Repository: alibaba/zvec Branch: main Commit: b49833bf56a0 Files: 988 Total size: 10.5 MB Directory structure: gitextract_w52102_2/ ├── .clang-format ├── .git/ │ ├── HEAD │ ├── config │ ├── description │ ├── hooks/ │ │ ├── applypatch-msg.sample │ │ ├── commit-msg.sample │ │ ├── fsmonitor-watchman.sample │ │ ├── post-update.sample │ │ ├── pre-applypatch.sample │ │ ├── pre-commit.sample │ │ ├── pre-merge-commit.sample │ │ ├── pre-push.sample │ │ ├── pre-rebase.sample │ │ ├── pre-receive.sample │ │ ├── prepare-commit-msg.sample │ │ ├── push-to-checkout.sample │ │ ├── sendemail-validate.sample │ │ └── update.sample │ ├── index │ ├── info/ │ │ └── exclude │ ├── logs/ │ │ ├── HEAD │ │ └── refs/ │ │ ├── heads/ │ │ │ └── main │ │ └── remotes/ │ │ └── origin/ │ │ └── HEAD │ ├── objects/ │ │ └── pack/ │ │ ├── pack-2b5e15ebe928a592991dc24c7ae7e8dc9e3500dc.idx │ │ ├── pack-2b5e15ebe928a592991dc24c7ae7e8dc9e3500dc.pack │ │ ├── pack-2b5e15ebe928a592991dc24c7ae7e8dc9e3500dc.promisor │ │ └── pack-2b5e15ebe928a592991dc24c7ae7e8dc9e3500dc.rev │ ├── packed-refs │ ├── refs/ │ │ ├── heads/ │ │ │ └── main │ │ └── remotes/ │ │ └── origin/ │ │ └── HEAD │ └── shallow ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── benchmark.yml │ │ ├── bug_report.yml │ │ ├── config.yml │ │ ├── enhancement.yml │ │ ├── feature_request.yml │ │ ├── integration.yml │ │ └── profiling.yml │ ├── codecov.yml │ ├── dependabot.yml │ └── workflows/ │ ├── 01-ci-pipeline.yml │ ├── 02-lint-check.yml │ ├── 03-macos-linux-build.yml │ ├── 04-android-build.yml │ ├── _build_wheel_job.yml │ ├── build_test_wheel.yml │ ├── build_wheel.yml │ ├── continuous_bench.yml │ ├── docker/ │ │ └── Dockerfile.linux_x64_glibc228 │ ├── nightly_coverage.yml │ └── scripts/ │ └── run_vdb.sh ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cmake/ │ ├── bazel.cmake │ ├── option.cmake │ └── utils.cmake ├── examples/ │ └── c++/ │ ├── CMakeLists.txt │ ├── ailego/ │ │ └── main.cc │ ├── core/ │ │ └── main.cc │ └── db/ │ └── main.cc ├── pyproject.toml ├── python/ │ ├── tests/ │ │ ├── detail/ │ │ │ ├── distance_helper.py │ │ │ ├── doc_helper.py │ │ │ ├── fixture_helper.py │ │ │ ├── params_helper.py │ │ │ ├── support_helper.py │ │ │ ├── test_collection_concurrency.py │ │ │ ├── test_collection_create_and_open.py │ │ │ ├── test_collection_ddl.py │ │ │ ├── test_collection_dml.py │ │ │ ├── test_collection_dql.py │ │ │ ├── test_collection_exception.py │ │ │ ├── test_collection_open.py │ │ │ ├── test_collection_recall.py │ │ │ └── test_db_config.py │ │ ├── test_collection.py │ │ ├── test_collection_hnsw_rabitq.py │ │ ├── test_convert.py │ │ ├── test_doc.py │ │ ├── test_embedding.py │ │ ├── test_params.py │ │ ├── test_query_executor.py │ │ ├── test_reranker.py │ │ ├── test_schema.py │ │ ├── test_typing.py │ │ └── test_util.py │ └── zvec/ │ ├── __init__.py │ ├── __init__.pyi │ ├── common/ │ │ ├── __init__.py │ │ └── constants.py │ ├── executor/ │ │ ├── __init__.py │ │ └── query_executor.py │ ├── extension/ │ │ ├── __init__.py │ │ ├── bm25_embedding_function.py │ │ ├── embedding_function.py │ │ ├── http_embedding_function.py │ │ ├── jina_embedding_function.py │ │ ├── jina_function.py │ │ ├── multi_vector_reranker.py │ │ ├── openai_embedding_function.py │ │ ├── openai_function.py │ │ ├── qwen_embedding_function.py │ │ ├── qwen_function.py │ │ ├── qwen_rerank_function.py │ │ ├── rerank_function.py │ │ ├── sentence_transformer_embedding_function.py │ │ ├── sentence_transformer_function.py │ │ └── sentence_transformer_rerank_function.py │ ├── model/ │ │ ├── __init__.py │ │ ├── collection.py │ │ ├── convert.py │ │ ├── doc.py │ │ ├── param/ │ │ │ ├── __init__.py │ │ │ ├── __init__.pyi │ │ │ └── vector_query.py │ │ └── schema/ │ │ ├── __init__.py │ │ ├── __init__.pyi │ │ ├── collection_schema.py │ │ └── field_schema.py │ ├── py.typed │ ├── tool/ │ │ ├── __init__.py │ │ └── util.py │ ├── typing/ │ │ ├── __init__.py │ │ ├── __init__.pyi │ │ └── enum.py │ └── zvec.py ├── scripts/ │ ├── README.md │ ├── build_android.sh │ └── gcov.sh ├── src/ │ ├── CMakeLists.txt │ ├── ailego/ │ │ ├── CMakeLists.txt │ │ ├── algorithm/ │ │ │ ├── binary_quantizer.cc │ │ │ ├── binary_quantizer.h │ │ │ ├── integer_quantizer.cc │ │ │ ├── integer_quantizer.h │ │ │ ├── kmeans.h │ │ │ └── lloyd_cluster.h │ │ ├── buffer/ │ │ │ ├── buffer_manager.cc │ │ │ └── buffer_pool.cc │ │ ├── container/ │ │ │ ├── bitmap.cc │ │ │ ├── bitmap.h │ │ │ ├── bloom_filter.h │ │ │ ├── params.cc │ │ │ ├── reservoir.h │ │ │ └── vector_array.h │ │ ├── encoding/ │ │ │ └── json/ │ │ │ └── mod_json.c │ │ ├── hash/ │ │ │ └── crc32c.cc │ │ ├── internal/ │ │ │ ├── cpu_features.cc │ │ │ └── cpu_features.h │ │ ├── io/ │ │ │ ├── file.cc │ │ │ ├── file_lock.cc │ │ │ ├── file_lock.h │ │ │ └── file_writer.h │ │ ├── logger/ │ │ │ └── logger.cc │ │ ├── math/ │ │ │ ├── cosine_distance_matrix.h │ │ │ ├── distance.h │ │ │ ├── distance_matrix.h │ │ │ ├── distance_matrix_accum_fp16.i │ │ │ ├── distance_matrix_accum_fp32.i │ │ │ ├── distance_matrix_accum_int4.i │ │ │ ├── distance_matrix_accum_int8.i │ │ │ ├── distance_matrix_euclidean_utility.i │ │ │ ├── distance_matrix_fp16.i │ │ │ ├── distance_matrix_fp32.i │ │ │ ├── distance_matrix_inner_product_utility.i │ │ │ ├── distance_matrix_int32.i │ │ │ ├── distance_matrix_int64.i │ │ │ ├── distance_matrix_mips_utility.i │ │ │ ├── distance_matrix_popcnt.i │ │ │ ├── distance_utility.h │ │ │ ├── euclidean_distance_matrix.h │ │ │ ├── euclidean_distance_matrix_fp16_avx.cc │ │ │ ├── euclidean_distance_matrix_fp16_avx512.cc │ │ │ ├── euclidean_distance_matrix_fp16_avx512fp16.cc │ │ │ ├── euclidean_distance_matrix_fp16_dispatch.cc │ │ │ ├── euclidean_distance_matrix_fp16_neon.cc │ │ │ ├── euclidean_distance_matrix_fp32_avx.cc │ │ │ ├── euclidean_distance_matrix_fp32_avx512.cc │ │ │ ├── euclidean_distance_matrix_fp32_dispatch.cc │ │ │ ├── euclidean_distance_matrix_fp32_neon.cc │ │ │ ├── euclidean_distance_matrix_fp32_sse.cc │ │ │ ├── euclidean_distance_matrix_int4_avx2.cc │ │ │ ├── euclidean_distance_matrix_int4_dispatch.cc │ │ │ ├── euclidean_distance_matrix_int4_sse.cc │ │ │ ├── euclidean_distance_matrix_int8_avx2.cc │ │ │ ├── euclidean_distance_matrix_int8_dispatch.cc │ │ │ ├── euclidean_distance_matrix_int8_sse.cc │ │ │ ├── euclidean_distance_matrix_scalar.cc │ │ │ ├── hamming_distance_matrix.cc │ │ │ ├── hamming_distance_matrix.h │ │ │ ├── inner_product_matrix.h │ │ │ ├── inner_product_matrix_fp16_avx.cc │ │ │ ├── inner_product_matrix_fp16_avx512.cc │ │ │ ├── inner_product_matrix_fp16_avx512fp16.cc │ │ │ ├── inner_product_matrix_fp16_dispatch.cc │ │ │ ├── inner_product_matrix_fp16_neon.cc │ │ │ ├── inner_product_matrix_fp32_avx.cc │ │ │ ├── inner_product_matrix_fp32_avx512.cc │ │ │ ├── inner_product_matrix_fp32_dispatch.cc │ │ │ ├── inner_product_matrix_fp32_neon.cc │ │ │ ├── inner_product_matrix_fp32_sse.cc │ │ │ ├── inner_product_matrix_int4_avx2.cc │ │ │ ├── inner_product_matrix_int4_dispatch.cc │ │ │ ├── inner_product_matrix_int4_sse.cc │ │ │ ├── inner_product_matrix_int8_avx2.cc │ │ │ ├── inner_product_matrix_int8_dispatch.cc │ │ │ ├── inner_product_matrix_int8_sse.cc │ │ │ ├── inner_product_matrix_scalar.cc │ │ │ ├── matrix_define.i │ │ │ ├── matrix_utility.i │ │ │ ├── mips_euclidean_distance_matrix.h │ │ │ ├── mips_euclidean_distance_matrix_fp16_avx.cc │ │ │ ├── mips_euclidean_distance_matrix_fp16_avx512.cc │ │ │ ├── mips_euclidean_distance_matrix_fp16_dispatch.cc │ │ │ ├── mips_euclidean_distance_matrix_fp16_neon.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_avx.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_avx512.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_dispatch.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_neon.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_sse.cc │ │ │ ├── mips_euclidean_distance_matrix_int4_avx2.cc │ │ │ ├── mips_euclidean_distance_matrix_int4_dispatch.cc │ │ │ ├── mips_euclidean_distance_matrix_int4_sse.cc │ │ │ ├── mips_euclidean_distance_matrix_int8_avx2.cc │ │ │ ├── mips_euclidean_distance_matrix_int8_dispatch.cc │ │ │ ├── mips_euclidean_distance_matrix_int8_sse.cc │ │ │ ├── mips_euclidean_distance_matrix_scalar.cc │ │ │ ├── norm1_matrix.h │ │ │ ├── norm1_matrix_fp16.cc │ │ │ ├── norm1_matrix_fp32.cc │ │ │ ├── norm2_matrix.h │ │ │ ├── norm2_matrix_fp16.cc │ │ │ ├── norm2_matrix_fp32.cc │ │ │ ├── norm_matrix.h │ │ │ ├── norm_matrix_fp16.i │ │ │ ├── norm_matrix_fp32.i │ │ │ ├── normalizer.cc │ │ │ └── normalizer.h │ │ ├── math_batch/ │ │ │ ├── cosine_distance_batch.h │ │ │ ├── distance_batch.h │ │ │ ├── inner_product_distance_batch.h │ │ │ ├── inner_product_distance_batch_dispatch.cc │ │ │ ├── inner_product_distance_batch_impl_fp16_avx2.cc │ │ │ ├── inner_product_distance_batch_impl_fp16_avx512.cc │ │ │ ├── inner_product_distance_batch_impl_fp16_avx512fp16.cc │ │ │ ├── inner_product_distance_batch_impl_fp32_avx2.cc │ │ │ ├── inner_product_distance_batch_impl_int8_avx2.cc │ │ │ └── inner_product_distance_batch_impl_int8_avx512fp16.cc │ │ ├── parallel/ │ │ │ ├── lock.h │ │ │ ├── multi_thread_list.h │ │ │ ├── semaphore.h │ │ │ └── thread_pool.cc │ │ ├── pattern/ │ │ │ ├── defer.h │ │ │ └── scope_guard.h │ │ ├── utility/ │ │ │ ├── bit_string_helper.h │ │ │ ├── bitset_helper.cc │ │ │ ├── bitset_helper.h │ │ │ ├── concurrency_helper.cc │ │ │ ├── concurrency_helper.h │ │ │ ├── dl_helper.cc │ │ │ ├── dl_helper.h │ │ │ ├── file_helper.cc │ │ │ ├── float_helper.cc │ │ │ ├── math_helper.h │ │ │ ├── matrix_helper.h │ │ │ ├── memory_helper.cc │ │ │ ├── memory_helper.h │ │ │ ├── string_helper.cc │ │ │ └── time_helper.cc │ │ ├── version.cc │ │ ├── version.h │ │ └── version.i │ ├── binding/ │ │ ├── CMakeLists.txt │ │ └── python/ │ │ ├── CMakeLists.txt │ │ ├── binding.cc │ │ ├── exports.mac │ │ ├── include/ │ │ │ ├── python_collection.h │ │ │ ├── python_config.h │ │ │ ├── python_doc.h │ │ │ ├── python_param.h │ │ │ ├── python_schema.h │ │ │ └── python_type.h │ │ ├── model/ │ │ │ ├── common/ │ │ │ │ └── python_config.cc │ │ │ ├── param/ │ │ │ │ └── python_param.cc │ │ │ ├── python_collection.cc │ │ │ ├── python_doc.cc │ │ │ └── schema/ │ │ │ └── python_schema.cc │ │ └── typing/ │ │ └── python_type.cc │ ├── core/ │ │ ├── CMakeLists.txt │ │ ├── algorithm/ │ │ │ ├── CMakeLists.txt │ │ │ ├── cluster/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── cluster_params.h │ │ │ │ ├── kmeans_cluster.cc │ │ │ │ ├── linear_seeker.cc │ │ │ │ ├── linear_seeker.h │ │ │ │ ├── opt_kmeans_cluster.cc │ │ │ │ ├── seeker.h │ │ │ │ ├── stratified_cluster.cc │ │ │ │ ├── stratified_cluster_trainer.cc │ │ │ │ ├── stratified_cluster_trainer.h │ │ │ │ └── vector_mean.h │ │ │ ├── flat/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── flat_builder.cc │ │ │ │ ├── flat_builder.h │ │ │ │ ├── flat_distance_matrix.h │ │ │ │ ├── flat_index_format.h │ │ │ │ ├── flat_searcher.cc │ │ │ │ ├── flat_searcher.h │ │ │ │ ├── flat_searcher_context.h │ │ │ │ ├── flat_searcher_provider.h │ │ │ │ ├── flat_streamer.cc │ │ │ │ ├── flat_streamer.h │ │ │ │ ├── flat_streamer_context.h │ │ │ │ ├── flat_streamer_dumper.h │ │ │ │ ├── flat_streamer_entity.cc │ │ │ │ ├── flat_streamer_entity.h │ │ │ │ ├── flat_streamer_provider.h │ │ │ │ └── flat_utility.h │ │ │ ├── flat_sparse/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── flat_sparse_builder.cc │ │ │ │ ├── flat_sparse_builder.h │ │ │ │ ├── flat_sparse_context.cc │ │ │ │ ├── flat_sparse_context.h │ │ │ │ ├── flat_sparse_entity.h │ │ │ │ ├── flat_sparse_index_format.h │ │ │ │ ├── flat_sparse_provider.h │ │ │ │ ├── flat_sparse_search.h │ │ │ │ ├── flat_sparse_searcher.cc │ │ │ │ ├── flat_sparse_searcher.h │ │ │ │ ├── flat_sparse_searcher_entity.cc │ │ │ │ ├── flat_sparse_searcher_entity.h │ │ │ │ ├── flat_sparse_streamer.cc │ │ │ │ ├── flat_sparse_streamer.h │ │ │ │ ├── flat_sparse_streamer_entity.cc │ │ │ │ ├── flat_sparse_streamer_entity.h │ │ │ │ └── flat_sparse_utility.h │ │ │ ├── hnsw/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_algorithm.cc │ │ │ │ ├── hnsw_algorithm.h │ │ │ │ ├── hnsw_builder.cc │ │ │ │ ├── hnsw_builder.h │ │ │ │ ├── hnsw_builder_entity.cc │ │ │ │ ├── hnsw_builder_entity.h │ │ │ │ ├── hnsw_chunk.cc │ │ │ │ ├── hnsw_chunk.h │ │ │ │ ├── hnsw_context.cc │ │ │ │ ├── hnsw_context.h │ │ │ │ ├── hnsw_dist_calculator.h │ │ │ │ ├── hnsw_entity.cc │ │ │ │ ├── hnsw_entity.h │ │ │ │ ├── hnsw_index_hash.h │ │ │ │ ├── hnsw_index_provider.h │ │ │ │ ├── hnsw_params.h │ │ │ │ ├── hnsw_searcher.cc │ │ │ │ ├── hnsw_searcher.h │ │ │ │ ├── hnsw_searcher_entity.cc │ │ │ │ ├── hnsw_searcher_entity.h │ │ │ │ ├── hnsw_streamer.cc │ │ │ │ ├── hnsw_streamer.h │ │ │ │ ├── hnsw_streamer_entity.cc │ │ │ │ └── hnsw_streamer_entity.h │ │ │ ├── hnsw_rabitq/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_rabitq_algorithm.cc │ │ │ │ ├── hnsw_rabitq_algorithm.h │ │ │ │ ├── hnsw_rabitq_builder.cc │ │ │ │ ├── hnsw_rabitq_builder.h │ │ │ │ ├── hnsw_rabitq_builder_entity.cc │ │ │ │ ├── hnsw_rabitq_builder_entity.h │ │ │ │ ├── hnsw_rabitq_chunk.cc │ │ │ │ ├── hnsw_rabitq_chunk.h │ │ │ │ ├── hnsw_rabitq_context.cc │ │ │ │ ├── hnsw_rabitq_context.h │ │ │ │ ├── hnsw_rabitq_dist_calculator.cc │ │ │ │ ├── hnsw_rabitq_dist_calculator.h │ │ │ │ ├── hnsw_rabitq_entity.cc │ │ │ │ ├── hnsw_rabitq_entity.h │ │ │ │ ├── hnsw_rabitq_index_hash.h │ │ │ │ ├── hnsw_rabitq_index_provider.h │ │ │ │ ├── hnsw_rabitq_params.h │ │ │ │ ├── hnsw_rabitq_query_algorithm.cc │ │ │ │ ├── hnsw_rabitq_query_algorithm.h │ │ │ │ ├── hnsw_rabitq_query_entity.h │ │ │ │ ├── hnsw_rabitq_register.cc │ │ │ │ ├── hnsw_rabitq_searcher.cc │ │ │ │ ├── hnsw_rabitq_searcher.h │ │ │ │ ├── hnsw_rabitq_searcher_entity.cc │ │ │ │ ├── hnsw_rabitq_searcher_entity.h │ │ │ │ ├── hnsw_rabitq_streamer.cc │ │ │ │ ├── hnsw_rabitq_streamer.h │ │ │ │ ├── hnsw_rabitq_streamer_entity.cc │ │ │ │ ├── hnsw_rabitq_streamer_entity.h │ │ │ │ ├── rabitq_converter.cc │ │ │ │ ├── rabitq_converter.h │ │ │ │ ├── rabitq_params.h │ │ │ │ ├── rabitq_reformer.cc │ │ │ │ ├── rabitq_reformer.h │ │ │ │ ├── rabitq_utils.cc │ │ │ │ └── rabitq_utils.h │ │ │ ├── hnsw_sparse/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_sparse_algorithm.cc │ │ │ │ ├── hnsw_sparse_algorithm.h │ │ │ │ ├── hnsw_sparse_builder.cc │ │ │ │ ├── hnsw_sparse_builder.h │ │ │ │ ├── hnsw_sparse_builder_entity.cc │ │ │ │ ├── hnsw_sparse_builder_entity.h │ │ │ │ ├── hnsw_sparse_chunk.cc │ │ │ │ ├── hnsw_sparse_chunk.h │ │ │ │ ├── hnsw_sparse_context.cc │ │ │ │ ├── hnsw_sparse_context.h │ │ │ │ ├── hnsw_sparse_dist_calculator.h │ │ │ │ ├── hnsw_sparse_entity.cc │ │ │ │ ├── hnsw_sparse_entity.h │ │ │ │ ├── hnsw_sparse_index_hash.h │ │ │ │ ├── hnsw_sparse_index_provider.h │ │ │ │ ├── hnsw_sparse_params.h │ │ │ │ ├── hnsw_sparse_searcher.cc │ │ │ │ ├── hnsw_sparse_searcher.h │ │ │ │ ├── hnsw_sparse_searcher_entity.cc │ │ │ │ ├── hnsw_sparse_searcher_entity.h │ │ │ │ ├── hnsw_sparse_streamer.cc │ │ │ │ ├── hnsw_sparse_streamer.h │ │ │ │ ├── hnsw_sparse_streamer_entity.cc │ │ │ │ └── hnsw_sparse_streamer_entity.h │ │ │ └── ivf/ │ │ │ ├── CMakeLists.txt │ │ │ ├── ivf_builder.cc │ │ │ ├── ivf_builder.h │ │ │ ├── ivf_centroid_index.cc │ │ │ ├── ivf_centroid_index.h │ │ │ ├── ivf_distance_calculator.cc │ │ │ ├── ivf_distance_calculator.h │ │ │ ├── ivf_dumper.cc │ │ │ ├── ivf_dumper.h │ │ │ ├── ivf_entity.cc │ │ │ ├── ivf_entity.h │ │ │ ├── ivf_index_format.h │ │ │ ├── ivf_index_provider.h │ │ │ ├── ivf_params.h │ │ │ ├── ivf_searcher.cc │ │ │ ├── ivf_searcher.h │ │ │ ├── ivf_searcher_context.h │ │ │ ├── ivf_streamer.cc │ │ │ ├── ivf_streamer.h │ │ │ └── ivf_utility.h │ │ ├── framework/ │ │ │ ├── CMakeLists.txt │ │ │ ├── index_cluster.cc │ │ │ ├── index_context.cc │ │ │ ├── index_converter.cc │ │ │ ├── index_error.cc │ │ │ ├── index_factory.cc │ │ │ ├── index_flow.cc │ │ │ ├── index_helper.cc │ │ │ ├── index_logger.cc │ │ │ ├── index_mapping.cc │ │ │ ├── index_meta.cc │ │ │ ├── index_plugin.cc │ │ │ └── index_version.cc │ │ ├── interface/ │ │ │ ├── CMakeLists.txt │ │ │ ├── index.cc │ │ │ ├── index_factory.cc │ │ │ ├── index_param.cc │ │ │ ├── indexes/ │ │ │ │ ├── flat_index.cc │ │ │ │ ├── hnsw_index.cc │ │ │ │ ├── hnsw_rabitq_index.cc │ │ │ │ └── ivf_index.cc │ │ │ └── utils/ │ │ │ └── utils.h │ │ ├── metric/ │ │ │ ├── CMakeLists.txt │ │ │ ├── cosine_metric.cc │ │ │ ├── euclidean_metric.cc │ │ │ ├── hamming_metric.cc │ │ │ ├── inner_product_metric.cc │ │ │ ├── metric_params.h │ │ │ ├── mips_euclidean_metric.cc │ │ │ ├── quantized_integer_metric.cc │ │ │ ├── quantized_integer_metric_batch.h │ │ │ └── quantized_integer_metric_matrix.h │ │ ├── mixed_reducer/ │ │ │ ├── CMakeLists.txt │ │ │ ├── mixed_reducer_params.h │ │ │ ├── mixed_streamer_reducer.cc │ │ │ └── mixed_streamer_reducer.h │ │ ├── quantizer/ │ │ │ ├── CMakeLists.txt │ │ │ ├── binary_converter.cc │ │ │ ├── binary_reformer.cc │ │ │ ├── cosine_converter.cc │ │ │ ├── cosine_reformer.cc │ │ │ ├── half_float_converter.cc │ │ │ ├── half_float_reformer.cc │ │ │ ├── integer_quantizer_converter.cc │ │ │ ├── integer_quantizer_reformer.cc │ │ │ ├── mips_converter.cc │ │ │ ├── mips_reformer.cc │ │ │ ├── quantizer_params.h │ │ │ └── record_quantizer.h │ │ └── utility/ │ │ ├── CMakeLists.txt │ │ ├── basic_refiner.cc │ │ ├── buffer_storage.cc │ │ ├── file_dumper.cc │ │ ├── file_read_storage.cc │ │ ├── memory_dumper.cc │ │ ├── memory_read_storage.cc │ │ ├── mmap_file_read_storage.cc │ │ ├── mmap_file_storage.cc │ │ ├── sparse_utility.h │ │ ├── utility_params.h │ │ └── visit_filter.h │ ├── db/ │ │ ├── CMakeLists.txt │ │ ├── collection.cc │ │ ├── common/ │ │ │ ├── CMakeLists.txt │ │ │ ├── cgroup_util.cc │ │ │ ├── cgroup_util.h │ │ │ ├── concurrent_roaring_bitmap.cc │ │ │ ├── concurrent_roaring_bitmap.h │ │ │ ├── config.cc │ │ │ ├── constants.h │ │ │ ├── error_code.cc │ │ │ ├── error_code.h │ │ │ ├── file_helper.cc │ │ │ ├── file_helper.h │ │ │ ├── global_resource.cc │ │ │ ├── global_resource.h │ │ │ ├── glogger.h │ │ │ ├── logger.h │ │ │ ├── profiler.h │ │ │ ├── rocbsdb_context.cc │ │ │ ├── rocksdb_context.h │ │ │ ├── status.cc │ │ │ ├── typedef.h │ │ │ ├── utils.cc │ │ │ └── utils.h │ │ ├── index/ │ │ │ ├── CMakeLists.txt │ │ │ ├── column/ │ │ │ │ ├── column_indexer.h │ │ │ │ ├── common/ │ │ │ │ │ └── index_results.h │ │ │ │ ├── inverted_column/ │ │ │ │ │ ├── inverted_codec.h │ │ │ │ │ ├── inverted_column_indexer.h │ │ │ │ │ ├── inverted_column_indexer_search.cc │ │ │ │ │ ├── inverted_column_indexer_util.cc │ │ │ │ │ ├── inverted_column_indexer_write.cc │ │ │ │ │ ├── inverted_doc_range.h │ │ │ │ │ ├── inverted_indexer.cc │ │ │ │ │ ├── inverted_indexer.h │ │ │ │ │ ├── inverted_rocksdb_merger.h │ │ │ │ │ └── inverted_search_result.h │ │ │ │ └── vector_column/ │ │ │ │ ├── combined_vector_column_indexer.cc │ │ │ │ ├── combined_vector_column_indexer.h │ │ │ │ ├── engine_helper.hpp │ │ │ │ ├── vector_column_indexer.cc │ │ │ │ ├── vector_column_indexer.h │ │ │ │ ├── vector_column_params.h │ │ │ │ └── vector_index_results.h │ │ │ ├── common/ │ │ │ │ ├── delete_store.h │ │ │ │ ├── doc.cc │ │ │ │ ├── id_map.cc │ │ │ │ ├── id_map.h │ │ │ │ ├── index_filter.h │ │ │ │ ├── index_params.cc │ │ │ │ ├── meta.h │ │ │ │ ├── proto_converter.cc │ │ │ │ ├── proto_converter.h │ │ │ │ ├── schema.cc │ │ │ │ ├── stats.cc │ │ │ │ ├── type_helper.cc │ │ │ │ ├── type_helper.h │ │ │ │ ├── version_manager.cc │ │ │ │ └── version_manager.h │ │ │ ├── segment/ │ │ │ │ ├── column_merging_reader.cc │ │ │ │ ├── column_merging_reader.h │ │ │ │ ├── segment.cc │ │ │ │ ├── segment.h │ │ │ │ ├── segment_helper.cc │ │ │ │ ├── segment_helper.h │ │ │ │ ├── segment_manager.cc │ │ │ │ ├── segment_manager.h │ │ │ │ ├── sql_expr_parser.cc │ │ │ │ └── sql_expr_parser.h │ │ │ └── storage/ │ │ │ ├── arrow_ipc_writer.cc │ │ │ ├── arrow_ipc_writer.h │ │ │ ├── base_forward_store.h │ │ │ ├── bufferpool_forward_store.cc │ │ │ ├── bufferpool_forward_store.h │ │ │ ├── chunked_file_writer.cc │ │ │ ├── chunked_file_writer.h │ │ │ ├── forward_writer.cc │ │ │ ├── forward_writer.h │ │ │ ├── lazy_record_batch_reader.h │ │ │ ├── memory_forward_store.cc │ │ │ ├── memory_forward_store.h │ │ │ ├── mmap_forward_store.cc │ │ │ ├── mmap_forward_store.h │ │ │ ├── parquet_writer.cc │ │ │ ├── parquet_writer.h │ │ │ ├── store_helper.h │ │ │ └── wal/ │ │ │ ├── local_wal_file.cc │ │ │ ├── local_wal_file.h │ │ │ ├── wal_file.cc │ │ │ └── wal_file.h │ │ ├── proto/ │ │ │ └── zvec.proto │ │ └── sqlengine/ │ │ ├── CMakeLists.txt │ │ ├── analyzer/ │ │ │ ├── query_analyzer.cc │ │ │ ├── query_analyzer.h │ │ │ ├── query_field_info.cc │ │ │ ├── query_field_info.h │ │ │ ├── query_info.cc │ │ │ ├── query_info.h │ │ │ ├── query_info_helper.cc │ │ │ ├── query_info_helper.h │ │ │ ├── query_node.cc │ │ │ ├── query_node.h │ │ │ ├── query_node_walker.cc │ │ │ ├── query_node_walker.h │ │ │ ├── query_orderby_info.cc │ │ │ ├── query_orderby_info.h │ │ │ ├── simple_rewriter.cc │ │ │ └── simple_rewriter.h │ │ ├── antlr/ │ │ │ ├── SQLLexer.g4 │ │ │ ├── SQLParser.g4 │ │ │ ├── gen/ │ │ │ │ ├── SQLLexer.cc │ │ │ │ ├── SQLLexer.h │ │ │ │ ├── SQLLexer.interp │ │ │ │ ├── SQLLexer.tokens │ │ │ │ ├── SQLParser.cc │ │ │ │ ├── SQLParser.h │ │ │ │ ├── SQLParser.interp │ │ │ │ ├── SQLParser.tokens │ │ │ │ ├── SQLParserBaseListener.cc │ │ │ │ ├── SQLParserBaseListener.h │ │ │ │ ├── SQLParserListener.cc │ │ │ │ └── SQLParserListener.h │ │ │ └── gen_parser.sh │ │ ├── common/ │ │ │ ├── generic_node.h │ │ │ ├── group_by.h │ │ │ ├── util.cc │ │ │ └── util.h │ │ ├── parser/ │ │ │ ├── base_info.h │ │ │ ├── case_changing_charstream.h │ │ │ ├── error_verbose_listener.h │ │ │ ├── node.cc │ │ │ ├── node.h │ │ │ ├── orderby_elem_info.h │ │ │ ├── query_parser.cc │ │ │ ├── query_parser.h │ │ │ ├── select_info.cc │ │ │ ├── select_info.h │ │ │ ├── selected_elem_info.cc │ │ │ ├── selected_elem_info.h │ │ │ ├── sql_info.cc │ │ │ ├── sql_info.h │ │ │ ├── sql_info_helper.cc │ │ │ ├── sql_info_helper.h │ │ │ ├── zvec_cached_sql_parser.cc │ │ │ ├── zvec_cached_sql_parser.h │ │ │ ├── zvec_parser.cc │ │ │ ├── zvec_parser.h │ │ │ ├── zvec_sql_parser.cc │ │ │ └── zvec_sql_parser.h │ │ ├── planner/ │ │ │ ├── doc_filter.cc │ │ │ ├── doc_filter.h │ │ │ ├── invert_recall_node.cc │ │ │ ├── invert_recall_node.h │ │ │ ├── invert_search.cc │ │ │ ├── invert_search.h │ │ │ ├── op_register.cc │ │ │ ├── op_register.h │ │ │ ├── ops/ │ │ │ │ ├── check_not_filtered_op.cc │ │ │ │ ├── check_not_filtered_op.h │ │ │ │ ├── contain_op.cc │ │ │ │ ├── contain_op.h │ │ │ │ ├── fetch_vector_op.cc │ │ │ │ └── fetch_vector_op.h │ │ │ ├── optimizer.cc │ │ │ ├── optimizer.h │ │ │ ├── plan_info.cc │ │ │ ├── plan_info.h │ │ │ ├── query_planner.cc │ │ │ ├── query_planner.h │ │ │ ├── segment_node.cc │ │ │ ├── segment_node.h │ │ │ ├── vector_recall_node.cc │ │ │ └── vector_recall_node.h │ │ ├── sqlengine.cc │ │ ├── sqlengine.h │ │ ├── sqlengine_impl.cc │ │ └── sqlengine_impl.h │ ├── include/ │ │ └── zvec/ │ │ ├── ailego/ │ │ │ ├── buffer/ │ │ │ │ ├── buffer_manager.h │ │ │ │ ├── buffer_pool.h │ │ │ │ └── concurrentqueue.h │ │ │ ├── container/ │ │ │ │ ├── blob.h │ │ │ │ ├── cube.h │ │ │ │ ├── heap.h │ │ │ │ ├── hypercube.h │ │ │ │ ├── params.h │ │ │ │ └── vector.h │ │ │ ├── encoding/ │ │ │ │ ├── json/ │ │ │ │ │ ├── mod_json.h │ │ │ │ │ └── mod_json_plus.h │ │ │ │ └── json.h │ │ │ ├── hash/ │ │ │ │ ├── crc32c.h │ │ │ │ └── jump_hash.h │ │ │ ├── internal/ │ │ │ │ └── platform.h │ │ │ ├── io/ │ │ │ │ ├── file.h │ │ │ │ └── mmap_file.h │ │ │ ├── logger/ │ │ │ │ └── logger.h │ │ │ ├── math_batch/ │ │ │ │ └── utils.h │ │ │ ├── parallel/ │ │ │ │ ├── thread_pool.h │ │ │ │ └── thread_queue.h │ │ │ ├── pattern/ │ │ │ │ ├── closure.h │ │ │ │ ├── expected.hpp │ │ │ │ ├── factory.h │ │ │ │ └── singleton.h │ │ │ ├── string/ │ │ │ │ ├── string_concat_helper.h │ │ │ │ └── string_view.h │ │ │ └── utility/ │ │ │ ├── file_helper.h │ │ │ ├── float_helper.h │ │ │ ├── string_helper.h │ │ │ ├── string_helper_impl.h │ │ │ ├── time_helper.h │ │ │ └── type_helper.h │ │ ├── core/ │ │ │ ├── framework/ │ │ │ │ ├── index_builder.h │ │ │ │ ├── index_bundle.h │ │ │ │ ├── index_cluster.h │ │ │ │ ├── index_context.h │ │ │ │ ├── index_converter.h │ │ │ │ ├── index_document.h │ │ │ │ ├── index_dumper.h │ │ │ │ ├── index_error.h │ │ │ │ ├── index_factory.h │ │ │ │ ├── index_features.h │ │ │ │ ├── index_filter.h │ │ │ │ ├── index_flow.h │ │ │ │ ├── index_format.h │ │ │ │ ├── index_framework.h │ │ │ │ ├── index_groupby.h │ │ │ │ ├── index_helper.h │ │ │ │ ├── index_holder.h │ │ │ │ ├── index_logger.h │ │ │ │ ├── index_mapping.h │ │ │ │ ├── index_memory.h │ │ │ │ ├── index_meta.h │ │ │ │ ├── index_metric.h │ │ │ │ ├── index_module.h │ │ │ │ ├── index_packer.h │ │ │ │ ├── index_plugin.h │ │ │ │ ├── index_provider.h │ │ │ │ ├── index_reducer.h │ │ │ │ ├── index_refiner.h │ │ │ │ ├── index_reformer.h │ │ │ │ ├── index_runner.h │ │ │ │ ├── index_searcher.h │ │ │ │ ├── index_segment_storage.h │ │ │ │ ├── index_stats.h │ │ │ │ ├── index_storage.h │ │ │ │ ├── index_streamer.h │ │ │ │ ├── index_threads.h │ │ │ │ ├── index_trainer.h │ │ │ │ ├── index_unpacker.h │ │ │ │ └── index_version.h │ │ │ └── interface/ │ │ │ ├── constants.h │ │ │ ├── index.h │ │ │ ├── index_factory.h │ │ │ ├── index_param.h │ │ │ └── index_param_builders.h │ │ ├── db/ │ │ │ ├── collection.h │ │ │ ├── config.h │ │ │ ├── doc.h │ │ │ ├── index_params.h │ │ │ ├── options.h │ │ │ ├── query_params.h │ │ │ ├── schema.h │ │ │ ├── stats.h │ │ │ ├── status.h │ │ │ └── type.h │ │ └── turbo/ │ │ └── turbo.h │ └── turbo/ │ ├── CMakeLists.txt │ ├── avx512_vnni/ │ │ └── record_quantized_int8/ │ │ ├── common.h │ │ ├── cosine.cc │ │ ├── cosine.h │ │ ├── squared_euclidean.cc │ │ └── squared_euclidean.h │ └── turbo.cc ├── tests/ │ ├── CMakeLists.txt │ ├── ailego/ │ │ ├── CMakeLists.txt │ │ ├── algorithm/ │ │ │ ├── integer_quantizer_test.cc │ │ │ └── kmeans_test.cc │ │ ├── buffer/ │ │ │ └── buffer_manager_test.cc │ │ ├── container/ │ │ │ ├── bitmap_test.cc │ │ │ ├── blob_test.cc │ │ │ ├── bloom_filter_test.cc │ │ │ ├── cube_test.cc │ │ │ ├── heap_test.cc │ │ │ ├── hypercube_test.cc │ │ │ ├── params_test.cc │ │ │ ├── reservoir_test.cc │ │ │ ├── vector_array_test.cc │ │ │ └── vector_test.cc │ │ ├── encoding/ │ │ │ └── json_parse_test.cc │ │ ├── hash/ │ │ │ ├── crc32c_test.cc │ │ │ └── jump_hash_test.cc │ │ ├── internal/ │ │ │ └── cpu_features_test.cc │ │ ├── io/ │ │ │ ├── file_lock_test.cc │ │ │ ├── file_test.cc │ │ │ └── mmap_file_test.cc │ │ ├── logger/ │ │ │ └── logger_test.cc │ │ ├── math/ │ │ │ ├── cosine_distance_matrix_fp16_test.cc │ │ │ ├── cosine_distance_matrix_fp32_test.cc │ │ │ ├── cosine_distance_matrix_int8_test.cc │ │ │ ├── euclidean_distance_matrix_fp16_test.cc │ │ │ ├── euclidean_distance_matrix_fp32_test.cc │ │ │ ├── euclidean_distance_matrix_int4_test.cc │ │ │ ├── euclidean_distance_matrix_int8_test.cc │ │ │ ├── hamming_distance_matrix_test.cc │ │ │ ├── inner_product_matrix_fp16_test.cc │ │ │ ├── inner_product_matrix_fp32_test.cc │ │ │ ├── inner_product_matrix_int4_test.cc │ │ │ ├── inner_product_matrix_int8_test.cc │ │ │ ├── mips_euclidean_distance_matrix_fp16_test.cc │ │ │ ├── mips_euclidean_distance_matrix_fp32_test.cc │ │ │ ├── mips_euclidean_distance_matrix_int4_test.cc │ │ │ ├── mips_euclidean_distance_matrix_int8_test.cc │ │ │ ├── norm_matrix_fp16_test.cc │ │ │ ├── norm_matrix_fp32_test.cc │ │ │ ├── norm_matrix_int4_test.cc │ │ │ ├── norm_matrix_int8_test.cc │ │ │ └── normalizer_test.cc │ │ ├── parallel/ │ │ │ ├── lock_test.cc │ │ │ ├── multi_thread_list_test.cc │ │ │ ├── semaphore_test.cc │ │ │ ├── thread_pool_test.cc │ │ │ └── thread_queue_test.cc │ │ ├── pattern/ │ │ │ ├── closure_test.cc │ │ │ ├── factory_test.cc │ │ │ ├── scope_guard_test.cc │ │ │ └── singleton_test.cc │ │ ├── utility/ │ │ │ ├── bit_string_helper_test.cc │ │ │ ├── bitset_helper_test.cc │ │ │ ├── dl_helper_test.cc │ │ │ ├── float_helper_test.cc │ │ │ ├── matrix_helper_test.cc │ │ │ ├── memory_helper_test.cc │ │ │ ├── string_helper_test.cc │ │ │ ├── time_helper_test.cc │ │ │ └── type_helper_test.cc │ │ └── version_test.cc │ ├── core/ │ │ ├── CMakeLists.txt │ │ ├── algorithm/ │ │ │ ├── CMakeLists.txt │ │ │ ├── cluster/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── kmeans_cluster_test.cc │ │ │ │ └── opt_kmeans_cluster_test.cc │ │ │ ├── flat/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── flat_builder_test.cc │ │ │ │ ├── flat_searcher_test.cpp │ │ │ │ ├── flat_streamer_buffer_test.cc │ │ │ │ ├── flat_streamer_buffer_time_test.cc │ │ │ │ └── flat_streamer_test.cc │ │ │ ├── flat_sparse/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── flat_sparse_builder_test.cc │ │ │ │ ├── flat_sparse_searcher_test.cc │ │ │ │ ├── flat_sparse_streamer_buffer_test.cc │ │ │ │ └── flat_sparse_streamer_test.cc │ │ │ ├── hnsw/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_builder_test.cc │ │ │ │ ├── hnsw_searcher_test.cpp │ │ │ │ ├── hnsw_streamer_buffer_test.cc │ │ │ │ └── hnsw_streamer_test.cc │ │ │ ├── hnsw_rabitq/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_rabitq_builder_test.cc │ │ │ │ ├── hnsw_rabitq_searcher_test.cc │ │ │ │ └── hnsw_rabitq_streamer_test.cc │ │ │ ├── hnsw_sparse/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── hnsw_sparse_builder_test.cc │ │ │ │ ├── hnsw_sparse_searcher_test.cpp │ │ │ │ ├── hnsw_sparse_streamer_buffer_test.cpp │ │ │ │ └── hnsw_sparse_streamer_test.cc │ │ │ └── ivf/ │ │ │ ├── CMakeLists.txt │ │ │ ├── ivf_builder_test.cc │ │ │ └── ivf_searcher_test.cc │ │ ├── framework/ │ │ │ └── CMakeLists.txt │ │ ├── interface/ │ │ │ ├── CMakeLists.txt │ │ │ └── index_interface_test.cc │ │ ├── metric/ │ │ │ ├── CMakeLists.txt │ │ │ ├── cosine_metric_test.cc │ │ │ ├── euclidean_metric_test.cc │ │ │ ├── hamming_metric_test.cc │ │ │ ├── inner_product_metric_test.cc │ │ │ └── quantized_integer_metric_test.cc │ │ ├── quantizer/ │ │ │ ├── CMakeLists.txt │ │ │ ├── half_float_reformer_test.cc │ │ │ └── integer_quantizer_reformer_test.cc │ │ └── utility/ │ │ ├── CMakeLists.txt │ │ ├── buffer_storage_test.cpp │ │ ├── file_dumper_test.cc │ │ ├── memory_dumper_test.cc │ │ ├── mmap_file_container_test.cc │ │ └── mmap_file_storage_test.cpp │ └── db/ │ ├── CMakeLists.txt │ ├── collection_test.cc │ ├── common/ │ │ ├── CMakeLists.txt │ │ ├── config_test.cc │ │ └── status_test.cc │ ├── crash_recovery/ │ │ ├── CMakeLists.txt │ │ ├── data_generator.cc │ │ ├── utility.h │ │ └── write_recovery_test.cc │ ├── index/ │ │ ├── CMakeLists.txt │ │ ├── column/ │ │ │ ├── inverted_column/ │ │ │ │ ├── inverted_column_indexer_array_numbers_test.cc │ │ │ │ ├── inverted_column_indexer_bool_test.cc │ │ │ │ ├── inverted_column_indexer_cyclic_numbers_test.cc │ │ │ │ ├── inverted_column_indexer_sequential_numbers_test.cc │ │ │ │ ├── inverted_column_indexer_string_test.cc │ │ │ │ └── inverted_indexer_util_test.cc │ │ │ └── vector_column_indexer_test.cc │ │ ├── common/ │ │ │ ├── db_proto_converter_test.cc │ │ │ ├── db_type_helper_test.cc │ │ │ ├── doc_test.cc │ │ │ ├── index_params_test.cc │ │ │ ├── meta_test.cc │ │ │ ├── query_params_test.cc │ │ │ ├── schema_test.cc │ │ │ └── version_manager_test.cc │ │ ├── segment/ │ │ │ ├── column_merging_reader_test.cc │ │ │ ├── segment_helper_test.cc │ │ │ ├── segment_test.cc │ │ │ ├── sql_expr_parser_test.cc │ │ │ └── sql_expr_validator_test.cc │ │ ├── storage/ │ │ │ ├── arrow_ipc_writer_test.cc │ │ │ ├── bufferpool_store_test.cc │ │ │ ├── mem_store_test.cc │ │ │ ├── mmap_store_test.cc │ │ │ ├── parquet_writer_test.cc │ │ │ └── wal_file_test.cc │ │ └── utils/ │ │ ├── utils.cc │ │ └── utils.h │ └── sqlengine/ │ ├── CMakeLists.txt │ ├── contain_test.cc │ ├── forward_recall_test.cc │ ├── invert_recall_test.cc │ ├── like_test.cc │ ├── mock_segment.h │ ├── optimizer_test.cc │ ├── query_info_test.cc │ ├── recall_base.h │ ├── simple_rewriter_test.cc │ ├── sqlengine_test.cc │ ├── test_helper.h │ └── vector_recall_test.cc ├── thirdparty/ │ ├── CMakeLists.txt │ ├── CRoaring/ │ │ └── CMakeLists.txt │ ├── RaBitQ-Library/ │ │ └── CMakeLists.txt │ ├── antlr/ │ │ ├── CMakeLists.txt │ │ └── antlr4.patch │ ├── arrow/ │ │ ├── CMakeLists.txt │ │ ├── arrow.android.patch │ │ └── arrow.patch │ ├── gflags/ │ │ └── CMakeLists.txt │ ├── glog/ │ │ ├── CMakeLists.txt │ │ ├── glog.android.patch │ │ └── glog.patch │ ├── googletest/ │ │ └── CMakeLists.txt │ ├── lz4/ │ │ └── CMakeLists.txt │ ├── magic_enum/ │ │ └── CMakeLists.txt │ ├── protobuf/ │ │ └── CMakeLists.txt │ ├── rocksdb/ │ │ ├── CMakeLists.txt │ │ └── rocksdb.android.patch │ ├── sparsehash/ │ │ ├── CMakeLists.txt │ │ └── sparseconfig.h │ └── yaml-cpp/ │ └── CMakeLists.txt └── tools/ ├── CMakeLists.txt └── core/ ├── CMakeLists.txt ├── README.md ├── bench.cc ├── bench_original.cc ├── bench_result.h ├── convert_cohere_parquet.py ├── filter_result_cache.h ├── flow.h ├── helper.h ├── index_meta_helper.h ├── local_builder.cc ├── local_builder_original.cc ├── meta_segment_common.h ├── recall.cc ├── recall_original.cc ├── txt2vecs.cc ├── txt_input_reader.h ├── vecs_common.h ├── vecs_index_holder.h └── vecs_reader.h ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ # http://clang.llvm.org/docs/ClangFormatStyleOptions.html # Defines the Google C++ style for automatic reformatting. BasedOnStyle: Google MaxEmptyLinesToKeep: 2 DerivePointerAlignment: false PointerAlignment: Right AllowShortFunctionsOnASingleLine: Empty IncludeBlocks: Merge IncludeCategories: - Regex: '^$' Priority: 200 - Regex: '^<[0-9A-Za-z_]+>$' Priority: 201 - Regex: '^<[0-9A-Za-z_]+\.[0-9A-Za-z]+>$' Priority: 202 - Regex: '^<[0-9A-Za-z_]+/[0-9A-Za-z]+' Priority: 203 - Regex: '^\"[0-9A-Za-z_]+/[0-9A-Za-z]+' Priority: 300 - Regex: '^\"[0-9A-Za-z_]+\.[0-9A-Za-z]+\"$' Priority: 301 - Regex: '.*' Priority: 1000 ================================================ FILE: .git/HEAD ================================================ ref: refs/heads/main ================================================ FILE: .git/config ================================================ [core] repositoryformatversion = 1 filemode = true bare = false logallrefupdates = true [remote "origin"] url = https://github.com/alibaba/zvec tagOpt = --no-tags fetch = +refs/heads/main:refs/remotes/origin/main promisor = true partialclonefilter = blob:limit=1048576 [branch "main"] remote = origin merge = refs/heads/main ================================================ FILE: .git/description ================================================ Unnamed repository; edit this file 'description' to name the repository. ================================================ FILE: .git/hooks/applypatch-msg.sample ================================================ #!/bin/sh # # An example hook script to check the commit log message taken by # applypatch from an e-mail message. # # The hook should exit with non-zero status after issuing an # appropriate message if it wants to stop the commit. The hook is # allowed to edit the commit message file. # # To enable this hook, rename this file to "applypatch-msg". . git-sh-setup commitmsg="$(git rev-parse --git-path hooks/commit-msg)" test -x "$commitmsg" && exec "$commitmsg" ${1+"$@"} : ================================================ FILE: .git/hooks/commit-msg.sample ================================================ #!/bin/sh # # An example hook script to check the commit log message. # Called by "git commit" with one argument, the name of the file # that has the commit message. The hook should exit with non-zero # status after issuing an appropriate message if it wants to stop the # commit. The hook is allowed to edit the commit message file. # # To enable this hook, rename this file to "commit-msg". # Uncomment the below to add a Signed-off-by line to the message. # Doing this in a hook is a bad idea in general, but the prepare-commit-msg # hook is more suited to it. # # SOB=$(git var GIT_AUTHOR_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p') # grep -qs "^$SOB" "$1" || echo "$SOB" >> "$1" # This example catches duplicate Signed-off-by lines. test "" = "$(grep '^Signed-off-by: ' "$1" | sort | uniq -c | sed -e '/^[ ]*1[ ]/d')" || { echo >&2 Duplicate Signed-off-by lines. exit 1 } ================================================ FILE: .git/hooks/fsmonitor-watchman.sample ================================================ #!/usr/bin/perl use strict; use warnings; use IPC::Open2; # An example hook script to integrate Watchman # (https://facebook.github.io/watchman/) with git to speed up detecting # new and modified files. # # The hook is passed a version (currently 2) and last update token # formatted as a string and outputs to stdout a new update token and # all files that have been modified since the update token. Paths must # be relative to the root of the working tree and separated by a single NUL. # # To enable this hook, rename this file to "query-watchman" and set # 'git config core.fsmonitor .git/hooks/query-watchman' # my ($version, $last_update_token) = @ARGV; # Uncomment for debugging # print STDERR "$0 $version $last_update_token\n"; # Check the hook interface version if ($version ne 2) { die "Unsupported query-fsmonitor hook version '$version'.\n" . "Falling back to scanning...\n"; } my $git_work_tree = get_working_dir(); my $retry = 1; my $json_pkg; eval { require JSON::XS; $json_pkg = "JSON::XS"; 1; } or do { require JSON::PP; $json_pkg = "JSON::PP"; }; launch_watchman(); sub launch_watchman { my $o = watchman_query(); if (is_work_tree_watched($o)) { output_result($o->{clock}, @{$o->{files}}); } } sub output_result { my ($clockid, @files) = @_; # Uncomment for debugging watchman output # open (my $fh, ">", ".git/watchman-output.out"); # binmode $fh, ":utf8"; # print $fh "$clockid\n@files\n"; # close $fh; binmode STDOUT, ":utf8"; print $clockid; print "\0"; local $, = "\0"; print @files; } sub watchman_clock { my $response = qx/watchman clock "$git_work_tree"/; die "Failed to get clock id on '$git_work_tree'.\n" . "Falling back to scanning...\n" if $? != 0; return $json_pkg->new->utf8->decode($response); } sub watchman_query { my $pid = open2(\*CHLD_OUT, \*CHLD_IN, 'watchman -j --no-pretty') or die "open2() failed: $!\n" . "Falling back to scanning...\n"; # In the query expression below we're asking for names of files that # changed since $last_update_token but not from the .git folder. # # To accomplish this, we're using the "since" generator to use the # recency index to select candidate nodes and "fields" to limit the # output to file names only. Then we're using the "expression" term to # further constrain the results. my $last_update_line = ""; if (substr($last_update_token, 0, 1) eq "c") { $last_update_token = "\"$last_update_token\""; $last_update_line = qq[\n"since": $last_update_token,]; } my $query = <<" END"; ["query", "$git_work_tree", {$last_update_line "fields": ["name"], "expression": ["not", ["dirname", ".git"]] }] END # Uncomment for debugging the watchman query # open (my $fh, ">", ".git/watchman-query.json"); # print $fh $query; # close $fh; print CHLD_IN $query; close CHLD_IN; my $response = do {local $/; }; # Uncomment for debugging the watch response # open ($fh, ">", ".git/watchman-response.json"); # print $fh $response; # close $fh; die "Watchman: command returned no output.\n" . "Falling back to scanning...\n" if $response eq ""; die "Watchman: command returned invalid output: $response\n" . "Falling back to scanning...\n" unless $response =~ /^\{/; return $json_pkg->new->utf8->decode($response); } sub is_work_tree_watched { my ($output) = @_; my $error = $output->{error}; if ($retry > 0 and $error and $error =~ m/unable to resolve root .* directory (.*) is not watched/) { $retry--; my $response = qx/watchman watch "$git_work_tree"/; die "Failed to make watchman watch '$git_work_tree'.\n" . "Falling back to scanning...\n" if $? != 0; $output = $json_pkg->new->utf8->decode($response); $error = $output->{error}; die "Watchman: $error.\n" . "Falling back to scanning...\n" if $error; # Uncomment for debugging watchman output # open (my $fh, ">", ".git/watchman-output.out"); # close $fh; # Watchman will always return all files on the first query so # return the fast "everything is dirty" flag to git and do the # Watchman query just to get it over with now so we won't pay # the cost in git to look up each individual file. my $o = watchman_clock(); $error = $output->{error}; die "Watchman: $error.\n" . "Falling back to scanning...\n" if $error; output_result($o->{clock}, ("/")); $last_update_token = $o->{clock}; eval { launch_watchman() }; return 0; } die "Watchman: $error.\n" . "Falling back to scanning...\n" if $error; return 1; } sub get_working_dir { my $working_dir; if ($^O =~ 'msys' || $^O =~ 'cygwin') { $working_dir = Win32::GetCwd(); $working_dir =~ tr/\\/\//; } else { require Cwd; $working_dir = Cwd::cwd(); } return $working_dir; } ================================================ FILE: .git/hooks/post-update.sample ================================================ #!/bin/sh # # An example hook script to prepare a packed repository for use over # dumb transports. # # To enable this hook, rename this file to "post-update". exec git update-server-info ================================================ FILE: .git/hooks/pre-applypatch.sample ================================================ #!/bin/sh # # An example hook script to verify what is about to be committed # by applypatch from an e-mail message. # # The hook should exit with non-zero status after issuing an # appropriate message if it wants to stop the commit. # # To enable this hook, rename this file to "pre-applypatch". . git-sh-setup precommit="$(git rev-parse --git-path hooks/pre-commit)" test -x "$precommit" && exec "$precommit" ${1+"$@"} : ================================================ FILE: .git/hooks/pre-commit.sample ================================================ #!/bin/sh # # An example hook script to verify what is about to be committed. # Called by "git commit" with no arguments. The hook should # exit with non-zero status after issuing an appropriate message if # it wants to stop the commit. # # To enable this hook, rename this file to "pre-commit". if git rev-parse --verify HEAD >/dev/null 2>&1 then against=HEAD else # Initial commit: diff against an empty tree object against=$(git hash-object -t tree /dev/null) fi # If you want to allow non-ASCII filenames set this variable to true. allownonascii=$(git config --type=bool hooks.allownonascii) # Redirect output to stderr. exec 1>&2 # Cross platform projects tend to avoid non-ASCII filenames; prevent # them from being added to the repository. We exploit the fact that the # printable range starts at the space character and ends with tilde. if [ "$allownonascii" != "true" ] && # Note that the use of brackets around a tr range is ok here, (it's # even required, for portability to Solaris 10's /usr/bin/tr), since # the square bracket bytes happen to fall in the designated range. test $(git diff-index --cached --name-only --diff-filter=A -z $against | LC_ALL=C tr -d '[ -~]\0' | wc -c) != 0 then cat <<\EOF Error: Attempt to add a non-ASCII file name. This can cause problems if you want to work with people on other platforms. To be portable it is advisable to rename the file. If you know what you are doing you can disable this check using: git config hooks.allownonascii true EOF exit 1 fi # If there are whitespace errors, print the offending file names and fail. exec git diff-index --check --cached $against -- ================================================ FILE: .git/hooks/pre-merge-commit.sample ================================================ #!/bin/sh # # An example hook script to verify what is about to be committed. # Called by "git merge" with no arguments. The hook should # exit with non-zero status after issuing an appropriate message to # stderr if it wants to stop the merge commit. # # To enable this hook, rename this file to "pre-merge-commit". . git-sh-setup test -x "$GIT_DIR/hooks/pre-commit" && exec "$GIT_DIR/hooks/pre-commit" : ================================================ FILE: .git/hooks/pre-push.sample ================================================ #!/bin/sh # An example hook script to verify what is about to be pushed. Called by "git # push" after it has checked the remote status, but before anything has been # pushed. If this script exits with a non-zero status nothing will be pushed. # # This hook is called with the following parameters: # # $1 -- Name of the remote to which the push is being done # $2 -- URL to which the push is being done # # If pushing without using a named remote those arguments will be equal. # # Information about the commits which are being pushed is supplied as lines to # the standard input in the form: # # # # This sample shows how to prevent push of commits where the log message starts # with "WIP" (work in progress). remote="$1" url="$2" zero=$(git hash-object --stdin &2 "Found WIP commit in $local_ref, not pushing" exit 1 fi fi done exit 0 ================================================ FILE: .git/hooks/pre-rebase.sample ================================================ #!/bin/sh # # Copyright (c) 2006, 2008 Junio C Hamano # # The "pre-rebase" hook is run just before "git rebase" starts doing # its job, and can prevent the command from running by exiting with # non-zero status. # # The hook is called with the following parameters: # # $1 -- the upstream the series was forked from. # $2 -- the branch being rebased (or empty when rebasing the current branch). # # This sample shows how to prevent topic branches that are already # merged to 'next' branch from getting rebased, because allowing it # would result in rebasing already published history. publish=next basebranch="$1" if test "$#" = 2 then topic="refs/heads/$2" else topic=`git symbolic-ref HEAD` || exit 0 ;# we do not interrupt rebasing detached HEAD fi case "$topic" in refs/heads/??/*) ;; *) exit 0 ;# we do not interrupt others. ;; esac # Now we are dealing with a topic branch being rebased # on top of master. Is it OK to rebase it? # Does the topic really exist? git show-ref -q "$topic" || { echo >&2 "No such branch $topic" exit 1 } # Is topic fully merged to master? not_in_master=`git rev-list --pretty=oneline ^master "$topic"` if test -z "$not_in_master" then echo >&2 "$topic is fully merged to master; better remove it." exit 1 ;# we could allow it, but there is no point. fi # Is topic ever merged to next? If so you should not be rebasing it. only_next_1=`git rev-list ^master "^$topic" ${publish} | sort` only_next_2=`git rev-list ^master ${publish} | sort` if test "$only_next_1" = "$only_next_2" then not_in_topic=`git rev-list "^$topic" master` if test -z "$not_in_topic" then echo >&2 "$topic is already up to date with master" exit 1 ;# we could allow it, but there is no point. else exit 0 fi else not_in_next=`git rev-list --pretty=oneline ^${publish} "$topic"` /usr/bin/perl -e ' my $topic = $ARGV[0]; my $msg = "* $topic has commits already merged to public branch:\n"; my (%not_in_next) = map { /^([0-9a-f]+) /; ($1 => 1); } split(/\n/, $ARGV[1]); for my $elem (map { /^([0-9a-f]+) (.*)$/; [$1 => $2]; } split(/\n/, $ARGV[2])) { if (!exists $not_in_next{$elem->[0]}) { if ($msg) { print STDERR $msg; undef $msg; } print STDERR " $elem->[1]\n"; } } ' "$topic" "$not_in_next" "$not_in_master" exit 1 fi <<\DOC_END This sample hook safeguards topic branches that have been published from being rewound. The workflow assumed here is: * Once a topic branch forks from "master", "master" is never merged into it again (either directly or indirectly). * Once a topic branch is fully cooked and merged into "master", it is deleted. If you need to build on top of it to correct earlier mistakes, a new topic branch is created by forking at the tip of the "master". This is not strictly necessary, but it makes it easier to keep your history simple. * Whenever you need to test or publish your changes to topic branches, merge them into "next" branch. The script, being an example, hardcodes the publish branch name to be "next", but it is trivial to make it configurable via $GIT_DIR/config mechanism. With this workflow, you would want to know: (1) ... if a topic branch has ever been merged to "next". Young topic branches can have stupid mistakes you would rather clean up before publishing, and things that have not been merged into other branches can be easily rebased without affecting other people. But once it is published, you would not want to rewind it. (2) ... if a topic branch has been fully merged to "master". Then you can delete it. More importantly, you should not build on top of it -- other people may already want to change things related to the topic as patches against your "master", so if you need further changes, it is better to fork the topic (perhaps with the same name) afresh from the tip of "master". Let's look at this example: o---o---o---o---o---o---o---o---o---o "next" / / / / / a---a---b A / / / / / / / / c---c---c---c B / / / / \ / / / / b---b C \ / / / / / \ / ---o---o---o---o---o---o---o---o---o---o---o "master" A, B and C are topic branches. * A has one fix since it was merged up to "next". * B has finished. It has been fully merged up to "master" and "next", and is ready to be deleted. * C has not merged to "next" at all. We would want to allow C to be rebased, refuse A, and encourage B to be deleted. To compute (1): git rev-list ^master ^topic next git rev-list ^master next if these match, topic has not merged in next at all. To compute (2): git rev-list master..topic if this is empty, it is fully merged to "master". DOC_END ================================================ FILE: .git/hooks/pre-receive.sample ================================================ #!/bin/sh # # An example hook script to make use of push options. # The example simply echoes all push options that start with 'echoback=' # and rejects all pushes when the "reject" push option is used. # # To enable this hook, rename this file to "pre-receive". if test -n "$GIT_PUSH_OPTION_COUNT" then i=0 while test "$i" -lt "$GIT_PUSH_OPTION_COUNT" do eval "value=\$GIT_PUSH_OPTION_$i" case "$value" in echoback=*) echo "echo from the pre-receive-hook: ${value#*=}" >&2 ;; reject) exit 1 esac i=$((i + 1)) done fi ================================================ FILE: .git/hooks/prepare-commit-msg.sample ================================================ #!/bin/sh # # An example hook script to prepare the commit log message. # Called by "git commit" with the name of the file that has the # commit message, followed by the description of the commit # message's source. The hook's purpose is to edit the commit # message file. If the hook fails with a non-zero status, # the commit is aborted. # # To enable this hook, rename this file to "prepare-commit-msg". # This hook includes three examples. The first one removes the # "# Please enter the commit message..." help message. # # The second includes the output of "git diff --name-status -r" # into the message, just before the "git status" output. It is # commented because it doesn't cope with --amend or with squashed # commits. # # The third example adds a Signed-off-by line to the message, that can # still be edited. This is rarely a good idea. COMMIT_MSG_FILE=$1 COMMIT_SOURCE=$2 SHA1=$3 /usr/bin/perl -i.bak -ne 'print unless(m/^. Please enter the commit message/..m/^#$/)' "$COMMIT_MSG_FILE" # case "$COMMIT_SOURCE,$SHA1" in # ,|template,) # /usr/bin/perl -i.bak -pe ' # print "\n" . `git diff --cached --name-status -r` # if /^#/ && $first++ == 0' "$COMMIT_MSG_FILE" ;; # *) ;; # esac # SOB=$(git var GIT_COMMITTER_IDENT | sed -n 's/^\(.*>\).*$/Signed-off-by: \1/p') # git interpret-trailers --in-place --trailer "$SOB" "$COMMIT_MSG_FILE" # if test -z "$COMMIT_SOURCE" # then # /usr/bin/perl -i.bak -pe 'print "\n" if !$first_line++' "$COMMIT_MSG_FILE" # fi ================================================ FILE: .git/hooks/push-to-checkout.sample ================================================ #!/bin/sh # An example hook script to update a checked-out tree on a git push. # # This hook is invoked by git-receive-pack(1) when it reacts to git # push and updates reference(s) in its repository, and when the push # tries to update the branch that is currently checked out and the # receive.denyCurrentBranch configuration variable is set to # updateInstead. # # By default, such a push is refused if the working tree and the index # of the remote repository has any difference from the currently # checked out commit; when both the working tree and the index match # the current commit, they are updated to match the newly pushed tip # of the branch. This hook is to be used to override the default # behaviour; however the code below reimplements the default behaviour # as a starting point for convenient modification. # # The hook receives the commit with which the tip of the current # branch is going to be updated: commit=$1 # It can exit with a non-zero status to refuse the push (when it does # so, it must not modify the index or the working tree). die () { echo >&2 "$*" exit 1 } # Or it can make any necessary changes to the working tree and to the # index to bring them to the desired state when the tip of the current # branch is updated to the new commit, and exit with a zero status. # # For example, the hook can simply run git read-tree -u -m HEAD "$1" # in order to emulate git fetch that is run in the reverse direction # with git push, as the two-tree form of git read-tree -u -m is # essentially the same as git switch or git checkout that switches # branches while keeping the local changes in the working tree that do # not interfere with the difference between the branches. # The below is a more-or-less exact translation to shell of the C code # for the default behaviour for git's push-to-checkout hook defined in # the push_to_deploy() function in builtin/receive-pack.c. # # Note that the hook will be executed from the repository directory, # not from the working tree, so if you want to perform operations on # the working tree, you will have to adapt your code accordingly, e.g. # by adding "cd .." or using relative paths. if ! git update-index -q --ignore-submodules --refresh then die "Up-to-date check failed" fi if ! git diff-files --quiet --ignore-submodules -- then die "Working directory has unstaged changes" fi # This is a rough translation of: # # head_has_history() ? "HEAD" : EMPTY_TREE_SHA1_HEX if git cat-file -e HEAD 2>/dev/null then head=HEAD else head=$(git hash-object -t tree --stdin &2 exit 1 } unset GIT_DIR GIT_WORK_TREE cd "$worktree" && if grep -q "^diff --git " "$1" then validate_patch "$1" else validate_cover_letter "$1" fi && if test "$GIT_SENDEMAIL_FILE_COUNTER" = "$GIT_SENDEMAIL_FILE_TOTAL" then git config --unset-all sendemail.validateWorktree && trap 'git worktree remove -ff "$worktree"' EXIT && validate_series fi ================================================ FILE: .git/hooks/update.sample ================================================ #!/bin/sh # # An example hook script to block unannotated tags from entering. # Called by "git receive-pack" with arguments: refname sha1-old sha1-new # # To enable this hook, rename this file to "update". # # Config # ------ # hooks.allowunannotated # This boolean sets whether unannotated tags will be allowed into the # repository. By default they won't be. # hooks.allowdeletetag # This boolean sets whether deleting tags will be allowed in the # repository. By default they won't be. # hooks.allowmodifytag # This boolean sets whether a tag may be modified after creation. By default # it won't be. # hooks.allowdeletebranch # This boolean sets whether deleting branches will be allowed in the # repository. By default they won't be. # hooks.denycreatebranch # This boolean sets whether remotely creating branches will be denied # in the repository. By default this is allowed. # # --- Command line refname="$1" oldrev="$2" newrev="$3" # --- Safety check if [ -z "$GIT_DIR" ]; then echo "Don't run this script from the command line." >&2 echo " (if you want, you could supply GIT_DIR then run" >&2 echo " $0 )" >&2 exit 1 fi if [ -z "$refname" -o -z "$oldrev" -o -z "$newrev" ]; then echo "usage: $0 " >&2 exit 1 fi # --- Config allowunannotated=$(git config --type=bool hooks.allowunannotated) allowdeletebranch=$(git config --type=bool hooks.allowdeletebranch) denycreatebranch=$(git config --type=bool hooks.denycreatebranch) allowdeletetag=$(git config --type=bool hooks.allowdeletetag) allowmodifytag=$(git config --type=bool hooks.allowmodifytag) # check for no description projectdesc=$(sed -e '1q' "$GIT_DIR/description") case "$projectdesc" in "Unnamed repository"* | "") echo "*** Project description file hasn't been set" >&2 exit 1 ;; esac # --- Check types # if $newrev is 0000...0000, it's a commit to delete a ref. zero=$(git hash-object --stdin &2 echo "*** Use 'git tag [ -a | -s ]' for tags you want to propagate." >&2 exit 1 fi ;; refs/tags/*,delete) # delete tag if [ "$allowdeletetag" != "true" ]; then echo "*** Deleting a tag is not allowed in this repository" >&2 exit 1 fi ;; refs/tags/*,tag) # annotated tag if [ "$allowmodifytag" != "true" ] && git rev-parse $refname > /dev/null 2>&1 then echo "*** Tag '$refname' already exists." >&2 echo "*** Modifying a tag is not allowed in this repository." >&2 exit 1 fi ;; refs/heads/*,commit) # branch if [ "$oldrev" = "$zero" -a "$denycreatebranch" = "true" ]; then echo "*** Creating a branch is not allowed in this repository" >&2 exit 1 fi ;; refs/heads/*,delete) # delete branch if [ "$allowdeletebranch" != "true" ]; then echo "*** Deleting a branch is not allowed in this repository" >&2 exit 1 fi ;; refs/remotes/*,commit) # tracking branch ;; refs/remotes/*,delete) # delete tracking branch if [ "$allowdeletebranch" != "true" ]; then echo "*** Deleting a tracking branch is not allowed in this repository" >&2 exit 1 fi ;; *) # Anything else (is there anything else?) echo "*** Update hook: unknown type of update to ref $refname of type $newrev_type" >&2 exit 1 ;; esac # --- Finished exit 0 ================================================ FILE: .git/info/exclude ================================================ # git ls-files --others --exclude-from=.git/info/exclude # Lines that start with '#' are comments. # For a project mostly in C, the following would be a good set of # exclude patterns (uncomment them if you want to use them): # *.[oa] # *~ ================================================ FILE: .git/logs/HEAD ================================================ 0000000000000000000000000000000000000000 b49833bf56a0e102b8ac1ff95ed7766545f5bd1e appuser 1774064477 +0000 clone: from https://github.com/alibaba/zvec ================================================ FILE: .git/logs/refs/heads/main ================================================ 0000000000000000000000000000000000000000 b49833bf56a0e102b8ac1ff95ed7766545f5bd1e appuser 1774064477 +0000 clone: from https://github.com/alibaba/zvec ================================================ FILE: .git/logs/refs/remotes/origin/HEAD ================================================ 0000000000000000000000000000000000000000 b49833bf56a0e102b8ac1ff95ed7766545f5bd1e appuser 1774064477 +0000 clone: from https://github.com/alibaba/zvec ================================================ FILE: .git/objects/pack/pack-2b5e15ebe928a592991dc24c7ae7e8dc9e3500dc.promisor ================================================ b49833bf56a0e102b8ac1ff95ed7766545f5bd1e refs/heads/main ================================================ FILE: .git/packed-refs ================================================ # pack-refs with: peeled fully-peeled sorted b49833bf56a0e102b8ac1ff95ed7766545f5bd1e refs/remotes/origin/main ================================================ FILE: .git/refs/heads/main ================================================ b49833bf56a0e102b8ac1ff95ed7766545f5bd1e ================================================ FILE: .git/refs/remotes/origin/HEAD ================================================ ref: refs/remotes/origin/main ================================================ FILE: .git/shallow ================================================ b49833bf56a0e102b8ac1ff95ed7766545f5bd1e ================================================ FILE: .github/ISSUE_TEMPLATE/benchmark.yml ================================================ name: Benchmarking description: Add, update, or fix benchmark cases for zvec title: "[Benchmark]: " labels: ["benchmark"] body: - type: markdown attributes: value: | Use this for benchmark-related work: new test cases, CI integration, or performance regression tracking. - type: input id: benchmark_type attributes: label: Benchmark Type description: e.g., filtered search, batch insert, recall@k, ARM64 vs x86 validations: required: true - type: textarea id: goal attributes: label: Goal description: What performance aspect are you measuring or improving? validations: required: true - type: textarea id: methodology attributes: label: Methodology description: Dataset, query size, hardware, metrics (latency, throughput, memory) validations: required: true - type: textarea id: baseline attributes: label: Baseline (if applicable) description: Current performance numbers or competing systems for comparison. validations: required: false - type: textarea id: ci_integration attributes: label: CI Integration Plan description: Should this run in CI? How often? validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ name: Bug Report description: Report a bug or unexpected behavior (e.g., crash, incorrect vector query, memory leak) title: "[Bug]: " labels: ["bug", "triage"] body: - type: markdown attributes: value: | Thank you for reporting! Please provide detailed info so we can reproduce and fix it quickly. - type: textarea id: description attributes: label: Description description: What happened? What did you expect? placeholder: | e.g. "Query with vector field crashes when using Zvec Python API" validations: required: true - type: textarea id: steps_to_reproduce attributes: label: Steps to Reproduce description: Exact steps to trigger the issue (code snippets welcome) placeholder: | 1. Build Zvec with CMake (Debug/Release) 2. Run Python script: `python test.py` 3. Call `collection.query(VectorQuery())` 4. Process segfaults / hangs / returns wrong results render: python validations: required: true - type: textarea id: logs_or_trace attributes: label: Logs / Stack Trace description: Paste relevant logs, LLDB/GDB backtrace, or CI failures placeholder: | Thread 1 "python" received signal SIGSEGV, Segmentation fault. 0x0000000104a2c3f0 in std::__1::shared_ptr<...>::... render: shell validations: required: false - type: input id: os attributes: label: Operating System placeholder: macOS 14 (M1), Ubuntu 22.04, Windows 11 (WSL2) validations: required: true - type: input id: build_env attributes: label: Build & Runtime Environment description: Compiler, CMake, Python, key dependencies placeholder: | clang 15.0.0, CMake 4.1.2, Python 3.11.9, magic_enum v0.9.7 (via git submodule) validations: required: true - type: checkboxes id: additional_context attributes: label: Additional Context options: - label: I've checked `git status` — no uncommitted submodule changes - label: I built with `CMAKE_BUILD_TYPE=Debug` - label: This occurs with or without `COVERAGE=ON` - label: The issue involves Python ↔ C++ integration (pybind11) ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: Documentation url: https://zvec.org/en/ about: Check the quickstart, build guide, and API docs first. - name: Python API Examples url: https://zvec.org/en/docs/quickstart/ about: See working usage examples. ================================================ FILE: .github/ISSUE_TEMPLATE/enhancement.yml ================================================ name: Enhancement description: Improve an existing feature or component title: "[Enhance]: " labels: ["enhancement"] body: - type: markdown attributes: value: | This template is for improving existing functionality (e.g., performance, usability, robustness). - type: input id: component attributes: label: Affected Component description: e.g., HNSW index, buffer manager, Python API validations: required: true - type: textarea id: current attributes: label: Current Behavior description: What is the current state and its limitations? validations: required: true - type: textarea id: desired attributes: label: Desired Improvement description: What should be improved and how? validations: required: true - type: textarea id: impact attributes: label: Impact description: How will this benefit users? (e.g., faster queries, lower memory, easier integration) validations: required: true ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: Feature Request description: Suggest a new feature or improvement (e.g., better memory control, new query option) title: "[Feature]: " labels: ["feature"] body: - type: markdown attributes: value: | Thanks for your idea! Help us understand the motivation and scope. - type: textarea id: problem_or_motivation attributes: label: Problem / Motivation description: What problem does this solve? Why is it needed? placeholder: | e.g. "Current vector queries don't allow filtering by metadata + distance threshold at once" validations: required: true - type: textarea id: proposed_solution attributes: label: Proposed Solution description: How should it work? API sketch or pseudocode welcome. placeholder: | Add `filter=` and `max_distance=` args to `Zvec.query()`: ```python results = db.query(vector, filter="category == 'A'", max_distance=0.5) ``` render: python validations: required: false - type: textarea id: alternatives attributes: label: Alternatives Considered description: Are there workarounds? Why not use them? validations: required: false - type: dropdown id: impact_area attributes: label: Affected Area multiple: true options: - label: C++ Core (storage, indexing) - label: Python API / Bindings - label: Build System (CMake, Homebrew pkg) - label: Testing / CI / Coverage - label: Documentation validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/integration.yml ================================================ name: Ecosystem Integration description: Integrate zvec with external frameworks (e.g., LangChain, LlamaIndex) title: "[Integration]: " labels: ["integration"] body: - type: input id: framework attributes: label: Target Framework description: e.g., LangChain, LlamaIndex, Haystack validations: required: true - type: textarea id: motivation attributes: label: Motivation description: Why integrate with this framework? Who benefits? validations: required: true - type: textarea id: interface attributes: label: Required Interface description: What adapter or interface must be implemented? (e.g., VectorStore base class) validations: required: true - type: textarea id: reference attributes: label: Reference Implementations description: Links to similar integrations in other vector DBs. validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/profiling.yml ================================================ name: Profiling / Investigation description: Profile performance, compatibility, or behavior in a specific scenario title: "[Profile]: " labels: ["profile"] body: - type: markdown attributes: value: | Use this for tasks like performance profiling, architecture compatibility checks, or feasibility studies. - type: input id: scenario attributes: label: Target Scenario description: e.g., ARM64 deployment, high-concurrency load, large dataset ingestion validations: required: true - type: textarea id: objective attributes: label: Objective description: What do you want to learn or validate? validations: required: true - type: textarea id: methodology attributes: label: Proposed Methodology description: How will you conduct the investigation? (tools, metrics, test data) validations: required: true - type: textarea id: expected_outcome attributes: label: Expected Outcome description: What deliverables are expected? (e.g., report, optimization PR, benchmark results) validations: required: true ================================================ FILE: .github/codecov.yml ================================================ codecov: require_ci_to_pass: true coverage: precision: 2 round: down range: "60...75" status: project: default: false patch: default: false parsers: gcov: branch_detection: conditional: true loop: true method: false macro: false comment: require_changes: false layout: "reach,diff,flags,tree" behavior: default ignore: - "thirdparty/" - "tests/" ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: # GitHub Actions dependencies - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" day: "monday" time: "02:00" timezone: "Asia/Shanghai" labels: - "dependencies" - "github-actions" commit-message: prefix: "ci" include: "scope" open-pull-requests-limit: 5 ================================================ FILE: .github/workflows/01-ci-pipeline.yml ================================================ name: Main on: push: branches: [ "main" ] paths-ignore: - '**.md' merge_group: pull_request: branches: [ "main" ] paths-ignore: - '**.md' workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }} cancel-in-progress: true permissions: contents: read jobs: # Code quality checks (fast, run first) lint: uses: ./.github/workflows/02-lint-check.yml # Main build and test matrix build-and-test-macos-arm64: name: Build & Test (macos-arm64) needs: lint uses: ./.github/workflows/03-macos-linux-build.yml with: platform: macos-arm64 os: macos-15 build-and-test-linux-arm64: name: Build & Test (linux-arm64) needs: lint uses: ./.github/workflows/03-macos-linux-build.yml with: platform: linux-arm64 os: ubuntu-24.04-arm build-and-test-linux-x64: name: Build & Test (linux-x64) needs: lint uses: ./.github/workflows/03-macos-linux-build.yml with: platform: linux-x64 os: ubuntu-24.04 build-android: name: Build & Test (android) needs: lint uses: ./.github/workflows/04-android-build.yml ================================================ FILE: .github/workflows/02-lint-check.yml ================================================ name: Lint on: workflow_call: jobs: lint: name: Code Quality Checks runs-on: ubuntu-24.04 steps: - name: Checkout code uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.10' cache: 'pip' cache-dependency-path: 'pyproject.toml' - name: Install linting tools run: | python -m pip install --upgrade pip \ ruff==v0.14.4 \ clang-format==18.1.8 shell: bash - name: Run Ruff Linter run: python -m ruff check . shell: bash - name: Run Ruff Formatter Check run: python -m ruff format --check . shell: bash - name: Run clang-format Check run: | CPP_FILES=$(find . -type f \( -name "*.cpp" -o -name "*.h" -o -name "*.hpp" -o -name "*.cc" -o -name "*.cxx" \) \ ! -path "./build/*" \ ! -path "./tests/*" \ ! -path "./scripts/*" \ ! -path "./python/*" \ ! -path "./thirdparty/*" \ ! -path "./.git/*") if [ -z "$CPP_FILES" ]; then echo "No C++ files found to check." exit 0 fi clang-format --dry-run --Werror $CPP_FILES shell: bash ================================================ FILE: .github/workflows/03-macos-linux-build.yml ================================================ name: MacOS & Linux Build on: workflow_call: inputs: platform: description: 'Platform identifier' required: true type: string os: description: 'GitHub Actions runner OS' required: true type: string permissions: contents: read jobs: # Build and test matrix (parallel execution) build-and-test: name: Build & Test (${{ inputs.platform }}) runs-on: ${{ inputs.os }} strategy: fail-fast: false matrix: include: - os: ${{ inputs.os }} platform: ${{ inputs.platform }} arch_flag: "" # Use appropriate architecture steps: - name: Checkout code uses: actions/checkout@v6 with: submodules: recursive - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.10' cache: 'pip' cache-dependency-path: 'pyproject.toml' - name: Set up environment variables run: | # Set number of processors for parallel builds if [[ "${{ matrix.platform }}" == "macos-arm64" ]]; then NPROC=$(sysctl -n hw.ncpu 2>/dev/null || echo 2) else NPROC=$(nproc 2>/dev/null || echo 2) fi echo "NPROC=$NPROC" >> $GITHUB_ENV echo "Using $NPROC parallel jobs for builds" # Add Python user base bin to PATH for pip-installed CLI tools echo "$(python -c 'import site; print(site.USER_BASE)')/bin" >> $GITHUB_PATH shell: bash - name: Install dependencies run: | python -m pip install --upgrade pip \ pybind11==3.0 \ cmake==3.30.0 \ ninja==1.11.1 \ pytest \ scikit-build-core \ setuptools_scm shell: bash - name: Build from source run: | cd "$GITHUB_WORKSPACE" CMAKE_GENERATOR="Unix Makefiles" \ CMAKE_BUILD_PARALLEL_LEVEL="$NPROC" \ python -m pip install -v . \ --no-build-isolation \ --config-settings='cmake.define.BUILD_TOOLS="ON"' \ ${{ matrix.arch_flag }} shell: bash - name: Run C++ Tests run: | cd "$GITHUB_WORKSPACE/build" make unittest -j$NPROC shell: bash - name: Run Python Tests run: | cd "$GITHUB_WORKSPACE" python -m pytest python/tests/ shell: bash - name: Run C++ Examples run: | cd "$GITHUB_WORKSPACE/examples/c++" mkdir build && cd build cmake .. -DCMAKE_BUILD_TYPE=Release make -j $NPROC ./db-example ./core-example ./ailego-example shell: bash ================================================ FILE: .github/workflows/04-android-build.yml ================================================ name: Android Cross Build on: workflow_call: permissions: contents: read jobs: build-android: # sdkmanager and other Android tools are x86‑only; ARM runners fail with exit code 1 # switch back to an x86 image so the setup-android action can install the SDK runs-on: ubuntu-24.04 strategy: fail-fast: false matrix: abi: [x86_64] api: [21] steps: - name: Checkout uses: actions/checkout@v6 - name: Cache dependencies uses: actions/cache@v5 with: path: | ~/.ccache key: ${{ runner.os }}-dependencies-cache-${{ hashFiles('**/CMakeLists.txt', 'thirdparty/**') }}-stl-fix - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y --no-install-recommends \ cmake ninja-build git ca-certificates python3 \ build-essential make ccache - name: Setup Java 17 uses: actions/setup-java@v5 with: distribution: temurin java-version: '17' - name: Setup Android SDK uses: android-actions/setup-android@v3 - name: Install NDK (side by side) shell: bash run: | sdkmanager "ndk;26.1.10909125" - name: Cache host protoc build uses: actions/cache@v5 with: path: build-host key: ${{ runner.os }}-host-protoc-${{ hashFiles('src/**', 'CMakeLists.txt') }}-stl-fix restore-keys: | ${{ runner.os }}-host-protoc- - name: Use host env to compile protoc shell: bash run: | git submodule update --init if [ ! -d "build-host" ]; then export CCACHE_BASEDIR="$GITHUB_WORKSPACE" export CCACHE_NOHASHDIR=1 export CCACHE_SLOPPINESS=clang_index_store,file_stat_matches,include_file_mtime,locale,time_macros cmake -S . -B build-host -G Ninja \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build-host --target protoc --parallel else echo "Using cached host protoc build" fi - name: Cache Android build uses: actions/cache@v5 with: path: build-android-${{ matrix.abi }} key: ${{ runner.os }}-android-build-${{ matrix.abi }}-${{ hashFiles('src/**', 'CMakeLists.txt', 'cmake/**', 'thirdparty/**') }}-stl-fix-3 - name: Configure and Build shell: bash run: | git submodule foreach --recursive 'git stash --include-untracked' export ANDROID_SDK_ROOT="$ANDROID_HOME" export ANDROID_NDK_HOME="$ANDROID_SDK_ROOT/ndk/26.1.10909125" export CCACHE_BASEDIR="$GITHUB_WORKSPACE" export CCACHE_NOHASHDIR=1 export CCACHE_SLOPPINESS=clang_index_store,file_stat_matches,include_file_mtime,locale,time_macros if [ ! -d "build-android-${{ matrix.abi }}" ]; then cmake -S . -B build-android-${{ matrix.abi }} -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK_HOME/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=${{ matrix.abi }} \ -DANDROID_PLATFORM=android-${{ matrix.api }} \ -DANDROID_STL=c++_static \ -DBUILD_PYTHON_BINDINGS=OFF \ -DENABLE_NATIVE=OFF \ -DAUTO_DETECT_ARCH=OFF \ -DBUILD_TOOLS=OFF \ -DGLOBAL_CC_PROTOBUF_PROTOC="$GITHUB_WORKSPACE/build-host/bin/protoc" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_VERBOSE_MAKEFILE=ON cmake --build build-android-${{ matrix.abi }} --parallel --verbose else echo "Using cached Android build directory" fi - name: Cache examples build uses: actions/cache@v5 with: path: examples/c++/build-android-examples-${{ matrix.abi }} key: ${{ runner.os }}-examples-build-${{ matrix.abi }}-${{ hashFiles('examples/c++/**', 'CMakeLists.txt', 'src/**') }}-stl-fix-3 - name: Build examples shell: bash run: | export ANDROID_SDK_ROOT="$ANDROID_HOME" export ANDROID_NDK_HOME="$ANDROID_SDK_ROOT/ndk/26.1.10909125" if [ ! -d "examples/c++/build-android-examples-${{ matrix.abi }}" ]; then cmake -S examples/c++ -B examples/c++/build-android-examples-${{ matrix.abi }} -G Ninja \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK_HOME/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=${{ matrix.abi }} \ -DANDROID_PLATFORM=android-${{ matrix.api }} \ -DANDROID_STL=c++_static \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON \ -DHOST_BUILD_DIR="build-android-${{ matrix.abi }}" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build examples/c++/build-android-examples-${{ matrix.abi }} --parallel else echo "Using cached examples build" fi - name: Run on Android emulator (arm64) and verify uses: reactivecircus/android-emulator-runner@v2 with: api-level: ${{ matrix.api }} arch: ${{ matrix.abi }} # target: google_apis # emulator-options: -no-window -gpu swiftshader_indirect -noaudio -no-boot-anim -netdelay none -netspeed full # disable-animations: true script: | adb wait-for-device echo "Device ABI:" adb shell getprop ro.product.cpu.abi adb shell getprop ro.product.cpu.abilist echo "=== CPU ISA / Instruction Set Support ===" echo "--- /proc/cpuinfo flags ---" adb shell 'cat /proc/cpuinfo | grep -E "^(Features|flags)"' echo "Checking binary sizes:" ls -lah examples/c++/build-android-examples-${{ matrix.abi }}/ # Push executables to device adb push examples/c++/build-android-examples-${{ matrix.abi }}/ailego-example /data/local/tmp/ adb push examples/c++/build-android-examples-${{ matrix.abi }}/core-example /data/local/tmp/ adb push examples/c++/build-android-examples-${{ matrix.abi }}/db-example /data/local/tmp/ adb shell chmod 755 /data/local/tmp/ailego-example adb shell chmod 755 /data/local/tmp/core-example adb shell chmod 755 /data/local/tmp/db-example echo "File info on device:" adb shell ls -la /data/local/tmp/ailego-example adb shell ls -la /data/local/tmp/core-example adb shell ls -la /data/local/tmp/db-example echo "Running ailego example:" adb shell 'cd /data/local/tmp && ./ailego-example' echo "Running core example:" adb shell 'cd /data/local/tmp && ./core-example' echo "Running db example:" adb shell 'cd /data/local/tmp && ./db-example' ================================================ FILE: .github/workflows/_build_wheel_job.yml ================================================ name: "(Reusable) Build, Publish and Smoke-test a Wheel" on: workflow_call: inputs: runner: description: "GitHub Actions runner label" required: true type: string pypi_repository_url: description: "PyPI repository URL (empty string means official PyPI)" required: false type: string default: "" secrets: PYPI_API_TOKEN: required: true jobs: build_publish_test: name: Build / publish / smoke-test on ${{ inputs.runner }} runs-on: ${{ inputs.runner }} permissions: contents: read steps: - name: Checkout code uses: actions/checkout@v6 with: submodules: recursive - name: Set up Python (for cibuildwheel controller) uses: actions/setup-python@v6 with: python-version: '3.11' - name: Install cibuildwheel run: | pip install --upgrade pip pip install cibuildwheel==3.4.0 - name: Build wheels using cibuildwheel run: | python -m cibuildwheel --output-dir wheelhouse # Save list of built wheels for publishing ls wheelhouse/*.whl | tee $GITHUB_STEP_SUMMARY echo "wheels=$(ls wheelhouse/*.whl | tr '\n' ' ')" >> $GITHUB_ENV - name: Publish to PyPI if: success() && github.event_name == 'workflow_dispatch' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} TWINE_REPOSITORY_URL: ${{ inputs.pypi_repository_url }} run: | pip install twine twine upload --skip-existing --verbose wheelhouse/*.whl - name: Smoke test from PyPI if: success() && github.event_name == 'workflow_dispatch' shell: bash env: PYPI_REPOSITORY_URL: ${{ inputs.pypi_repository_url }} run: | # Extract version from wheel filename (e.g. zvec-0.2.1.dev24-cp311-...whl -> 0.2.1.dev24) WHEEL_FILE=$(ls wheelhouse/zvec-*.whl | head -1) ZVEC_VERSION=$(basename "$WHEEL_FILE" | sed 's/zvec-\([^-]*\)-.*/\1/') # Build index-url flags: use TestPyPI when repository URL is set, otherwise official PyPI if [ -n "$PYPI_REPOSITORY_URL" ]; then INDEX_FLAGS="--index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/" echo "Waiting for zvec==$ZVEC_VERSION to become available on TestPyPI..." else INDEX_FLAGS="" echo "Waiting for zvec==$ZVEC_VERSION to become available on PyPI..." fi # Poll until the version is available (max 5 minutes) FOUND=0 for i in $(seq 1 30); do if pip install $INDEX_FLAGS --dry-run "zvec==$ZVEC_VERSION" > /dev/null 2>&1; then echo "Version $ZVEC_VERSION is available." FOUND=1 break fi echo "Attempt $i/30: not yet available, retrying in 10s..." sleep 10 done if [ "$FOUND" -eq 0 ]; then echo "ERROR: Timed out (5 min) waiting for zvec==$ZVEC_VERSION on PyPI. Aborting smoke test." exit 1 fi # Create a clean venv and install python -m venv test_env source test_env/bin/activate pip install --upgrade pip pip install $INDEX_FLAGS "zvec==$ZVEC_VERSION" pip install --upgrade pip pip install $INDEX_FLAGS "zvec==$ZVEC_VERSION" # Run a simple smoke test python -c "import zvec; print('Import OK:', zvec.__version__)" ================================================ FILE: .github/workflows/build_test_wheel.yml ================================================ name: Build Test PyPi Wheels on: workflow_dispatch: permissions: contents: read jobs: build_wheels_linux_x64: name: Build wheels on ubuntu-24.04 (x64) for TestPyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: ubuntu-24.04 pypi_repository_url: https://test.pypi.org/legacy/ secrets: PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} build_wheels_linux_arm64: name: Build wheels on ubuntu-24.04-arm (arm64) for TestPyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: ubuntu-24.04-arm pypi_repository_url: https://test.pypi.org/legacy/ secrets: PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} build_wheels_macos_arm64: name: Build wheels on macos-15 (arm64) for TestPyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: macos-15 pypi_repository_url: https://test.pypi.org/legacy/ secrets: PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} ================================================ FILE: .github/workflows/build_wheel.yml ================================================ name: Build Wheels on: workflow_dispatch: permissions: contents: read jobs: build_wheels_linux_x64: name: Build wheels on ubuntu-24.04 (x64) for PyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: ubuntu-24.04 secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} build_wheels_linux_arm64: name: Build wheels on ubuntu-24.04-arm (arm64) for PyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: ubuntu-24.04-arm secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} build_wheels_macos_arm64: name: Build wheels on macos-15 (arm64) for PyPi uses: ./.github/workflows/_build_wheel_job.yml with: runner: macos-15 secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .github/workflows/continuous_bench.yml ================================================ name: Continuous Benchmark on: push: branches: [ "main", "ci/continuous_bench_squash" ] paths-ignore: - '**.md' workflow_dispatch: concurrency: group: cb-${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true permissions: contents: read jobs: benchmark: runs-on: vdbbench steps: - uses: actions/checkout@v6 - name: Run VectorDBBench env: DATABASE_URL: ${{ secrets.DATABASE_URL }} run: | bash .github/workflows/scripts/run_vdb.sh ================================================ FILE: .github/workflows/docker/Dockerfile.linux_x64_glibc228 ================================================ # ============================================================================= # Dockerfile.linux_x64_glibc228 # Purpose: Ubuntu 18.10 gcc-9 + glibc 2.28 + CMake 3.30.0 + PyBind11 build environment # Warning: ubuntu:18.10 is EOL; use only for glibc 2.28 compatibility testing. # ============================================================================= # Use official Ubuntu 18.10 (Cosmic Cuttlefish) # glibc version: 2.28 (confirmed via `ldd --version`) FROM ubuntu:18.10 # Replace Ubuntu mirror with old-releases.ubuntu.com for older glibc compatibility RUN sed -i 's|http://\(.*\)/ubuntu|http://old-releases.ubuntu.com/ubuntu|g' /etc/apt/sources.list && \ sed -i 's|http://security.ubuntu.com/ubuntu|http://old-releases.ubuntu.com/ubuntu|g' /etc/apt/sources.list # Add Ubuntu 20.04 (focal) repo for GCC 9 ONLY RUN echo "deb http://archive.ubuntu.com/ubuntu/ focal main universe" >> /etc/apt/sources.list && \ echo "deb http://security.ubuntu.com/ubuntu/ focal-security main universe" >> /etc/apt/sources.list # Prevent interactive prompts & set non-root user ENV DEBIAN_FRONTEND=noninteractive \ TZ=Etc/UTC # Create non-root user for safety (optional but recommended) RUN useradd -m -u 1000 builder && \ mkdir -p /workspace && chown builder:builder /workspace # Install base system dependencies RUN apt-get update && \ apt-get install -y --no-install-recommends \ build-essential \ gcc-9 g++-9 \ ninja-build git curl ca-certificates vim wget lcov gnupg clang-format-18\ rsync lsb-release \ uuid-dev zlib1g-dev libssl-dev libffi-dev \ pybind11-dev && \ update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 90 \ --slave /usr/bin/g++ g++ /usr/bin/g++-9 && \ rm -rf /var/lib/apt/lists/* # Install Miniforge (Conda) as root, then assign to builder ENV MINIFORGE_VERSION="latest" ENV MINIFORGE_HOME="/opt/miniforge3" RUN curl -sSL "https://github.com/conda-forge/miniforge/releases/${MINIFORGE_VERSION}/download/Miniforge3-Linux-x86_64.sh" -o miniforge.sh && \ bash miniforge.sh -b -p ${MINIFORGE_HOME} && \ rm miniforge.sh && \ chown -R builder:builder ${MINIFORGE_HOME} # Switch to non-root user USER builder ENV PATH="${MINIFORGE_HOME}/bin:${PATH}" WORKDIR /workspace # Create conda envs for supported Python versions RUN conda create -n py310 python=3.10 -y && \ conda create -n py311 python=3.11 -y && \ conda create -n py312 python=3.12 -y RUN conda clean --all -f -y # Install CMake 3.30.0 from Kitware official binary # Ref: https://github.com/Kitware/CMake/releases/tag/v3.30.0 RUN mkdir -p /tmp/cmake && cd /tmp/cmake && \ curl -sSL -o cmake.tar.gz \ "https://github.com/Kitware/CMake/releases/download/v3.30.0/cmake-3.30.0-linux-x86_64.tar.gz" && \ tar -xzf cmake.tar.gz --strip-components=1 -C /tmp/cmake && \ mkdir -p /home/builder/.local && \ mv * /home/builder/.local/ && \ chown -R builder:builder /home/builder/.local && \ rm -rf /tmp/cmake # Add CMake to PATH ENV PATH="/home/builder/.local/bin:${PATH}" # Verify installations RUN cmake --version && \ conda info && \ conda env list && \ python --version && \ gcc --version && \ ldd --version | head -n1 # Final setup WORKDIR /workspace ================================================ FILE: .github/workflows/nightly_coverage.yml ================================================ name: Nightly Coverage Report on: schedule: # Runs daily at 00:00 CST (China Standard Time) = 16:00 UTC - cron: '0 16 * * *' workflow_dispatch: permissions: contents: read jobs: coverage: name: Nightly Coverage Report runs-on: ubuntu-24.04 strategy: matrix: python-version: ['3.10'] fail-fast: false steps: - name: Checkout code uses: actions/checkout@v6 with: ref: main # Always use main for nightly submodules: recursive - name: Set up Python uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} cache: 'pip' cache-dependency-path: 'pyproject.toml' - name: Set up environment variables run: | # Set number of processors for parallel builds NPROC=$(nproc 2>/dev/null || echo 2) echo "NPROC=$NPROC" >> $GITHUB_ENV echo "Using $NPROC parallel jobs for builds" # Add Python user base bin to PATH for pip-installed CLI tools echo "$(python -c 'import site; print(site.USER_BASE)')/bin" >> $GITHUB_PATH shell: bash - name: Install dependencies run: | python -m pip install --upgrade pip \ cmake==3.30.0 \ ninja==1.11.1 \ pytest \ pytest-cov \ scikit-build-core \ setuptools_scm shell: bash - name: Build with COVERAGE config run: | cd "$GITHUB_WORKSPACE" CMAKE_GENERATOR="Unix Makefiles" \ CMAKE_BUILD_PARALLEL_LEVEL="$NPROC" \ python -m pip install -v . \ --no-build-isolation \ --config-settings="cmake.build-type=COVERAGE" \ --config-settings='cmake.define.ENABLE_ZEN3="ON"' shell: bash - name: Run Python Tests with Coverage run: | cd "$GITHUB_WORKSPACE" python -m pytest python/tests/ --cov=zvec --cov-report=xml shell: bash - name: Run C++ Tests and Generate Coverage run: | cd "$GITHUB_WORKSPACE/build" make unittest -j$NPROC cd "$GITHUB_WORKSPACE" # Ensure gcov.sh is executable chmod +x scripts/gcov.sh bash scripts/gcov.sh -k shell: bash - name: Upload Coverage to Codecov uses: codecov/codecov-action@v5 with: files: ./proxima-zvec-filtered.lcov.info,./coverage.xml flags: python,cpp,nightly name: nightly-linux-py${{ matrix.python-version }} token: ${{ secrets.CODECOV_TOKEN }} ================================================ FILE: .github/workflows/scripts/run_vdb.sh ================================================ set -e QUANTIZE_TYPE_LIST="int8 int4 fp16 fp32" CASE_TYPE_LIST="Performance768D1M Performance768D10M Performance1536D500K" # respectively test cosine, ip # Performance960D1M l2 metrics LOG_FILE="bench.log" DATE=$(date +%Y-%m-%d_%H-%M-%S) NPROC=$(nproc 2>/dev/null || getconf _NPROCESSORS_ONLN 2>/dev/null || echo 2) # COMMIT_ID = branch-date-sha COMMIT_ID=${GITHUB_REF_NAME}-"$DATE"-$(echo ${GITHUB_WORKFLOW_SHA} | cut -c1-8) COMMIT_ID=$(echo "$COMMIT_ID" | sed 's/\//_/g') echo "COMMIT_ID: $COMMIT_ID" echo "GITHUB_WORKFLOW_SHA: $GITHUB_WORKFLOW_SHA" echo "workspace: $GITHUB_WORKSPACE" DB_LABEL_PREFIX="Zvec16c64g-$COMMIT_ID" # install zvec git submodule update --init # for debug #cd .. #export SKBUILD_BUILD_DIR="$GITHUB_WORKSPACE/../build" pwd python3 -m venv .venv source .venv/bin/activate pip install cmake ninja psycopg2-binary loguru fire pip install -e /opt/VectorDBBench CMAKE_GENERATOR="Unix Makefiles" \ CMAKE_BUILD_PARALLEL_LEVEL="$NPROC" \ pip install -v "$GITHUB_WORKSPACE" for CASE_TYPE in $CASE_TYPE_LIST; do echo "Running VectorDBBench for $CASE_TYPE" DATASET_DESC="" if [ "$CASE_TYPE" == "Performance768D1M" ]; then DATASET_DESC="Performance768D1M - Cohere Cosine" elif [ "$CASE_TYPE" == "Performance768D10M" ]; then DATASET_DESC="Performance768D10M - Cohere Cosine" else DATASET_DESC="Performance1536D500K - OpenAI IP" fi for QUANTIZE_TYPE in $QUANTIZE_TYPE_LIST; do DB_LABEL="$DB_LABEL_PREFIX-$CASE_TYPE-$QUANTIZE_TYPE" echo "Running VectorDBBench for $DB_LABEL" VDB_PARAMS="--path ${DB_LABEL} --db-label ${DB_LABEL} --case-type ${CASE_TYPE} --num-concurrency 12,14,16,18,20" if [ "$CASE_TYPE" == "Performance768D1M" ]; then VDB_PARAMS="${VDB_PARAMS} --m 15 --ef-search 180" elif [ "$CASE_TYPE" == "Performance768D10M" ]; then VDB_PARAMS="${VDB_PARAMS} --m 50 --ef-search 118 --is-using-refiner" else #Performance1536D500K using default params + refiner to monitor performance degradation VDB_PARAMS="${VDB_PARAMS} --m 50 --ef-search 100 --is-using-refiner" fi if [ "$QUANTIZE_TYPE" == "fp32" ]; then vectordbbench zvec ${VDB_PARAMS} 2>&1 | tee $LOG_FILE else vectordbbench zvec ${VDB_PARAMS} --quantize-type "${QUANTIZE_TYPE}" 2>&1 | tee $LOG_FILE fi RESULT_JSON_PATH=$(grep -o "/opt/VectorDBBench/.*\.json" $LOG_FILE) QPS=$(jq -r '.results[0].metrics.qps' "$RESULT_JSON_PATH") RECALL=$(jq -r '.results[0].metrics.recall' "$RESULT_JSON_PATH") LATENCY_P99=$(jq -r '.results[0].metrics.serial_latency_p99' "$RESULT_JSON_PATH") LOAD_DURATION=$(jq -r '.results[0].metrics.load_duration' "$RESULT_JSON_PATH") #quote the var to avoid space in the label label_list="case_type=\"${CASE_TYPE}\",dataset_desc=\"${DATASET_DESC}\",db_label=\"${DB_LABEL}\",commit=\"${COMMIT_ID}\",date=\"${DATE}\",quantize_type=\"${QUANTIZE_TYPE}\"" # replace `/` with `_` in label_list label_list=$(echo "$label_list" | sed 's/\//_/g') cat < prom_metrics.txt # TYPE vdb_bench_qps gauge vdb_bench_qps{$label_list} $QPS # TYPE vdb_bench_recall gauge vdb_bench_recall{$label_list} $RECALL # TYPE vdb_bench_latency_p99 gauge vdb_bench_latency_p99{$label_list} $LATENCY_P99 # TYPE vdb_bench_load_duration gauge vdb_bench_load_duration{$label_list} $LOAD_DURATION EOF echo "prom_metrics:" cat prom_metrics.txt curl --data-binary @prom_metrics.txt "http://47.93.34.27:9091/metrics/job/benchmarks-${CASE_TYPE}/case_type/${CASE_TYPE}/quantize_type/${QUANTIZE_TYPE}" -v done done ================================================ FILE: .gitignore ================================================ .* *~ bazel-* build* bin/* lib/* var/* venv* tests/integration/conf/* tests/de_integration/conf/* **/__pycache__/* tests/bench/log/* tests/integration/integration tests/integration/log tests/integration/*.log tests/de_integration/log tests/de_integration/*.log !.git* !.clang-format !.circleci !.drone.yml sdk/python/dist/ compile_commands.json dist html *.lcov.info # Dependencies /node_modules # Production /build # Generated files .docusaurus .cache-loader # Misc .DS_Store .env.local .env.development.local .env.test.local .env.production.local npm-debug.log* yarn-debug.log* yarn-error.log* allure-* !build_android.sh ================================================ FILE: .gitmodules ================================================ [submodule "thirdparty/googletest/googletest-1.10.0"] path = thirdparty/googletest/googletest-1.10.0 url = https://github.com/google/googletest.git [submodule "thirdparty/sparsehash/sparsehash-2.0.4"] path = thirdparty/sparsehash/sparsehash-2.0.4 url = https://github.com/sparsehash/sparsehash.git ignore = untracked [submodule "thirdparty/gflags/gflags-2.2.2"] path = thirdparty/gflags/gflags-2.2.2 url = https://github.com/gflags/gflags.git [submodule "thirdparty/rocksdb/rocksdb-8.1.1"] path = thirdparty/rocksdb/rocksdb-8.1.1 url = https://github.com/facebook/rocksdb.git ignore = all [submodule "thirdparty/yaml-cpp/yaml-cpp-0.6.3"] path = thirdparty/yaml-cpp/yaml-cpp-0.6.3 url = https://github.com/jbeder/yaml-cpp.git [submodule "thirdparty/arrow/apache-arrow-21.0.0"] path = thirdparty/arrow/apache-arrow-21.0.0 url = https://github.com/apache/arrow.git ignore = all [submodule "thirdparty/CRoaring/CRoaring-2.0.4"] path = thirdparty/CRoaring/CRoaring-2.0.4 url = https://github.com/RoaringBitmap/CRoaring.git [submodule "thirdparty/glog/glog-0.5.0"] path = thirdparty/glog/glog-0.5.0 url = https://github.com/google/glog.git ignore = all [submodule "thirdparty/protobuf/protobuf-3.21.12"] path = thirdparty/protobuf/protobuf-3.21.12 url = https://github.com/protocolbuffers/protobuf.git [submodule "thirdparty/lz4/lz4-1.9.4"] path = thirdparty/lz4/lz4-1.9.4 url = https://github.com/lz4/lz4.git [submodule "thirdparty/antlr/antlr4"] path = thirdparty/antlr/antlr4 url = https://github.com/antlr/antlr4.git ignore = all [submodule "thirdparty/magic_enum/magic_enum-0.9.7"] path = thirdparty/magic_enum/magic_enum-0.9.7 url = https://github.com/Neargye/magic_enum.git ignore = all [submodule "thirdparty/RaBitQ-Library/RaBitQ-Library-0.1"] path = thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 url = https://github.com/VectorDB-NTU/RaBitQ-Library.git ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.13) cmake_policy(SET CMP0077 NEW) project(zvec) set(CC_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror=return-type") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Werror=return-type") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--no-as-needed") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") endif() if(NOT DEFINED PROJECT_ROOT_DIR OR NOT PROJECT_ROOT_DIR) set(PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE PATH "Root directory of the project" FORCE) endif() message(STATUS "PROJECT_ROOT_DIR = ${PROJECT_ROOT_DIR}") include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) if(NOT ANDROID AND AUTO_DETECT_ARCH AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") setup_compiler_march_for_x86(MATH_MARCH_FLAG_SSE MATH_MARCH_FLAG_AVX2 MATH_MARCH_FLAG_AVX512 MATH_MARCH_FLAG_AVX512FP16) message(STATUS "best compiler march, sse: " ${MATH_MARCH_FLAG_SSE} ", avx2: " ${MATH_MARCH_FLAG_AVX2} ", avx512: " ${MATH_MARCH_FLAG_AVX512} ", avx512fp16: " ${MATH_MARCH_FLAG_AVX512FP16}) endif() include_directories(${PROJECT_ROOT_DIR}/src/include) include_directories(${PROJECT_ROOT_DIR}/src) option(BUILD_PYTHON_BINDINGS "Build Python bindings using pybind11" OFF) message(STATUS "BUILD_PYTHON_BINDINGS:${BUILD_PYTHON_BINDINGS}") option(BUILD_TOOLS "Build tools" ON) message(STATUS "BUILD_TOOLS:${BUILD_TOOLS}") option(RABITQ_ENABLE_AVX512 "Compile RaBitQ with AVX-512 support" OFF) if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64|AMD64" AND NOT ANDROID) include(CheckCCompilerFlag) check_c_compiler_flag("-mavx2" COMPILER_SUPPORTS_AVX2) check_c_compiler_flag("-mavx512f -mavx512bw -mavx512vl" COMPILER_SUPPORTS_AVX512) if(COMPILER_SUPPORTS_AVX2 OR COMPILER_SUPPORTS_AVX512) set(RABITQ_SUPPORTED ON) add_definitions(-DRABITQ_SUPPORTED=1) if(RABITQ_ENABLE_AVX512 AND COMPILER_SUPPORTS_AVX512) add_definitions(-DRABITQ_COMPILED_AVX512=1) set(RABITQ_ARCH_FLAG "${MATH_MARCH_FLAG_AVX512}") else() set(RABITQ_ARCH_FLAG "${MATH_MARCH_FLAG_AVX2}") endif() else() set(RABITQ_SUPPORTED OFF) add_definitions(-DRABITQ_SUPPORTED=0) message(STATUS "RaBitQ support disabled - compiler does not support AVX2 or AVX-512") endif() else() set(RABITQ_SUPPORTED OFF) add_definitions(-DRABITQ_SUPPORTED=0) message(STATUS "RaBitQ support disabled - only supported on Linux x86_64") endif() message(STATUS "RABITQ_ARCH_FLAG: ${RABITQ_ARCH_FLAG}") option(USE_OSS_MIRROR "Use OSS mirror for faster third-party downloads" OFF) if(DEFINED ENV{USE_OSS_MIRROR} AND NOT "$ENV{USE_OSS_MIRROR}" STREQUAL "") set(USE_OSS_MIRROR "$ENV{USE_OSS_MIRROR}" CACHE BOOL "Use OSS mirror for faster third-party downloads" FORCE) endif() message(STATUS "USE_OSS_MIRROR:${USE_OSS_MIRROR}") cc_directory(thirdparty) cc_directories(src) cc_directories(tests) if(BUILD_TOOLS) cc_directories(tools) endif() git_version(GIT_SRCS_VER ${PROJECT_ROOT_DIR}) set(CPACK_PACKAGE_VERSION ${GIT_SRCS_VER}) set(CPACK_PACKAGE_NAME zvec) include(CPack) if(BUILD_PYTHON_BINDINGS) if(APPLE) set(CMAKE_STRIP "") message(STATUS "Disabled strip on macOS to preserve code signature") endif() include(GNUInstallDirs) if(DEFINED SKBUILD_PLATLIB_DIR) set(ZVEC_PY_INSTALL_DIR "${SKBUILD_PLATLIB_DIR}") elseif(DEFINED Python_SITEARCH) set(ZVEC_PY_INSTALL_DIR "${Python_SITEARCH}") else() set(ZVEC_PY_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}") endif() message(STATUS "Zvec install path: ${ZVEC_PY_INSTALL_DIR}") install(TARGETS _zvec LIBRARY DESTINATION ${ZVEC_PY_INSTALL_DIR}) endif() ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Zvec Code of Conduct ## Our Pledge We pledge to foster an open, respectful, and harassment-free environment for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Expected Behavior - Use welcoming and inclusive language - Respect differing viewpoints and experiences - Gracefully accept constructive criticism - Focus on what is best for the community - Show empathy and kindness toward others ## Unacceptable Behavior - Harassment, intimidation, or discriminatory conduct - Trolling, insulting, or derogatory comments - Public or private harassment - Publishing others’ private information without consent - Any conduct that would reasonably be considered inappropriate in a professional setting ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at **zvec@alibaba-inc.com** (replace with your contact). All complaints will be reviewed and investigated promptly and fairly. The project team is obligated to respect the privacy and security of the reporter. Consequences may include: - A formal warning - Temporary or permanent ban from project spaces - Removal of contributions (e.g. comments, PRs) ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html [homepage]: https://www.contributor-covenant.org ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Zvec First off, thank you for considering contributing to Zvec! 🙌 Whether you're reporting a bug, proposing a feature, improving documentation, or submitting code — every contribution helps make Zvec better. ## Code of Conduct By participating, you agree to abide by our [Code of Conduct](CODE_OF_CONDUCT.md). Please be respectful, collaborative, and inclusive. --- ## Development Setup ### Prerequisites - Python 3.10 - 3.12 - CMake ≥ 3.26, < 4.0 (`cmake --version`) - A C++17-compatible compiler (e.g., `g++-11+`, `clang++`, Apple Clang on macOS) ### Clone & Initialize ```bash git clone --recursive https://github.com/alibaba/zvec.git cd zvec ``` > 💡 **Tip** > - Forgot `--recursive`? Run: > ```bash > git submodule update --init --recursive > ``` > - Set up pre-commit hooks: > ```bash > pip install pre-commit && pre-commit install > ``` ### Build from Source (Editable Install) ```bash pip install -e ".[dev]" # This installs dev dependencies (pytest, ruff, etc.) and builds the C++ extension in-place ``` > ✅ Verify: > ```bash > python -c "import zvec; print('Success!')" > ``` --- ## Testing ### Run All Tests ```bash pytest python/tests/ -v ``` ### Run with Coverage (Debug/CI) ```bash pytest python/tests/ --cov=zvec --cov-report=term-missing ``` > 🔎 See full rules in `[tool.ruff]` section of `pyproject.toml`. --- ## Build Customization You can control build behavior via environment variables or `pyproject.toml`: | Option | How to Set | Description | |--------|------------|-------------| | **Build Type** | `CMAKE_BUILD_TYPE=Debug` | `Debug`, `Release`, or `Coverage` (for gcov/lcov) | | **Generator** | `CMAKE_GENERATOR="Unix Makefiles"` | Default: `Ninja`; use Make if preferred | | **AVX-512** | `ENABLE_SKYLAKE_AVX512=ON` | Enable AVX-512 optimizations (x86_64 only) | Example (Debug + Make): ```bash CMAKE_BUILD_TYPE=Debug CMAKE_GENERATOR="Unix Makefiles" pip install -v . ``` --- ## Submitting Changes 1. Fork the repo and create a feature branch (`feat/...`, `fix/...`, `docs/...`) 2. Write clear commit messages (e.g., `fix(query): handle null vector in dense_fp32`) 3. Ensure tests pass & linter is clean 4. Open a Pull Request to `main` 5. Link related issue (e.g., `Closes #123`) ✅ **PRs should include**: - Test coverage for new behavior - Updates to documentation (if applicable) - Reasoning behind non-obvious design choices --- ## Documentation - User guides: `docs/` (built with MkDocs) - API reference: generated from docstrings (follow [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings)) - Build & deploy: `mkdocs serve` / `mkdocs build` --- ## Need Help - Browse [existing issues](https://github.com/alibaba/zvec/issues) - For sensitive/security issues: email `zvec@alibaba-inc.com` --- ✨ Thanks again for being part of Zvec! ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
zvec logo

Code Coverage Main License PyPI Release Python Versions npm Release

alibaba%2Fzvec | Trendshift

🚀 Quickstart | 🏠 Home | 📚 Docs | 📊 Benchmarks | 🔎 DeepWiki | 🎮 Discord

**Zvec** is an open-source, in-process vector database — lightweight, lightning-fast, and designed to embed directly into applications. Built on **Proxima** (Alibaba's battle-tested vector search engine), it delivers production-grade, low-latency, scalable similarity search with minimal setup. ## 💫 Features - **Blazing Fast**: Searches billions of vectors in milliseconds. - **Simple, Just Works**: [Install](#-installation) and start searching in seconds. No servers, no config, no fuss. - **Dense + Sparse Vectors**: Work with both dense and sparse embeddings, with native support for multi-vector queries in a single call. - **Hybrid Search**: Combine semantic similarity with structured filters for precise results. - **Runs Anywhere**: As an in-process library, Zvec runs wherever your code runs — notebooks, servers, CLI tools, or even edge devices. ## 📦 Installation ### [Python](https://pypi.org/project/zvec/) **Requirements**: Python 3.10 - 3.12 ```bash pip install zvec ``` ### [Node.js](https://www.npmjs.com/package/@zvec/zvec) ```bash npm install @zvec/zvec ``` ### ✅ Supported Platforms - Linux (x86_64, ARM64) - macOS (ARM64) ### 🛠️ Building from Source If you prefer to build Zvec from source, please check the [Building from Source](https://zvec.org/en/docs/build/) guide. ## ⚡ One-Minute Example ```python import zvec # Define collection schema schema = zvec.CollectionSchema( name="example", vectors=zvec.VectorSchema("embedding", zvec.DataType.VECTOR_FP32, 4), ) # Create collection collection = zvec.create_and_open(path="./zvec_example", schema=schema) # Insert documents collection.insert([ zvec.Doc(id="doc_1", vectors={"embedding": [0.1, 0.2, 0.3, 0.4]}), zvec.Doc(id="doc_2", vectors={"embedding": [0.2, 0.3, 0.4, 0.1]}), ]) # Search by vector similarity results = collection.query( zvec.VectorQuery("embedding", vector=[0.4, 0.3, 0.3, 0.1]), topk=10 ) # Results: list of {'id': str, 'score': float, ...}, sorted by relevance print(results) ``` ## 📈 Performance at Scale Zvec delivers exceptional speed and efficiency, making it ideal for demanding production workloads. Zvec Performance Benchmarks For detailed benchmark methodology, configurations, and complete results, please see our [Benchmarks documentation](https://zvec.org/en/docs/benchmarks/). ## 🤝 Join Our Community
Stay updated and get support — scan or click:
| 💬 DingTalk | 📱 WeChat | 🎮 Discord | |:---:|:---:|:---:| | | | [![Discord](https://img.shields.io/badge/Discord-Join%20Server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/rKddFBBu9z) | | Scan to join | Scan to join | Click to join |
## ❤️ Contributing We welcome and appreciate contributions from the community! Whether you're fixing a bug, adding a feature, or improving documentation, your help makes Zvec better for everyone. Check out our [Contributing Guide](./CONTRIBUTING.md) to get started! ================================================ FILE: cmake/bazel.cmake ================================================ ## ## The following functions used by user's CMakeLists.txt: ## ## 1. Functions for C/C++ ## ## 1.1. Add a subdirectory to the build ## cc_directory( [binary_dir]) ## ## 1.2. Add subdirectories to the build ## cc_directories( [source_dir2 ...]) ## ## 1.3. Build a C/C++ static or shared library ## cc_library( ## NAME ## [STATIC] [SHARED] [STRICT] [ALWAYS_LINK] [EXCLUDE] [PACKED] [SRCS_NO_GLOB] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [PACKED_EXCLUDES pattern1 ...] ## [VERSION ] ## ) ## ## 1.4. Build a C/C++ executable program ## cc_binary( ## NAME ## [STRICT] [PACKED] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [VERSION ] ## ) ## ## 1.5. Build a C/C++ executable test program ## cc_test( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 1.6. Add existing test cases to a test suite ## cc_test_suite( [test_name ...]) ## ## 1.7. Import a C/C++ static or shared library ## cc_import( ## NAME ## [STATIC | SHARED] [PACKED] ## PATH ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEPS target1 ...] ## [IMPLIB ] ## [PACKED_EXCLUDES pattern1 ...] ## ) ## ## 1.8. Import a C/C++ interface library ## cc_interface( ## NAME ## [PACKED] ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEPS target1 ...] ## [PACKED_EXCLUDES pattern1 ...] ## ) ## ## 1.9. Build a C/C++ executable google test program ## cc_gtest( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 1.10. Build a C/C++ executable google mock program ## cc_gmock( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 1.11. Build a C++ protobuf static or shared library ## cc_proto_library( ## NAME ## [STATIC] [SHARED] [STRICT] [EXCLUDE] [PACKED] ## SRCS [file2.proto ...] ## [PROTOROOT path] ## [CXXFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [VERSION ] ## [PROTOBUF_VERSION ] ## ) ## ## 2. Functions for CUDA ## ## 2.1. Add a subdirectory to the build ## cuda_directory( [binary_dir]) ## ## 2.2. Add subdirectories to the build ## cuda_directories( [source_dir2 ...]) ## ## 2.3. Build a CUDA static or shared library ## cuda_library( ## NAME ## [STATIC] [SHARED] [STRICT] [ALWAYS_LINK] [EXCLUDE] [PACKED] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [CUDAFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [PACKED_EXCLUDES pattern1 ...] ## [VERSION ] ## ) ## ## 2.4. Build a CUDA executable program ## cuda_binary( ## NAME ## [STRICT] [PACKED] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [CUDAFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [VERSION ] ## ) ## ## 2.5. Build a CUDA executable test program ## cuda_test( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [CUDAFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 2.6. Add existing test cases to a test suite ## cuda_test_suite( [test_name ...]) ## ## 2.7. Import a C/C++/CUDA static or shared library ## cuda_import( ## NAME ## [STATIC | SHARED] [PACKED] ## PATH ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEPS target1 ...] ## [IMPLIB ] ## [PACKED_EXCLUDES pattern1 ...] ## ) ## ## 2.8. Import a C/C++/CUDA interface library ## cuda_interface( ## NAME ## [PACKED] ## [INCS dir1 ...] ## [PUBINCS public_dir1 ...] ## [DEPS target1 ...] ## [PACKED_EXCLUDES pattern1 ...] ## ) ## ## 2.9. Build a CUDA executable google test program ## cuda_gtest( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [CUDAFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 2.10. Build a CUDA executable google mock program ## cuda_gmock( ## NAME ## [STRICT] ## SRCS [file2 ...] ## [INCS dir1 ...] ## [DEFS DEF1=1 ...] ## [LIBS lib1 ...] ## [CFLAGS flag1 ...] ## [CXXFLAGS flag1 ...] ## [CUDAFLAGS flag1 ...] ## [LDFLAGS flag1 ...] ## [DEPS target1 ...] ## [ARGS args1 ...] ## [VERSION ] ## ) ## ## 3. Utility functions ## ## 3.1. Download a git repository ## git_repository( ## NAME ## URL ## [TAG ] ## [PATH ] ## ) ## ## 3.2. Download a hg repository ## hg_repository( ## NAME ## URL ## [TAG ] ## [PATH ] ## ) ## ## 3.3. Download a svn repository ## svn_repository( ## NAME ## URL ## [REV ] ## [PATH ] ## ) ## ## 3.4. Download a http archive ## http_archive( ## NAME ## URL ## [SHA256 | SHA1 | MD5 ] ## [PATH ] ## ) ## ## 3.5. Retrieve a version string from GIT ## git_version( ## ## ## ) ## ## 3.6. Retrieve a version string from HG ## hg_version( ## ## ## ) ## ## 3.7. Retrieve a version string from SVN ## svn_version( ## ## ## ) ## cmake_minimum_required(VERSION 3.1 FATAL_ERROR) include(CMakeParseArguments) # Using AppleClang instead of Clang (Compiler id) if(POLICY CMP0025) cmake_policy(SET CMP0025 NEW) endif() # Enable unit testing enable_testing() # Add unittest target if(NOT TARGET unittest) add_custom_target( unittest COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure --build-config $ ) endif() # Directories of target output if(NOT CMAKE_ARCHIVE_OUTPUT_DIRECTORY) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) endif() if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) endif() if(NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin) endif() # RPATH settings set(CMAKE_MACOSX_RPATH ON) if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(CMAKE_SKIP_BUILD_RPATH ON) set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) if(${CMAKE_SIZEOF_VOID_P} EQUAL "8") set(CMAKE_INSTALL_RPATH "$ORIGIN/../lib64:$ORIGIN/../lib:$ORIGIN") else() set(CMAKE_INSTALL_RPATH "$ORIGIN/../lib:$ORIGIN") endif() else() set(CMAKE_INSTALL_RPATH "@loader_path/../lib:@loader_path") endif() # Define standard installation directories if(NOT CMAKE_INSTALL_LIBDIR) set(CMAKE_INSTALL_LIBDIR lib) endif() if(NOT CMAKE_INSTALL_BINDIR) set(CMAKE_INSTALL_BINDIR bin) endif() if(NOT CMAKE_INSTALL_INCDIR) set(CMAKE_INSTALL_INCDIR include) endif() if(NOT CMAKE_INSTALL_ETCDIR) set(CMAKE_INSTALL_ETCDIR etc) endif() # Generates a compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) if(APPLE OR ANDROID) option(CLANG_USE_LIBCXX "Use libc++ instead of libstdc++" ON) else() option(CLANG_USE_LIBCXX "Use libc++ instead of libstdc++" OFF) endif() set(CLANG_STDLIB_OPTION "") if(CLANG_USE_LIBCXX) set(CLANG_STDLIB_OPTION "-stdlib=libc++") else() set(CLANG_STDLIB_OPTION "-stdlib=libstdc++") endif() if(NOT MSVC) # Use color in diagnostics set( _COMPILER_FLAGS "$<$:-fcolor-diagnostics;${CLANG_STDLIB_OPTION}>" "$<$:-fcolor-diagnostics>" "$<$:-fdiagnostics-color=always>" ) add_compile_options( "$<$:${_COMPILER_FLAGS}>" "$<$:${_COMPILER_FLAGS}>" ) unset(_COMPILER_FLAGS) else() # Replace the default compiling flags set( _COMPILER_FLAGS CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_RELWITHDEBINFO CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_RELWITHDEBINFO CMAKE_C_FLAGS_MINSIZEREL ) foreach(COMPILER_FLAG ${_COMPILER_FLAGS}) string(REPLACE "/MT" "/MD" ${COMPILER_FLAG} "${${COMPILER_FLAG}}") string(REGEX REPLACE "/W[0-9]" "" ${COMPILER_FLAG} "${${COMPILER_FLAG}}") endforeach() unset(_COMPILER_FLAGS) add_definitions(-D_CRT_SECURE_NO_WARNINGS) # Build shared library as default set(BUILD_SHARED_LIBS ON) endif() set(CMAKE_C_FLAGS_ASAN ${CMAKE_C_FLAGS_DEBUG}) set(CMAKE_CXX_FLAGS_ASAN ${CMAKE_CXX_FLAGS_DEBUG}) set(CMAKE_EXE_LINKER_FLAGS_ASAN ${CMAKE_EXE_LINKER_FLAGS_DEBUG}) set(CMAKE_SHARED_LINKER_FLAGS_ASAN ${CMAKE_SHARED_LINKER_FLAGS_DEBUG}) set(CMAKE_STATIC_LINKER_FLAGS_ASAN ${CMAKE_STATIC_LINKER_FLAGS_DEBUG}) set(CMAKE_MODULE_LINKER_FLAGS_ASAN ${CMAKE_MODULE_LINKER_FLAGS_DEBUG}) set(CMAKE_C_FLAGS_COVERAGE ${CMAKE_C_FLAGS_DEBUG}) set(CMAKE_CXX_FLAGS_COVERAGE ${CMAKE_CXX_FLAGS_DEBUG}) set(CMAKE_EXE_LINKER_FLAGS_COVERAGE ${CMAKE_EXE_LINKER_FLAGS_DEBUG}) set(CMAKE_SHARED_LINKER_FLAGS_COVERAGE ${CMAKE_SHARED_LINKER_FLAGS_DEBUG}) set(CMAKE_STATIC_LINKER_FLAGS_COVERAGE ${CMAKE_STATIC_LINKER_FLAGS_DEBUG}) set(CMAKE_MODULE_LINKER_FLAGS_COVERAGE ${CMAKE_MODULE_LINKER_FLAGS_DEBUG}) # C/C++ ASAN compile flags set( BAZEL_CC_ASAN_COMPILE_FLAGS "$<$:$<$:-fsanitize=address>>" "$<$:$<$:-fsanitize=address>>" "$<$:$<$:-fsanitize=address>>" "$<$:$<$:/fsanitize=address>>" ) # C/C++ COVERAGE compile flags set( BAZEL_CC_COVERAGE_COMPILE_FLAGS "$<$:$<$:--coverage>>" "$<$:$<$:--coverage>>" "$<$:$<$:--coverage>>" ) # C/C++ strict compile flags if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set( BAZEL_CC_STRICT_COMPILE_FLAGS "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow-local;-Wno-misleading-indentation>" "$<$:/W4>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) else() set( BAZEL_CC_STRICT_COMPILE_FLAGS "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow;-Wno-misleading-indentation>" "$<$:/W4>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) endif() # C/C++ strict link flags set( BAZEL_CC_STRICT_LINK_FLAGS "$<$:${CLANG_STDLIB_OPTION}>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) # C/C++ unstrict compile flags set( BAZEL_CC_UNSTRICT_COMPILE_FLAGS "$<$:-Wall>" "$<$:-Wall>" "$<$:-Wall>" "$<$:/W3>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) # C/C++ unstrict link flags set( BAZEL_CC_UNSTRICT_LINK_FLAGS "$<$:${CLANG_STDLIB_OPTION}>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) # CUDA strict compile flags set( BAZEL_CUDA_STRICT_COMPILE_FLAGS "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:/W4>>" "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:-Wall;-Wextra;-Wshadow>>" "$<$:$<$:/W4>>" "$<$:$<$:-G>>" ) # CUDA strict link flags set(BAZEL_CUDA_STRICT_LINK_FLAGS "") # CUDA unstrict compile flags set( BAZEL_CUDA_UNSTRICT_COMPILE_FLAGS "$<$:$<$:-Wall>>" "$<$:$<$:-Wall>>" "$<$:$<$:-Wall>>" "$<$:$<$:/W3>>" "$<$:$<$:-Wall>>" "$<$:$<$:-Wall>>" "$<$:$<$:-Wall>>" "$<$:$<$:/W3>>" "$<$:$<$:-G>>" ) # CUDA unstrict link flags set(BAZEL_CUDA_UNSTRICT_LINK_FLAGS "") ## Find workspace directory function(_find_workspace_directory _RESULT) # Find Workspace.cmake folder set(_CURRENT_WORKSPACE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) get_filename_component( _PARENT_WORKSPACE_DIR ${_CURRENT_WORKSPACE_DIR} DIRECTORY ) while(NOT ("${_CURRENT_WORKSPACE_DIR}" STREQUAL "${_PARENT_WORKSPACE_DIR}")) if(EXISTS "${_CURRENT_WORKSPACE_DIR}/Workspace.cmake") set(${_RESULT} ${_CURRENT_WORKSPACE_DIR} PARENT_SCOPE) message(STATUS "Found workspace at ${${_RESULT}}") break() endif() # Find next parent folder set(_CURRENT_WORKSPACE_DIR ${_PARENT_WORKSPACE_DIR}) get_filename_component( _PARENT_WORKSPACE_DIR ${_CURRENT_WORKSPACE_DIR} DIRECTORY ) endwhile() endfunction() ## Retrieve absolute paths function(_absolute_paths _RESULT) foreach(FILEPATH ${ARGN}) if(NOT IS_ABSOLUTE ${FILEPATH}) get_filename_component(FILEPATH ${FILEPATH} ABSOLUTE) endif() list(APPEND FILEPATHS ${FILEPATH}) endforeach() set(${_RESULT} "${FILEPATHS}" PARENT_SCOPE) endfunction() ## Add both shared and static library macro(_add_library _NAME _OPTION) add_library(${_NAME}_objects OBJECT ${_OPTION} ${ARGN}) add_library( ${_NAME}_static STATIC ${_OPTION} $ ) add_library( ${_NAME} SHARED ${_OPTION} $ ) add_dependencies(${_NAME} ${_NAME}_static) if(NOT MSVC) set_property(TARGET ${_NAME}_static PROPERTY OUTPUT_NAME ${_NAME}) endif() endmacro() ## Link dependencies function(_targets_link_dependencies _NAME) foreach(LIB ${ARGN}) if(TARGET ${LIB}) list(APPEND LIBS_DEPS ${LIB}) list( APPEND LIBS_INCS "$" ) endif() endforeach() if(LIBS_DEPS) add_dependencies(${_NAME} ${LIBS_DEPS}) target_include_directories(${_NAME} PRIVATE "${LIBS_INCS}") endif() endfunction() ## Link libraries function(_target_link_libraries _NAME) function(_collect_always_link_libs LIB_LIST RESULT_VAR) if(NOT _COLLECT_ALWAYS_LINK_VISITED) set(_COLLECT_ALWAYS_LINK_VISITED "" PARENT_SCOPE) endif() set(LOCAL_RESULT "") foreach(LIB ${LIB_LIST}) if(NOT TARGET ${LIB}) continue() endif() list(FIND _COLLECT_ALWAYS_LINK_VISITED ${LIB} ALREADY_VISITED) if(NOT ALREADY_VISITED EQUAL -1) continue() endif() list(APPEND _COLLECT_ALWAYS_LINK_VISITED ${LIB}) set(_COLLECT_ALWAYS_LINK_VISITED "${_COLLECT_ALWAYS_LINK_VISITED}" PARENT_SCOPE) get_target_property(ALWAYS_LINK ${LIB} ALWAYS_LINK) if(ALWAYS_LINK) list(APPEND LOCAL_RESULT ${LIB}) endif() get_target_property(DEP_LIBS ${LIB} INTERFACE_LINK_LIBRARIES) if(DEP_LIBS) _collect_always_link_libs("${DEP_LIBS}" DEP_ALWAYS_LINK_LIBS) list(APPEND LOCAL_RESULT ${DEP_ALWAYS_LINK_LIBS}) endif() get_target_property(LINK_LIBS ${LIB} LINK_LIBRARIES) if(LINK_LIBS) _collect_always_link_libs("${LINK_LIBS}" LINK_ALWAYS_LINK_LIBS) list(APPEND LOCAL_RESULT ${LINK_ALWAYS_LINK_LIBS}) endif() endforeach() list(REMOVE_DUPLICATES LOCAL_RESULT) set(${RESULT_VAR} "${LOCAL_RESULT}" PARENT_SCOPE) endfunction() _collect_always_link_libs("${ARGN}" ALL_ALWAYS_LINK_LIBS) set(ALL_LIBS_TO_PROCESS ${ARGN}) foreach(ALWAYS_LIB ${ALL_ALWAYS_LINK_LIBS}) list(FIND ARGN ${ALWAYS_LIB} FOUND_INDEX) if(FOUND_INDEX EQUAL -1) list(APPEND ALL_LIBS_TO_PROCESS ${ALWAYS_LIB}) endif() endforeach() list(REMOVE_DUPLICATES ALL_LIBS_TO_PROCESS) foreach(LIB ${ALL_LIBS_TO_PROCESS}) if(NOT TARGET ${LIB}) list(APPEND LINK_LIBS ${LIB}) continue() endif() list(FIND ALL_ALWAYS_LINK_LIBS ${LIB} IS_ALWAYS_LINK) if(IS_ALWAYS_LINK EQUAL -1) list(APPEND LINK_LIBS ${LIB}) continue() endif() if(NOT MSVC) if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") list(APPEND LINK_LIBS -Wl,--whole-archive ${LIB} -Wl,--no-whole-archive) else() list(APPEND LINK_LIBS -Wl,-force_load ${LIB}) endif() else() # Microsoft Visual C++ list(APPEND LINK_LIBS /WHOLEARCHIVE:$) get_target_property(OTHER_LINK_LIBS ${LIB} INTERFACE_LINK_LIBRARIES) if(OTHER_LINK_LIBS) foreach(OTHER_LIB ${OTHER_LINK_LIBS}) list(FIND ALL_LIBS_TO_PROCESS ${OTHER_LIB} FOUND_INDEX) if(FOUND_INDEX EQUAL -1) list(APPEND LINK_LIBS ${OTHER_LIB}) endif() endforeach() endif() list(APPEND LIBS_DEPS ${LIB}) list( APPEND LIBS_INCS "$" ) endif() endforeach() target_link_libraries(${_NAME} ${LINK_LIBS}) if(LIBS_DEPS) add_dependencies(${_NAME} ${LIBS_DEPS}) target_include_directories(${_NAME} PRIVATE "${LIBS_INCS}") endif() endfunction() ## Add a subdirectory to the build function(cc_directory) add_subdirectory(${ARGN}) endfunction() ## Add subdirectories to the build function(cc_directories) foreach(SRC_DIR ${ARGN}) add_subdirectory(${SRC_DIR}) endforeach() endfunction() ## Set the properties of target function(_cc_target_properties) cmake_parse_arguments( CC_ARGS "STRICT;ALWAYS_LINK" "NAME;VERSION;C_STANDARD;CXX_STANDARD" "INCS;PUBINCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() get_target_property(TARGET_TYPE ${CC_ARGS_NAME} TYPE) if(("${TARGET_TYPE}" STREQUAL "SHARED_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "STATIC_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "EXECUTABLE")) set(TARGET_LINKABLE TRUE) endif() if(CC_ARGS_ALWAYS_LINK) if(("${TARGET_TYPE}" STREQUAL "STATIC_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "OBJECT_LIBRARY")) set_property(TARGET ${CC_ARGS_NAME} PROPERTY ALWAYS_LINK TRUE) endif() endif() # Set the warning level of compiling if(CC_ARGS_STRICT) target_compile_options( ${CC_ARGS_NAME} PRIVATE "${BAZEL_CC_STRICT_COMPILE_FLAGS}" ) if(TARGET_LINKABLE) target_link_libraries(${CC_ARGS_NAME} "${BAZEL_CC_STRICT_LINK_FLAGS}") endif() else() target_compile_options( ${CC_ARGS_NAME} PRIVATE "${BAZEL_CC_UNSTRICT_COMPILE_FLAGS}" ) if(TARGET_LINKABLE) target_link_libraries(${CC_ARGS_NAME} "${BAZEL_CC_UNSTRICT_LINK_FLAGS}") endif() endif() if(CC_ARGS_DEFS) target_compile_definitions(${CC_ARGS_NAME} PRIVATE "${CC_ARGS_DEFS}") endif() if(CC_ARGS_CFLAGS OR CC_ARGS_CXXFLAGS) target_compile_options( ${CC_ARGS_NAME} PRIVATE "$<$:${CC_ARGS_CFLAGS}>" "$<$:${CC_ARGS_CXXFLAGS}>" ) endif() if(CC_ARGS_LDFLAGS) string(REPLACE ";" " " CC_ARGS_LDFLAGS "${CC_ARGS_LDFLAGS}") set_property( TARGET ${CC_ARGS_NAME} PROPERTY LINK_FLAGS "${CC_ARGS_LDFLAGS}" ) endif() if(CC_ARGS_INCS) _absolute_paths(INC_DIRS ${CC_ARGS_INCS}) target_include_directories(${CC_ARGS_NAME} PRIVATE "${INC_DIRS}") endif() if(BAZEL_WORKSPACE_DIR) target_include_directories(${CC_ARGS_NAME} PRIVATE "${BAZEL_WORKSPACE_DIR}") endif() if(CC_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CC_ARGS_PUBINCS}) target_include_directories(${CC_ARGS_NAME} PUBLIC "${INC_DIRS}") endif() if(CC_ARGS_LIBS) if(NOT TARGET_LINKABLE) _targets_link_dependencies(${CC_ARGS_NAME} ${CC_ARGS_LIBS}) else() if ("${TARGET_TYPE}" STREQUAL "EXECUTABLE") _target_link_libraries(${CC_ARGS_NAME} "${CC_ARGS_LIBS}") else() target_link_libraries(${CC_ARGS_NAME} "${CC_ARGS_LIBS}") endif() endif() endif() if(CC_ARGS_DEPS) add_dependencies(${CC_ARGS_NAME} "${CC_ARGS_DEPS}") endif() if(CC_ARGS_VERSION) set_property( TARGET ${CC_ARGS_NAME} PROPERTY VERSION "${CC_ARGS_VERSION}" ) endif() if(NOT CC_C_STANDARD) set(CC_C_STANDARD 99) endif() if(NOT CC_CXX_STANDARD) set(CC_CXX_STANDARD 11) endif() set_target_properties( ${CC_ARGS_NAME} PROPERTIES DEFINE_SYMBOL "" C_STANDARD ${CC_C_STANDARD} CXX_STANDARD ${CC_CXX_STANDARD} C_STANDARD_REQUIRED ON C_EXTENSIONS ON CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF WINDOWS_EXPORT_ALL_SYMBOLS ON ) endfunction() ## Build a C/C++ static or shared library function(cc_library) cmake_parse_arguments( CC_ARGS "STATIC;SHARED;EXCLUDE;PACKED;SRCS_NO_GLOB" "NAME;VERSION" "SRCS;INCS;PUBINCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS;PACKED_EXCLUDES" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name provided.") endif() if(CC_ARGS_SRCS_NO_GLOB) set(SOURCE_FILES ${CC_ARGS_SRCS}) if(NOT SOURCE_FILES) message(FATAL_ERROR "No source files provided for ${CC_ARGS_NAME} (SRCS_NO_GLOB mode).") endif() else() set(SOURCE_FILES "") foreach(_src IN LISTS CC_ARGS_SRCS) if(IS_ABSOLUTE "${_src}" OR NOT "${_src}" MATCHES "[*?]") list(APPEND SOURCE_FILES "${_src}") else() file(GLOB _globbed_srcs ${_src}) list(APPEND SOURCE_FILES ${_globbed_srcs}) endif() endforeach() if(NOT SOURCE_FILES) message(FATAL_ERROR "No source files found for ${CC_ARGS_NAME} after globbing.") endif() endif() if(CC_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CC_ARGS_NAME}") list(APPEND CC_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CC_ARGS_VERSION}") endif() if(CC_ARGS_EXCLUDE) set(EXCLUDE_OPTION EXCLUDE_FROM_ALL) endif() if(CC_ARGS_SHARED AND CC_ARGS_STATIC) _add_library(${CC_ARGS_NAME} "${EXCLUDE_OPTION}" ${SOURCE_FILES}) elseif(CC_ARGS_SHARED) add_library(${CC_ARGS_NAME} SHARED ${EXCLUDE_OPTION} ${SOURCE_FILES}) elseif(CC_ARGS_STATIC) add_library(${CC_ARGS_NAME} STATIC ${EXCLUDE_OPTION} ${SOURCE_FILES}) else() add_library(${CC_ARGS_NAME} ${EXCLUDE_OPTION} ${SOURCE_FILES}) endif() if(TARGET ${CC_ARGS_NAME}_objects) _cc_target_properties( NAME "${CC_ARGS_NAME}_objects" INCS "${CC_ARGS_INCS};${CC_ARGS_PUBINCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) endif() if(TARGET ${CC_ARGS_NAME}_static) _cc_target_properties( NAME "${CC_ARGS_NAME}_static" INCS "${CC_ARGS_INCS}" PUBINCS "${CC_ARGS_PUBINCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME}_static ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) endif() endif() _cc_target_properties( NAME "${CC_ARGS_NAME}" INCS "${CC_ARGS_INCS}" PUBINCS "${CC_ARGS_PUBINCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" VERSION "${CC_ARGS_VERSION}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME} ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) if(CC_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CC_ARGS_PACKED_EXCLUDES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CC_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" ${PATTERN_EXCLUDES} ) endif() endif() endfunction() ## Build a C/C++ executable program function(cc_binary) cmake_parse_arguments( CC_ARGS "PACKED" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CC_ARGS_SRCS ${CC_ARGS_SRCS}) if(NOT CC_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CC_ARGS_NAME}.") endif() if(CC_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CC_ARGS_NAME}") list(APPEND CC_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CC_ARGS_VERSION}") endif() add_executable(${CC_ARGS_NAME} ${CC_ARGS_SRCS}) if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME} RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" ) endif() _cc_target_properties( NAME "${CC_ARGS_NAME}" INCS "${CC_ARGS_INCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" VERSION "${CC_ARGS_VERSION}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) endfunction() ## Build a C/C++ executable test program function(cc_test) cmake_parse_arguments( CC_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CC_ARGS_SRCS ${CC_ARGS_SRCS}) if(NOT CC_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CC_ARGS_NAME}.") endif() if(CC_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CC_ARGS_NAME}") list(APPEND CC_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CC_ARGS_VERSION}") endif() add_executable(${CC_ARGS_NAME} EXCLUDE_FROM_ALL ${CC_ARGS_SRCS}) _cc_target_properties( NAME "${CC_ARGS_NAME}" INCS "${CC_ARGS_INCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) add_dependencies(unittest ${CC_ARGS_NAME}) add_custom_target( unittest.${CC_ARGS_NAME} COMMAND $ "${CC_ARGS_ARGS}" WORKING_DIRECTORY ${PROJECT_BINARY_DIR} DEPENDS ${CC_ARGS_NAME} ) add_test( NAME ${CC_ARGS_NAME} COMMAND $ "${CC_ARGS_ARGS}" WORKING_DIRECTORY ${PROJECT_BINARY_DIR} ) endfunction() ## Add existing test cases to a test suite function(cc_test_suite _NAME) if(NOT TARGET unittest.${_NAME}) add_custom_target(unittest.${_NAME} COMMAND "") endif() foreach(TEST_TARGET ${ARGN}) list(APPEND TEST_TARGETS unittest.${TEST_TARGET}) endforeach() if(TEST_TARGETS) add_dependencies(unittest.${_NAME} ${TEST_TARGETS}) endif() endfunction() ## Import a C/C++ static or shared library function(cc_import) cmake_parse_arguments( CC_ARGS "STATIC;SHARED;PACKED" "NAME;PATH;IMPLIB" "INCS;PUBINCS;DEPS;PACKED_EXCLUDES" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CC_ARGS_PATH ${CC_ARGS_PATH}) if(NOT CC_ARGS_PATH) message(FATAL_ERROR "No imported target file found of ${CC_ARGS_NAME}.") endif() if(MSVC AND CC_ARGS_SHARED AND NOT CC_ARGS_IMPLIB) string(REGEX REPLACE ".[Dd][Ll][Ll]$" ".lib" CC_ARGS_IMPLIB ${CC_ARGS_PATH} ) endif() if(CC_ARGS_SHARED) add_library(${CC_ARGS_NAME} SHARED IMPORTED GLOBAL) elseif(CC_ARGS_STATIC) add_library(${CC_ARGS_NAME} STATIC IMPORTED GLOBAL) else() add_library(${CC_ARGS_NAME} UNKNOWN IMPORTED GLOBAL) endif() set_property( TARGET ${CC_ARGS_NAME} PROPERTY IMPORTED_LOCATION ${CC_ARGS_PATH} ) if(MSVC AND CC_ARGS_SHARED) set_property( TARGET ${CC_ARGS_NAME} PROPERTY IMPORTED_IMPLIB ${CC_ARGS_IMPLIB} ) endif() if(CC_ARGS_INCS) _absolute_paths(INC_DIRS ${CC_ARGS_INCS}) foreach(INC_DIR ${INC_DIRS}) set_property( TARGET ${CC_ARGS_NAME} APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${INC_DIR}" ) endforeach() endif() if(CC_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CC_ARGS_PUBINCS}) foreach(INC_DIR ${INC_DIRS}) set_property( TARGET ${CC_ARGS_NAME} APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${INC_DIR}" ) endforeach() endif() if(CC_ARGS_DEPS) add_dependencies(${CC_ARGS_NAME} "${CC_ARGS_DEPS}") endif() if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME} ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) if(CC_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CC_ARGS_PACKED_EXCLUDES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CC_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" ${PATTERN_EXCLUDES} ) endif() endif() endfunction() ## Import a C/C++ interface library function(cc_interface) cmake_parse_arguments( CC_ARGS "PACKED" "NAME" "INCS;PUBINCS;DEPS;PACKED_EXCLUDES" ${ARGN} ) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() add_library(${CC_ARGS_NAME} INTERFACE GLOBAL) if(CC_ARGS_INCS) _absolute_paths(INC_DIRS ${CC_ARGS_INCS}) target_include_directories(${CC_ARGS_NAME} INTERFACE "${INC_DIRS}") endif() if(CC_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CC_ARGS_PUBINCS}) target_include_directories(${CC_ARGS_NAME} INTERFACE "${INC_DIRS}") endif() if(CC_ARGS_DEPS) add_dependencies(${CC_ARGS_NAME} "${CC_ARGS_DEPS}") endif() if(CC_ARGS_PACKED AND CC_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CC_ARGS_PACKED_EXCLUDES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CC_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" ${PATTERN_EXCLUDES} ) endif() endfunction() ## Find gtest library function(_find_gtest) if(DEFINED FIND_GTEST_LIBS AND DEFINED FIND_GTEST_INCS) return() endif() if(NOT TARGET gtest OR NOT TARGET gtest_main) # Find gtest using 'find_package' find_package(GTest REQUIRED) set( FIND_GTEST_INCS "${GTEST_INCLUDE_DIRS}" CACHE STRING "GTest includes" ) set( FIND_GTEST_LIBS "${GTEST_BOTH_LIBRARIES}" CACHE STRING "GTest libraries" ) else() # Find gtest using target names set(FIND_GTEST_INCS "" CACHE STRING "GTest includes") set(FIND_GTEST_LIBS "gtest;gtest_main" CACHE STRING "GTest libraries") endif() endfunction() ## Build a C/C++ executable google test program function(cc_gtest) cmake_parse_arguments( CC_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) _find_gtest() cc_test( NAME "${CC_ARGS_NAME}" VERSION "${CC_ARGS_VERSION}" SRCS "${CC_ARGS_SRCS}" INCS "${CC_ARGS_INCS};${FIND_GTEST_INCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS};${FIND_GTEST_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" ARGS "${CC_ARGS_ARGS}" ) endfunction() ## Find gmock library function(_find_gmock) if(DEFINED FIND_GMOCK_LIBS AND DEFINED FIND_GMOCK_INCS) return() endif() if(NOT TARGET gmock OR NOT TARGET gmock_main) # Find gmock/gtest using 'find_package' find_package(GMock REQUIRED) find_package(GTest REQUIRED) set( FIND_GMOCK_INCS "${GMOCK_INCLUDE_DIRS};${GTEST_INCLUDE_DIRS}" CACHE STRING "GMock includes" ) set( FIND_GMOCK_LIBS "${GMOCK_BOTH_LIBRARIES};${GTEST_LIBRARIES}" CACHE STRING "GMock libraries" ) else() # Find gmock using target names set(FIND_GMOCK_INCS "" CACHE STRING "GMock includes") set(FIND_GMOCK_LIBS "gmock;gmock_main" CACHE STRING "GMock libraries") endif() endfunction() ## Build a C/C++ executable google mock program function(cc_gmock) cmake_parse_arguments( CC_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) _find_gmock() cc_test( NAME "${CC_ARGS_NAME}" VERSION "${CC_ARGS_VERSION}" SRCS "${CC_ARGS_SRCS}" INCS "${CC_ARGS_INCS};${FIND_GMOCK_INCS}" DEFS "${CC_ARGS_DEFS}" LIBS "${CC_ARGS_LIBS};${FIND_GMOCK_LIBS}" CFLAGS "${CC_ARGS_CFLAGS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" ARGS "${CC_ARGS_ARGS}" ) endfunction() ## Find protobuf library function(_find_protobuf _VERSION) if(DEFINED CC_PROTOBUF_PROTOC_${_VERSION}) return() endif() # Find protobuf using 'find_package' if(NOT TARGET protoc OR NOT TARGET libprotobuf) find_package(Protobuf ${_VERSION} REQUIRED) set( CC_PROTOBUF_PROTOC_${_VERSION} "${PROTOBUF_PROTOC_EXECUTABLE}" CACHE PATH "Protobuf compiler" ) set( CC_PROTOBUF_INCS_${_VERSION} "${PROTOBUF_INCLUDE_DIRS}" CACHE STRING "Protobuf includes" ) set( CC_PROTOBUF_LIBS_${_VERSION} "${PROTOBUF_LIBRARIES}" CACHE STRING "Protobuf libraries" ) return() endif() # Find protobuf using target names get_target_property(protoc_VERSION protoc VERSION) get_target_property(libprotobuf_VERSION libprotobuf VERSION) if(_VERSION) if(${protoc_VERSION} VERSION_LESS ${_VERSION}) message( FATAL_ERROR "The 'protoc' version is ${protoc_VERSION}, less than ${_VERSION}." ) endif() if(${libprotobuf_VERSION} VERSION_LESS ${_VERSION}) message( FATAL_ERROR "The 'libprotobuf' version is ${libprotobuf_VERSION}, " "less than ${_VERSION}." ) endif() endif() message(STATUS "Found binary 'protoc ${protoc_VERSION}'") message(STATUS "Found library 'libprotobuf ${libprotobuf_VERSION}'") set( CC_PROTOBUF_PROTOC_${_VERSION} "$" CACHE PATH "Protobuf compiler" ) get_target_property(protoc_SOURCE_DIR protoc SOURCE_DIR) get_filename_component(protoc_INCLUDE_DIR ${protoc_SOURCE_DIR}/../src ABSOLUTE) set( CC_PROTOBUF_INCS_${_VERSION} "${protoc_INCLUDE_DIR}" CACHE STRING "Protobuf includes" ) set( CC_PROTOBUF_LIBS_${_VERSION} libprotobuf CACHE STRING "Protobuf libraries" ) endfunction() ## Build a C++ protobuf static or shared library function(cc_proto_library) cmake_parse_arguments( CC_ARGS "STATIC;SHARED;EXCLUDE;PACKED" "NAME;VERSION;PROTOROOT;PROTOBUF_VERSION" "SRCS;CXXFLAGS;LDFLAGS;DEPS" ${ARGN} ) _find_protobuf("${CC_ARGS_PROTOBUF_VERSION}") set(CC_PROTOBUF_PROTOC ${CC_PROTOBUF_PROTOC_${CC_ARGS_PROTOBUF_VERSION}}) if(DEFINED GLOBAL_CC_PROTOBUF_PROTOC) set(CC_PROTOBUF_PROTOC ${GLOBAL_CC_PROTOBUF_PROTOC}) endif() set(CC_PROTOBUF_INCS ${CC_PROTOBUF_INCS_${CC_ARGS_PROTOBUF_VERSION}}) set(CC_PROTOBUF_LIBS ${CC_PROTOBUF_LIBS_${CC_ARGS_PROTOBUF_VERSION}}) if(NOT CC_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CC_ARGS_SRCS ${CC_ARGS_SRCS}) if(NOT CC_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CC_ARGS_NAME}.") endif() if(CC_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CC_ARGS_NAME}") list(APPEND CC_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CC_ARGS_VERSION}") endif() if(CC_ARGS_EXCLUDE) set(EXCLUDE_OPTION EXCLUDE_FROM_ALL) endif() set(PROTO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) if(CC_ARGS_PROTOROOT) get_filename_component(PROTO_ROOT ${CC_ARGS_PROTOROOT} ABSOLUTE) endif() # Compile proto files to C++ sources set(CPP_OUTPATH "${CMAKE_CURRENT_BINARY_DIR}") foreach(PROTO_FILE ${CC_ARGS_SRCS}) get_filename_component(PROTO_FILE ${PROTO_FILE} ABSOLUTE) if(NOT ${PROTO_FILE} MATCHES "\\.proto$$") message(FATAL_ERROR "Unrecognized proto file ${PROTOFILE}") endif() if(NOT ${PROTO_FILE} MATCHES "^${PROTO_ROOT}") message(FATAL_ERROR "'${PROTO_FILE}' NOT IN '${PROTO_ROOT}'") endif() string( REGEX REPLACE "^${PROTO_ROOT}(/?)" "" ROOT_CLEANED_FILE ${PROTO_FILE} ) string(REGEX REPLACE "\\.proto$$" "" EXT_CLEANED_FILE ${ROOT_CLEANED_FILE}) set(CPP_FILE "${CPP_OUTPATH}/${EXT_CLEANED_FILE}.pb.cc") set(HDR_FILE "${CPP_OUTPATH}/${EXT_CLEANED_FILE}.pb.h") set(INJ_FILE "${CPP_OUTPATH}/${EXT_CLEANED_FILE}.pb.cmake") file(RELATIVE_PATH REL_CPP_FILE ${CMAKE_BINARY_DIR} ${CPP_FILE}) set(INJECTED_SCRIPT "foreach(SRC ${EXT_CLEANED_FILE}.pb.cc ${EXT_CLEANED_FILE}.pb.h)\n" " file(READ \$\{SRC\} SRC_CODE)\n" " file(REMOVE \$\{SRC\})\n" " file(APPEND \$\{SRC\} \"#ifdef __GNUC__\\n\")\n" " file(APPEND \$\{SRC\} \"#pragma GCC diagnostic push\\n\")\n" " file(APPEND \$\{SRC\} \"#pragma GCC diagnostic ignored \\\"-Wshadow\\\"\\n\")\n" " file(APPEND \$\{SRC\} \"#pragma GCC diagnostic ignored \\\"-Wunused-parameter\\\"\\n\")\n" " file(APPEND \$\{SRC\} \"#endif\\n\\n\")\n" " file(APPEND \$\{SRC\} \"\$\{SRC_CODE\}\")\n" " file(APPEND \$\{SRC\} \"\\n#ifdef __GNUC__\\n\")\n" " file(APPEND \$\{SRC\} \"#pragma GCC diagnostic pop\\n\")\n" " file(APPEND \$\{SRC\} \"#endif\\n\")\n" "endforeach()\n" ) file(WRITE "${INJ_FILE}" ${INJECTED_SCRIPT}) add_custom_command( OUTPUT "${CPP_FILE}" "${HDR_FILE}" # COMMAND ${CMAKE_COMMAND} -E make_directory ${CPP_OUTPATH} COMMAND ${CC_PROTOBUF_PROTOC} --cpp_out "${CPP_OUTPATH}" --python_out "${CPP_OUTPATH}" --proto_path "${PROTO_ROOT}" --proto_path "${CC_PROTOBUF_INCS}" "${PROTO_FILE}" COMMAND ${CMAKE_COMMAND} -P "${INJ_FILE}" DEPENDS "${PROTO_FILE}" COMMENT "Generating CXX source ${REL_CPP_FILE}" VERBATIM ) list(APPEND CC_SRCS "${CPP_FILE}" "${HDR_FILE}") endforeach() # Compile C++ sources if(CC_ARGS_SHARED AND CC_ARGS_STATIC) _add_library(${CC_ARGS_NAME} "${EXCLUDE_OPTION}" "${CC_SRCS}") elseif(CC_ARGS_SHARED) add_library(${CC_ARGS_NAME} SHARED ${EXCLUDE_OPTION} ${CC_SRCS}) elseif(CC_ARGS_STATIC) add_library(${CC_ARGS_NAME} STATIC ${EXCLUDE_OPTION} ${CC_SRCS}) else() add_library(${CC_ARGS_NAME} ${EXCLUDE_OPTION} ${CC_SRCS}) endif() if(TARGET ${CC_ARGS_NAME}_objects) _cc_target_properties( NAME "${CC_ARGS_NAME}_objects" INCS "${CPP_OUTPATH};${CC_PROTOBUF_INCS}" LIBS "${CC_PROTOBUF_LIBS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) endif() if(TARGET ${CC_ARGS_NAME}_static) _cc_target_properties( NAME "${CC_ARGS_NAME}_static" PUBINCS "${CPP_OUTPATH};${CC_PROTOBUF_INCS}" LIBS "${CC_PROTOBUF_LIBS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME}_static ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) endif() endif() _cc_target_properties( NAME "${CC_ARGS_NAME}" PUBINCS "${CPP_OUTPATH};${CC_PROTOBUF_INCS}" LIBS "${CC_PROTOBUF_LIBS}" CXXFLAGS "${CC_ARGS_CXXFLAGS}" LDFLAGS "${CC_ARGS_LDFLAGS}" DEPS "${CC_ARGS_DEPS}" VERSION "${CC_ARGS_VERSION}" "${CC_ARGS_UNPARSED_ARGUMENTS}" ) if(CC_ARGS_PACKED) install( TARGETS ${CC_ARGS_NAME} ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) endif() endfunction() ## Add a subdirectory to the build function(cuda_directory) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cc_directory(${ARGN}) endfunction() ## Add subdirectories to the build function(cuda_directories) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cc_directories(${ARGN}) endfunction() ## Set the properties of cuda target function(_cuda_target_properties) cmake_parse_arguments( CUDA_ARGS "STRICT;ALWAYS_LINK" "NAME;VERSION;C_STANDARD;CXX_STANDARD" "INCS;PUBINCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() get_target_property(TARGET_TYPE ${CUDA_ARGS_NAME} TYPE) if(("${TARGET_TYPE}" STREQUAL "SHARED_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "STATIC_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "EXECUTABLE")) set(TARGET_LINKABLE TRUE) endif() if(CUDA_ARGS_ALWAYS_LINK) if(("${TARGET_TYPE}" STREQUAL "STATIC_LIBRARY") OR ("${TARGET_TYPE}" STREQUAL "OBJECT_LIBRARY")) set_property(TARGET ${CUDA_ARGS_NAME} PROPERTY ALWAYS_LINK TRUE) endif() endif() # Set the warning level of compiling if(CUDA_ARGS_STRICT) target_compile_options( ${CUDA_ARGS_NAME} PRIVATE "${BAZEL_CUDA_STRICT_COMPILE_FLAGS}" ) if(TARGET_LINKABLE) target_link_libraries( ${CUDA_ARGS_NAME} "${BAZEL_CUDA_STRICT_LINK_FLAGS}" ) endif() else() target_compile_options( ${CUDA_ARGS_NAME} PRIVATE "${BAZEL_CUDA_UNSTRICT_COMPILE_FLAGS}" ) if(TARGET_LINKABLE) target_link_libraries( ${CUDA_ARGS_NAME} "${BAZEL_CUDA_UNSTRICT_LINK_FLAGS}" ) endif() endif() target_compile_options( ${CUDA_ARGS_NAME} PRIVATE "$<$:-ccbin=${CMAKE_CXX_COMPILER}>" ) if(CUDA_ARGS_DEFS) target_compile_definitions(${CUDA_ARGS_NAME} PRIVATE "${CUDA_ARGS_DEFS}") endif() if(CUDA_ARGS_CFLAGS OR CUDA_ARGS_CXXFLAGS OR CUDA_ARGS_CUDAFLAGS) target_compile_options( ${CUDA_ARGS_NAME} PRIVATE "$<$:${CUDA_ARGS_CFLAGS}>" "$<$:${CUDA_ARGS_CXXFLAGS}>" "$<$:${CUDA_ARGS_CUDAFLAGS}>" ) endif() if(CUDA_ARGS_LDFLAGS) string(REPLACE ";" " " CUDA_ARGS_LDFLAGS "${CUDA_ARGS_LDFLAGS}") set_property( TARGET ${CUDA_ARGS_NAME} PROPERTY LINK_FLAGS "${CUDA_ARGS_LDFLAGS}" ) endif() if(CUDA_ARGS_INCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_INCS}) target_include_directories(${CUDA_ARGS_NAME} PRIVATE "${INC_DIRS}") endif() target_include_directories( ${CUDA_ARGS_NAME} PRIVATE "${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}" ) if(BAZEL_WORKSPACE_DIR) target_include_directories( ${CUDA_ARGS_NAME} PRIVATE "${BAZEL_WORKSPACE_DIR}" ) endif() if(CUDA_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_PUBINCS}) target_include_directories(${CUDA_ARGS_NAME} PUBLIC "${INC_DIRS}") endif() if(CUDA_ARGS_LIBS) if(NOT TARGET_LINKABLE) _targets_link_dependencies(${CUDA_ARGS_NAME} ${CUDA_ARGS_LIBS}) else() if ("${TARGET_TYPE}" STREQUAL "EXECUTABLE") _target_link_libraries(${CUDA_ARGS_NAME} "${CUDA_ARGS_LIBS}") else() target_link_libraries(${CUDA_ARGS_NAME} "${CUDA_ARGS_LIBS}") endif() endif() endif() if(CUDA_ARGS_DEPS) add_dependencies(${CUDA_ARGS_NAME} "${CUDA_ARGS_DEPS}") endif() if(CUDA_ARGS_VERSION) set_property( TARGET ${CUDA_ARGS_NAME} PROPERTY VERSION "${CUDA_ARGS_VERSION}" ) endif() if(NOT CUDA_C_STANDARD) set(CUDA_C_STANDARD 99) endif() if(NOT CUDA_CXX_STANDARD) set(CUDA_CXX_STANDARD 11) endif() set_target_properties( ${CUDA_ARGS_NAME} PROPERTIES DEFINE_SYMBOL "" C_STANDARD ${CUDA_C_STANDARD} CXX_STANDARD ${CUDA_CXX_STANDARD} C_STANDARD_REQUIRED ON C_EXTENSIONS ON CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF CUDA_STANDARD 11 CUDA_STANDARD_REQUIRED ON CUDA_EXTENSIONS OFF WINDOWS_EXPORT_ALL_SYMBOLS ON ) endfunction() ## Build a CUDA static or shared library function(cuda_library) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cmake_parse_arguments( CUDA_ARGS "STATIC;SHARED;EXCLUDE;PACKED" "NAME;VERSION" "SRCS;INCS;PUBINCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS;PACKED_EXCS" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CUDA_ARGS_SRCS ${CUDA_ARGS_SRCS}) if(NOT CUDA_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CUDA_ARGS_NAME}.") endif() if(CUDA_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CUDA_ARGS_NAME}") list(APPEND CUDA_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CUDA_ARGS_VERSION}") endif() if(CUDA_ARGS_EXCLUDE) set(EXCLUDE_OPTION EXCLUDE_FROM_ALL) endif() if(CUDA_ARGS_SHARED AND CUDA_ARGS_STATIC) _add_library(${CUDA_ARGS_NAME} "${EXCLUDE_OPTION}" "${CUDA_ARGS_SRCS}") elseif(CUDA_ARGS_SHARED) add_library(${CUDA_ARGS_NAME} SHARED ${EXCLUDE_OPTION} ${CUDA_ARGS_SRCS}) elseif(CUDA_ARGS_STATIC) add_library(${CUDA_ARGS_NAME} STATIC ${EXCLUDE_OPTION} ${CUDA_ARGS_SRCS}) else() add_library(${CUDA_ARGS_NAME} ${EXCLUDE_OPTION} ${CUDA_ARGS_SRCS}) endif() if(TARGET ${CUDA_ARGS_NAME}_objects) _cuda_target_properties( NAME "${CUDA_ARGS_NAME}_objects" INCS "${CUDA_ARGS_INCS};${CUDA_ARGS_PUBINCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" "${CUDA_ARGS_UNPARSED_ARGUMENTS}" ) endif() if(TARGET ${CUDA_ARGS_NAME}_static) _cuda_target_properties( NAME "${CUDA_ARGS_NAME}_static" INCS "${CUDA_ARGS_INCS}" PUBINCS "${CUDA_ARGS_PUBINCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" "${CUDA_ARGS_UNPARSED_ARGUMENTS}" ) if(CUDA_ARGS_PACKED) install( TARGETS ${CUDA_ARGS_NAME}_static ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) endif() endif() _cuda_target_properties( NAME "${CUDA_ARGS_NAME}" INCS "${CUDA_ARGS_INCS}" PUBINCS "${CUDA_ARGS_PUBINCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" VERSION "${CUDA_ARGS_VERSION}" "${CUDA_ARGS_UNPARSED_ARGUMENTS}" ) if(CUDA_ARGS_PACKED) install( TARGETS ${CUDA_ARGS_NAME} ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) if(CUDA_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CUDA_ARGS_PACKED_IGORNES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CUDA_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" PATTERN "*.cuh" ${PATTERN_EXCLUDES} ) endif() endif() endfunction() ## Build a CUDA executable program function(cuda_binary) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cmake_parse_arguments( CUDA_ARGS "PACKED" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CUDA_ARGS_SRCS ${CUDA_ARGS_SRCS}) if(NOT CUDA_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CUDA_ARGS_NAME}.") endif() if(CUDA_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CUDA_ARGS_NAME}") list(APPEND CUDA_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CUDA_ARGS_VERSION}") endif() add_executable(${CUDA_ARGS_NAME} ${CUDA_ARGS_SRCS}) if(CUDA_ARGS_PACKED) install( TARGETS ${CUDA_ARGS_NAME} RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" ) endif() _cuda_target_properties( NAME "${CUDA_ARGS_NAME}" INCS "${CUDA_ARGS_INCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" VERSION "${CUDA_ARGS_VERSION}" "${CUDA_ARGS_UNPARSED_ARGUMENTS}" ) endfunction() ## Build a CUDA executable test program function(cuda_test) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cmake_parse_arguments( CUDA_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CUDA_ARGS_SRCS ${CUDA_ARGS_SRCS}) if(NOT CUDA_ARGS_SRCS) message(FATAL_ERROR "No source files found of ${CUDA_ARGS_NAME}.") endif() if(CUDA_ARGS_VERSION) string(REPLACE "-" "_" MACRO_PREFIX "${CUDA_ARGS_NAME}") list(APPEND CUDA_ARGS_DEFS ${MACRO_PREFIX}_VERSION="${CUDA_ARGS_VERSION}") endif() add_executable(${CUDA_ARGS_NAME} EXCLUDE_FROM_ALL ${CUDA_ARGS_SRCS}) _cuda_target_properties( NAME "${CUDA_ARGS_NAME}" INCS "${CUDA_ARGS_INCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" "${CUDA_ARGS_UNPARSED_ARGUMENTS}" ) add_dependencies(unittest ${CUDA_ARGS_NAME}) add_custom_target( unittest.${CUDA_ARGS_NAME} COMMAND $ "${CUDA_ARGS_ARGS}" WORKING_DIRECTORY ${PROJECT_BINARY_DIR} DEPENDS ${CUDA_ARGS_NAME} ) add_test( NAME ${CUDA_ARGS_NAME} COMMAND $ "${CUDA_ARGS_ARGS}" WORKING_DIRECTORY ${PROJECT_BINARY_DIR} ) endfunction() ## Add existing test cases to a test suite function(cuda_test_suite) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cc_test_suite(${ARGN}) endfunction() ## Import a C/C++/CUDA static or shared library function(cuda_import) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cmake_parse_arguments( CUDA_ARGS "STATIC;SHARED;PACKED" "NAME;PATH;IMPLIB" "INCS;PUBINCS;DEPS;PACKED_EXCLUDES" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB CUDA_ARGS_PATH ${CUDA_ARGS_PATH}) if(NOT CUDA_ARGS_PATH) message(FATAL_ERROR "No imported target file found of ${CUDA_ARGS_NAME}.") endif() if(MSVC AND CUDA_ARGS_SHARED AND NOT CUDA_ARGS_IMPLIB) string(REGEX REPLACE ".[Dd][Ll][Ll]$" ".lib" CUDA_ARGS_IMPLIB ${CUDA_ARGS_PATH} ) endif() if(CUDA_ARGS_SHARED) add_library(${CUDA_ARGS_NAME} SHARED IMPORTED GLOBAL) elseif(CUDA_ARGS_STATIC) add_library(${CUDA_ARGS_NAME} STATIC IMPORTED GLOBAL) else() add_library(${CUDA_ARGS_NAME} UNKNOWN IMPORTED GLOBAL) endif() set_property( TARGET ${CUDA_ARGS_NAME} PROPERTY IMPORTED_LOCATION ${CUDA_ARGS_PATH} ) if(MSVC AND CUDA_ARGS_SHARED) set_property( TARGET ${CUDA_ARGS_NAME} PROPERTY IMPORTED_IMPLIB ${CUDA_ARGS_IMPLIB} ) endif() if(CUDA_ARGS_INCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_INCS}) foreach(INC_DIR ${INC_DIRS}) set_property( TARGET ${CUDA_ARGS_NAME} APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${INC_DIR}" ) endforeach() endif() if(CUDA_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_PUBINCS}) foreach(INC_DIR ${INC_DIRS}) set_property( TARGET ${CUDA_ARGS_NAME} APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${INC_DIR}" ) endforeach() endif() if(CUDA_ARGS_DEPS) add_dependencies(${CUDA_ARGS_NAME} "${CUDA_ARGS_DEPS}") endif() if(CUDA_ARGS_PACKED) install( TARGETS ${CUDA_ARGS_NAME} ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" ) if(CUDA_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CUDA_ARGS_PACKED_EXCLUDES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CUDA_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" PATTERN "*.cuh" ${PATTERN_EXCLUDES} ) endif() endif() endfunction() ## Import a C/C++/CUDA interface library function(cuda_interface) if(NOT CMAKE_CUDA_COMPILER) message(FATAL_ERROR "No CUDA language supported.") endif() cmake_parse_arguments( CUDA_ARGS "PACKED" "NAME" "INCS;PUBINCS;DEPS;PACKED_EXCLUDES" ${ARGN} ) if(NOT CUDA_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() add_library(${CUDA_ARGS_NAME} INTERFACE GLOBAL) if(CUDA_ARGS_INCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_INCS}) target_include_directories(${CUDA_ARGS_NAME} INTERFACE "${INC_DIRS}") endif() if(CUDA_ARGS_PUBINCS) _absolute_paths(INC_DIRS ${CUDA_ARGS_PUBINCS}) target_include_directories(${CUDA_ARGS_NAME} INTERFACE "${INC_DIRS}") endif() if(CUDA_ARGS_DEPS) add_dependencies(${CUDA_ARGS_NAME} "${CUDA_ARGS_DEPS}") endif() if(CUDA_ARGS_PACKED AND CUDA_ARGS_PUBINCS) foreach(PACKED_EXCLUDE ${CUDA_ARGS_PACKED_EXCLUDES}) list(APPEND PATTERN_EXCLUDES "PATTERN;${PACKED_EXCLUDE};EXCLUDE") endforeach() install( DIRECTORY ${CUDA_ARGS_PUBINCS} DESTINATION ${CMAKE_INSTALL_INCDIR} FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp" PATTERN "*.hxx" PATTERN "*.cuh" ${PATTERN_EXCLUDES} ) endif() endfunction() ## Build a C/C++/CUDA executable google test program function(cuda_gtest) cmake_parse_arguments( CUDA_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) _find_gtest() cuda_test( NAME "${CUDA_ARGS_NAME}" VERSION "${CUDA_ARGS_VERSION}" SRCS "${CUDA_ARGS_SRCS}" INCS "${CUDA_ARGS_INCS};${FIND_GTEST_INCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS};${FIND_GTEST_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" ARGS "${CUDA_ARGS_ARGS}" ) endfunction() ## Build a C/C++/CUDA executable google mock program function(cuda_gmock) cmake_parse_arguments( CUDA_ARGS "" "NAME;VERSION" "SRCS;INCS;DEFS;LIBS;CFLAGS;CXXFLAGS;CUDAFLAGS;LDFLAGS;DEPS;ARGS" ${ARGN} ) _find_gmock() cuda_test( NAME "${CUDA_ARGS_NAME}" VERSION "${CUDA_ARGS_VERSION}" SRCS "${CUDA_ARGS_SRCS}" INCS "${CUDA_ARGS_INCS};${FIND_GMOCK_INCS}" DEFS "${CUDA_ARGS_DEFS}" LIBS "${CUDA_ARGS_LIBS};${FIND_GMOCK_LIBS}" CFLAGS "${CUDA_ARGS_CFLAGS}" CXXFLAGS "${CUDA_ARGS_CXXFLAGS}" CUDAFLAGS "${CUDA_ARGS_CUDAFLAGS}" LDFLAGS "${CUDA_ARGS_LDFLAGS}" DEPS "${CUDA_ARGS_DEPS}" ARGS "${CUDA_ARGS_ARGS}" ) endfunction() ## Add a subdirectory to the build function(go_directory) add_subdirectory(${ARGN}) endfunction() ## Add subdirectories to the build function(go_directories) foreach(SRC_DIR ${ARGN}) add_subdirectory(${SRC_DIR}) endforeach() endfunction() ## Build a go executable program function(go_binary) find_program( GO_EXECUTABLE go PATHS $ENV{HOME}/go ENV GOROOT GOPATH PATH_SUFFIXES bin ) if(NOT GO_EXECUTABLE) message(FATAL_ERROR "No go language compiler found.") endif() cmake_parse_arguments( GO_ARGS "PACKED" "NAME" "GOPATH;SRCS;ASMFLAGS;GCFLAGS;LDFLAGS;DEPS" ${ARGN} ) if(NOT GO_ARGS_NAME) message(FATAL_ERROR "No target name privated.") endif() file(GLOB GO_ARGS_SRCS ${GO_ARGS_SRCS}) if(NOT GO_ARGS_SRCS) message(FATAL_ERROR "No source files/directories found of ${GO_ARGS_NAME}.") endif() if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") string(REPLACE ";" "\;" GO_ARGS_GOPATH "${GO_ARGS_GOPATH}") else() string(REPLACE ";" ":" GO_ARGS_GOPATH "${GO_ARGS_GOPATH}") endif() set( GO_OUTPUT_FILE ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${GO_ARGS_NAME}${CMAKE_EXECUTABLE_SUFFIX} ) file(RELATIVE_PATH GO_OUTPUT_REL_FILE ${CMAKE_BINARY_DIR} ${GO_OUTPUT_FILE}) add_custom_target( ${GO_ARGS_NAME} COMMAND ${CMAKE_COMMAND} -E env GOPATH="${GO_ARGS_GOPATH}" "${GO_EXECUTABLE}" build -v -buildmode=exe -compiler=gc -gcflags="${GO_ARGS_GCFLAGS}" -asmflags="${GO_ARGS_ASMFLAGS}" -ldflags="${GO_ARGS_LDFLAGS}" -o "${GO_OUTPUT_FILE}" "${GO_ARGS_SRCS}" WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" DEPENDS "${GO_ARGS_DEPS}" COMMENT "Building GO executable ${GO_OUTPUT_REL_FILE}" ) if(GO_ARGS_PACKED) install(PROGRAMS ${GO_OUTPUT_FILE} DESTINATION "${CMAKE_INSTALL_BINDIR}") endif() endfunction() ## Fetch content function(_fetch_content) cmake_parse_arguments( DL_ARGS "" "NAME;PATH;GIT_URL;GIT_TAG;HG_URL;HG_TAG;SVN_URL;SVN_REV;URL;URL_HASH" "" ${ARGN} ) if(NOT DL_ARGS_NAME) message(FATAL_ERROR "No fetch name privated.") endif() if(NOT DL_ARGS_PATH) # Download to current source directory set(DL_ARGS_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${DL_ARGS_NAME}") endif() set( CMAKELISTS_CONTENT "cmake_minimum_required(VERSION 3.1)\n" "project(${DL_ARGS_NAME})\n" "include(ExternalProject)\n" "ExternalProject_Add(\n" " ${DL_ARGS_NAME}\n" " PREFIX \"external\"\n" " GIT_REPOSITORY \"${DL_ARGS_GIT_URL}\"\n" " GIT_TAG \"${DL_ARGS_GIT_TAG}\"\n" " HG_REPOSITORY \"${DL_ARGS_HG_URL}\"\n" " HG_TAG \"${DL_ARGS_HG_TAG}\"\n" " SVN_REPOSITORY \"${DL_ARGS_SVN_URL}\"\n" " SVN_REVISION \"${DL_ARGS_SVN_REV}\"\n" " URL \"${DL_ARGS_URL}\"\n" " URL_HASH \"${DL_ARGS_URL_HASH}\"\n" " SOURCE_DIR \"${DL_ARGS_PATH}\"\n" " BINARY_DIR \"\"\n" " CONFIGURE_COMMAND \"\"\n" " BUILD_COMMAND \"\"\n" " INSTALL_COMMAND \"\"\n" " TEST_COMMAND \"\"\n" " LOG_DOWNLOAD ON\n" " )\n" ) set( CMAKELISTS_DIRECTORY "${PROJECT_BINARY_DIR}/downloads/${DL_ARGS_NAME}" ) add_custom_target( external.${DL_ARGS_NAME} COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . && "${CMAKE_COMMAND}" --build . WORKING_DIRECTORY "${CMAKELISTS_DIRECTORY}" ) # Write a cmake script into folder file(WRITE "${CMAKELISTS_DIRECTORY}/CMakeLists.txt" ${CMAKELISTS_CONTENT}) execute_process( COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . WORKING_DIRECTORY "${CMAKELISTS_DIRECTORY}" ) execute_process( COMMAND "${CMAKE_COMMAND}" --build . WORKING_DIRECTORY "${CMAKELISTS_DIRECTORY}" ) endfunction() ## Download a git repository function(git_repository) cmake_parse_arguments(GIT_ARGS "" "NAME;PATH;URL;TAG" "" ${ARGN}) if(NOT GIT_ARGS_NAME) message(FATAL_ERROR "No repository name privated.") endif() if(NOT GIT_ARGS_URL) message(FATAL_ERROR "No repository URL privated.") endif() if(GIT_ARGS_PATH AND NOT IS_ABSOLUTE ${GIT_ARGS_PATH}) get_filename_component(GIT_ARGS_PATH ${GIT_ARGS_PATH} ABSOLUTE) endif() _fetch_content( NAME "${GIT_ARGS_NAME}" PATH "${GIT_ARGS_PATH}" GIT_URL "${GIT_ARGS_URL}" GIT_TAG "${GIT_ARGS_TAG}" ) endfunction() ## Download a hg repository function(hg_repository) cmake_parse_arguments(HG_ARGS "" "NAME;PATH;URL;TAG" "" ${ARGN}) if(NOT HG_ARGS_NAME) message(FATAL_ERROR "No repository name privated.") endif() if(NOT HG_ARGS_URL) message(FATAL_ERROR "No repository URL privated.") endif() if(HG_ARGS_PATH AND NOT IS_ABSOLUTE ${HG_ARGS_PATH}) get_filename_component(HG_ARGS_PATH ${HG_ARGS_PATH} ABSOLUTE) endif() _fetch_content( NAME "${HG_ARGS_NAME}" PATH "${HG_ARGS_PATH}" HG_URL "${HG_ARGS_URL}" HG_TAG "${HG_ARGS_TAG}" ) endfunction() ## Download a svn repository function(svn_repository) cmake_parse_arguments(SVN_ARGS "" "NAME;PATH;URL;REV" "" ${ARGN}) if(NOT SVN_ARGS_NAME) message(FATAL_ERROR "No repository name privated.") endif() if(NOT SVN_ARGS_URL) message(FATAL_ERROR "No repository URL privated.") endif() if(SVN_ARGS_PATH AND NOT IS_ABSOLUTE ${SVN_ARGS_PATH}) get_filename_component(SVN_ARGS_PATH ${SVN_ARGS_PATH} ABSOLUTE) endif() _fetch_content( NAME "${SVN_ARGS_NAME}" PATH "${SVN_ARGS_PATH}" SVN_URL "${SVN_ARGS_URL}" SVN_REV "${SVN_ARGS_REV}" ) endfunction() ## Download a http archive function(http_archive) cmake_parse_arguments(HTTP_ARGS "" "NAME;PATH;URL;SHA256;SHA1;MD5" "" ${ARGN}) if(NOT HTTP_ARGS_NAME) message(FATAL_ERROR "No archive name privated.") endif() if(NOT HTTP_ARGS_URL) message(FATAL_ERROR "No archive URL privated.") endif() if(HTTP_ARGS_PATH AND NOT IS_ABSOLUTE ${HTTP_ARGS_PATH}) get_filename_component(HTTP_ARGS_PATH ${HTTP_ARGS_PATH} ABSOLUTE) endif() if(HTTP_ARGS_SHA256) set(HTTP_URL_HASH "SHA256=${HTTP_ARGS_SHA256}") elseif(HTTP_ARGS_SHA1) set(HTTP_URL_HASH "SHA1=${HTTP_ARGS_SHA1}") elseif(HTTP_ARGS_MD5) set(HTTP_URL_HASH "MD5=${HTTP_ARGS_MD5}") else() set(HTTP_URL_HASH "") endif() _fetch_content( NAME "${HTTP_ARGS_NAME}" PATH "${HTTP_ARGS_PATH}" URL "${HTTP_ARGS_URL}" URL_HASH "${HTTP_URL_HASH}" ) endfunction() ## Retrieve a version string from GIT function(git_version _RESULT _SOURCES_DIR) find_package(Git REQUIRED) if(NOT IS_ABSOLUTE ${_SOURCES_DIR}) get_filename_component(_SOURCES_DIR ${_SOURCES_DIR} ABSOLUTE) endif() # git describe --tags execute_process( COMMAND "${GIT_EXECUTABLE}" describe --tags WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE GIT_VER_RESULT OUTPUT_VARIABLE GIT_VER_OUTPUT ERROR_VARIABLE GIT_VER_ERROR ) if(GIT_VER_RESULT EQUAL 0) string(STRIP ${GIT_VER_OUTPUT} GIT_VER_OUTPUT) set(${_RESULT} "${GIT_VER_OUTPUT}" PARENT_SCOPE) return() endif() # git rev-parse --short HEAD execute_process( COMMAND "${GIT_EXECUTABLE}" rev-parse --short HEAD WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE GIT_VER_RESULT OUTPUT_VARIABLE GIT_VER_OUTPUT ERROR_VARIABLE GIT_VER_ERROR ) if(GIT_VER_RESULT EQUAL 0) string(STRIP ${GIT_VER_OUTPUT} GIT_VER_OUTPUT) set(${_RESULT} "g${GIT_VER_OUTPUT}" PARENT_SCOPE) return() endif() set(${_RESULT} "" PARENT_SCOPE) endfunction() ## Retrieve a version string from HG function(hg_version _RESULT _SOURCES_DIR) find_package(Hg REQUIRED) if(NOT IS_ABSOLUTE ${_SOURCES_DIR}) get_filename_component(_SOURCES_DIR ${_SOURCES_DIR} ABSOLUTE) endif() # hg log -T "{latesttagdistance}" -r . execute_process( COMMAND "${HG_EXECUTABLE}" log -T "{latesttagdistance}" -r . WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE HG_VER_RESULT OUTPUT_VARIABLE HG_VER_OUTPUT ERROR_VARIABLE HG_VER_ERROR ) if(HG_VER_RESULT EQUAL 0) string(STRIP ${HG_VER_OUTPUT} HG_VER_OUTPUT) if(HG_VER_OUTPUT STREQUAL "0") # hg log -T "{latesttag}" -r . execute_process( COMMAND "${HG_EXECUTABLE}" log -T "{latesttag}" -r . WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE HG_VER_RESULT OUTPUT_VARIABLE HG_VER_OUTPUT ERROR_VARIABLE HG_VER_ERROR ) else() # hg log -T "{latesttag}-{latesttagdistance}-h{node|short}" -r . execute_process( COMMAND "${HG_EXECUTABLE}" log -T "{latesttag}-{latesttagdistance}-h{node|short}" -r . WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE HG_VER_RESULT OUTPUT_VARIABLE HG_VER_OUTPUT ERROR_VARIABLE HG_VER_ERROR ) endif() if(HG_VER_RESULT EQUAL 0) string(STRIP ${HG_VER_OUTPUT} HG_VER_OUTPUT) if(NOT HG_VER_OUTPUT MATCHES "^null.*") set(${_RESULT} "${HG_VER_OUTPUT}" PARENT_SCOPE) return() endif() endif() endif() # hg log -T "h{node|short}" -r . execute_process( COMMAND "${HG_EXECUTABLE}" log -T "h{node|short}" -r . WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE HG_VER_RESULT OUTPUT_VARIABLE HG_VER_OUTPUT ERROR_VARIABLE HG_VER_ERROR ) if(HG_VER_RESULT EQUAL 0) string(STRIP ${HG_VER_OUTPUT} HG_VER_OUTPUT) set(${_RESULT} "${HG_VER_OUTPUT}" PARENT_SCOPE) return() endif() set(${_RESULT} "" PARENT_SCOPE) endfunction() ## Retrieve a version string from SVN function(svn_version _RESULT _SOURCES_DIR) find_package(Subversion REQUIRED) if(NOT IS_ABSOLUTE ${_SOURCES_DIR}) get_filename_component(_SOURCES_DIR ${_SOURCES_DIR} ABSOLUTE) endif() # svn info --show-item revision execute_process( COMMAND "${Subversion_SVN_EXECUTABLE}" info --show-item revision WORKING_DIRECTORY "${_SOURCES_DIR}" RESULT_VARIABLE SVN_VER_RESULT OUTPUT_VARIABLE SVN_VER_OUTPUT ERROR_VARIABLE SVN_VER_ERROR ) if(SVN_VER_RESULT EQUAL 0) string(STRIP ${SVN_VER_OUTPUT} SVN_VER_OUTPUT) set(${_RESULT} "r${SVN_VER_OUTPUT}" PARENT_SCOPE) return() endif() set(${_RESULT} "" PARENT_SCOPE) endfunction() _find_workspace_directory(BAZEL_WORKSPACE_DIR) if(BAZEL_WORKSPACE_DIR) include("${BAZEL_WORKSPACE_DIR}/Workspace.cmake") endif() ================================================ FILE: cmake/option.cmake ================================================ ## https://en.wikipedia.org/wiki/List_of_Intel_CPU_microarchitectures ## https://en.wikipedia.org/wiki/List_of_AMD_CPU_microarchitectures ## https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html ## Intel Microarchitectures option(ENABLE_NEHALEM "Enable Intel Nehalem CPU microarchitecture" OFF) option(ENABLE_SANDYBRIDGE "Enable Intel Sandy Bridge CPU microarchitecture" OFF) option(ENABLE_HASWELL "Enable Intel Haswell CPU microarchitecture" OFF) option(ENABLE_BROADWELL "Enable Intel Broadwell CPU microarchitecture" OFF) option(ENABLE_SKYLAKE "Enable Intel Skylake CPU microarchitecture" OFF) option(ENABLE_SKYLAKE_AVX512 "Enable Intel Skylake Server CPU microarchitecture" OFF) option(ENABLE_ICELAKE "Enable Intel Icelake CPU microarchitecture" OFF) option(ENABLE_SAPPHIRERAPIDS "Enable Intel Sapphire Rapids Server CPU microarchitecture" OFF) option(ENABLE_EMERALDRAPIDS "Enable Intel Emerald Rapids Server CPU microarchitecture" OFF) option(ENABLE_GRANITERAPIDS "Enable Intel Granite Rapids Server CPU microarchitecture" OFF) option(ENABLE_NATIVE "Enable native CPU microarchitecture" OFF) ## AMD Microarchitectures option(ENABLE_ZEN1 "Enable AMD Zen+ Family 17h CPU microarchitecture" OFF) option(ENABLE_ZEN2 "Enable AMD Zen 2 Family 17h CPU microarchitecture" OFF) option(ENABLE_ZEN3 "Enable AMD Zen 3 Family 19h CPU microarchitecture" OFF) ## ARM architectures option(ENABLE_ARMV8A "Enable ARMv8-a architecture" OFF) option(ENABLE_ARMV8.1A "Enable ARMv8.1-a architecture" OFF) option(ENABLE_ARMV8.2A "Enable ARMv8.2-a architecture" OFF) option(ENABLE_ARMV8.3A "Enable ARMv8.3-a architecture" OFF) option(ENABLE_ARMV8.4A "Enable ARMv8.4-a architecture" OFF) option(ENABLE_ARMV8.5A "Enable ARMv8.5-a architecture" OFF) option(ENABLE_ARMV8.6A "Enable ARMv8.6-a architecture" OFF) ## OpenMP option option(ENABLE_OPENMP "Enable OpenMP support" OFF) set(ARCH_OPTIONS ENABLE_NEHALEM ENABLE_SANDYBRIDGE ENABLE_HASWELL ENABLE_BROADWELL ENABLE_SKYLAKE ENABLE_SKYLAKE_AVX512 ENABLE_ICELAKE ENABLE_SAPPHIRERAPIDS ENABLE_EMERALDRAPIDS ENABLE_GRANITERAPIDS ENABLE_ZEN1 ENABLE_ZEN2 ENABLE_ZEN3 ENABLE_ARMV8A ENABLE_ARMV8.1A ENABLE_ARMV8.2A ENABLE_ARMV8.3A ENABLE_ARMV8.4A ENABLE_ARMV8.5A ENABLE_ARMV8.6A ENABLE_NATIVE ) option(AUTO_DETECT_ARCH "Auto detect CPU microarchitecture" ON) foreach(opt IN LISTS ARCH_OPTIONS) if(${opt}) set(AUTO_DETECT_ARCH OFF) break() endif() endforeach() include(CheckCCompilerFlag) function(_AppendFlags _RESULT _FLAG) if(${_RESULT} AND NOT "${${_RESULT}}" MATCHES "${_FLAG}") set(${_RESULT} "${${_RESULT}} ${_FLAG}" PARENT_SCOPE) else() set(${_RESULT} "${_FLAG}" PARENT_SCOPE) endif() endfunction() macro(add_arch_flag FLAG VAR_NAME OPTION_NAME) check_c_compiler_flag("${FLAG}" COMPILER_SUPPORT_${VAR_NAME}) if(COMPILER_SUPPORT_${VAR_NAME}) _AppendFlags(CMAKE_C_FLAGS "${FLAG}") _AppendFlags(CMAKE_CXX_FLAGS "${FLAG}") set(${VAR_NAME}_ENABLED ON) else() if(${OPTION_NAME}) message(FATAL_ERROR "Compiler does not support required flag: '${FLAG}' for ${OPTION_NAME}") else() set(${VAR_NAME}_ENABLED OFF) endif() endif() endmacro() function(_setup_armv8_march) set(_arch "armv8") check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) if(_COMP_SUPP_${_arch}) _AppendFlags(CMAKE_C_FLAGS "-march=${_arch}") _AppendFlags(CMAKE_CXX_FLAGS "-march=${_arch}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) return() else() message(WARNING "No ARMv8 march flag supported by compiler.") endif() endfunction() function(_setup_x86_march) set(_arch "x86-64") check_c_compiler_flag("-march=${_arch}" _COMP_SUPP_${_arch}) if(_COMP_SUPP_${_arch}) _AppendFlags(CMAKE_C_FLAGS "-march=${_arch}") _AppendFlags(CMAKE_CXX_FLAGS "-march=${_arch}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" PARENT_SCOPE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) return() else() message(WARNING "No known x86 march flag supported; falling back to generic.") endif() endfunction() function(setup_compiler_march_for_x86 VAR_NAME_SSE VAR_NAME_AVX2 VAR_NAME_AVX512 VAR_NAME_AVX512FP16) #sse set(${VAR_NAME_SSE} "-march=corei7" PARENT_SCOPE) #avx 2 set(${VAR_NAME_AVX2} "-march=core-avx2" PARENT_SCOPE) #avx512 set(_x86_flags_avx512 "icelake-server" "skylake-avx512" "core-avx2" "x86-64") foreach(_arch_avx512 IN LISTS _x86_flags_avx512) check_c_compiler_flag("-march=${_arch_avx512}" _COMP_SUPP_${_arch_avx512}) if(_COMP_SUPP_${_arch_avx512}) set(${VAR_NAME_AVX512} "-march=${_arch_avx512}" PARENT_SCOPE) break() endif() endforeach() #avx512fp16 set(_x86_flags_avx512fp16 "sapphirerapids" "icelake-server" "skylake-avx512" "core-avx2" "x86-64" ) foreach(_arch_avx512fp16 IN LISTS _x86_flags_avx512fp16) check_c_compiler_flag("-march=${_arch_avx512fp16}" _COMP_SUPP_${_arch_avx512fp16}) if(_COMP_SUPP_${_arch_avx512fp16}) set(${VAR_NAME_AVX512FP16} "-march=${_arch_avx512fp16}" PARENT_SCOPE) break() endif() endforeach() endfunction() if(MSVC) # Prefer higher ISAs foreach(_isa IN ITEMS "AVX512" "AVX2" "AVX" "SSE2") check_c_compiler_flag("/arch:${_isa}" _COMP_SUPP_${_isa}) if(_COMP_SUPP_${_isa}) _AppendFlags(CMAKE_C_FLAGS "/arch:${_isa}") _AppendFlags(CMAKE_CXX_FLAGS "/arch:${_isa}") message(STATUS "MSVC: enabled /arch:${_isa}") break() endif() endforeach() return() endif() if(NOT AUTO_DETECT_ARCH) if(ENABLE_NATIVE) add_arch_flag("-march=native" NATIVE ENABLE_NATIVE) endif() if(ENABLE_ZEN3) add_arch_flag("-march=znver3" ZNVER3 ENABLE_ZEN3) endif() if(ENABLE_ZEN2) add_arch_flag("-march=znver2" ZNVER2 ENABLE_ZEN2) endif() if(ENABLE_ZEN1) add_arch_flag("-march=znver1" ZNVER1 ENABLE_ZEN1) endif() if(ENABLE_GRANITERAPIDS) add_arch_flag("-march=graniterapids" GRANITERAPIDS ENABLE_GRANITERAPIDS) endif() if(ENABLE_EMERALDRAPIDS) add_arch_flag("-march=emeraldrapids" EMERALDRAPIDS ENABLE_EMERALDRAPIDS) endif() if(ENABLE_SAPPHIRERAPIDS) add_arch_flag("-march=sapphirerapids" SAPPHIRERAPIDS ENABLE_SAPPHIRERAPIDS) endif() if(ENABLE_ICELAKE) add_arch_flag("-march=icelake-server" ICELAKE ENABLE_ICELAKE) endif() if(ENABLE_SKYLAKE_AVX512) add_arch_flag("-march=skylake-avx512" SKYLAKE_AVX512 ENABLE_SKYLAKE_AVX512) endif() if(ENABLE_SKYLAKE) add_arch_flag("-march=skylake" SKYLAKE ENABLE_SKYLAKE) endif() if(ENABLE_BROADWELL) add_arch_flag("-march=broadwell" BROADWELL ENABLE_BROADWELL) endif() if(ENABLE_HASWELL) add_arch_flag("-march=haswell" HASWELL ENABLE_HASWELL) endif() if(ENABLE_SANDYBRIDGE) add_arch_flag("-march=sandybridge" SANDYBRIDGE ENABLE_SANDYBRIDGE) endif() if(ENABLE_NEHALEM) add_arch_flag("-march=nehalem" NEHALEM ENABLE_NEHALEM) endif() # ARM (newest first — allow multiple? usually only one) # But GCC allows only one -march=, so honor highest enabled if(ENABLE_ARMV8.6A) add_arch_flag("-march=armv8.6-a" ARMV86A ENABLE_ARMV8.6A) endif() if(ENABLE_ARMV8.5A) add_arch_flag("-march=armv8.5-a" ARMV85A ENABLE_ARMV8.5A) endif() if(ENABLE_ARMV8.4A) add_arch_flag("-march=armv8.4-a" ARMV84A ENABLE_ARMV8.4A) endif() if(ENABLE_ARMV8.3A) add_arch_flag("-march=armv8.3-a" ARMV83A ENABLE_ARMV8.3A) endif() if(ENABLE_ARMV8.2A) add_arch_flag("-march=armv8.2-a" ARMV82A ENABLE_ARMV8.2A) endif() if(ENABLE_ARMV8.1A) add_arch_flag("-march=armv8.1-a" ARMV81A ENABLE_ARMV8.1A) endif() if(ENABLE_ARMV8A) add_arch_flag("-march=armv8-a" ARMV8A ENABLE_ARMV8A) endif() else() # AUTO DETECT # Heuristic: detect host architecture and probe appropriate flags if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") _setup_armv8_march() elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") _setup_x86_march() else() message(WARNING "Unknown host architecture: ${CMAKE_SYSTEM_PROCESSOR}; no -march= set.") endif() endif() # ----------------------------- # OpenMP # ----------------------------- if(ENABLE_OPENMP) find_package(OpenMP REQUIRED) if(OpenMP_C_FLAGS) _AppendFlags(CMAKE_C_FLAGS "${OpenMP_C_FLAGS}") endif() if(OpenMP_CXX_FLAGS) _AppendFlags(CMAKE_CXX_FLAGS "${OpenMP_CXX_FLAGS}") endif() endif() ================================================ FILE: cmake/utils.cmake ================================================ function(apply_patch_once patch_name target_dir patch_file) set(mark_file "${target_dir}/.${patch_name}_patched") if(EXISTS "${mark_file}") #message(STATUS "Patch '${patch_name}' already applied to ${target_dir}, skipping.") return() endif() if(NOT EXISTS "${patch_file}") message(FATAL_ERROR "Patch file '${patch_file}' not found!") endif() #message(STATUS "Applying patch '${patch_name}' to ${target_dir} ...") execute_process( COMMAND patch -p1 -i "${patch_file}" WORKING_DIRECTORY "${target_dir}" RESULT_VARIABLE patch_result OUTPUT_VARIABLE patch_stdout ERROR_VARIABLE patch_stderr ) if(NOT patch_result EQUAL 0) message(FATAL_ERROR "Failed to apply patch '${patch_name}' to ${target_dir}:\n${patch_stderr}") else() #message(STATUS "Patch '${patch_name}' applied successfully:\n${patch_stdout}") file(WRITE "${mark_file}" "patched") endif() endfunction() ================================================ FILE: examples/c++/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.13) cmake_policy(SET CMP0077 NEW) project(zvec-example-c++) set(CMAKE_CXX_STANDARD 17) # Enable compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # --- Paths to Zvec and dependencies --- # Allow custom host build directory, default to "build" if(NOT DEFINED HOST_BUILD_DIR) set(HOST_BUILD_DIR "build") endif() set(ZVEC_INCLUDE_DIR ${CMAKE_BINARY_DIR}/../../../src/include) set(ZVEC_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/lib) set(ZVEC_DEPENDENCY_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/external/usr/local/lib) # Add include and library search paths include_directories(${ZVEC_INCLUDE_DIR}) link_directories(${ZVEC_LIB_DIR} ${ZVEC_DEPENDENCY_LIB_DIR}) # --- Determine debug/release library names --- if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(GLOG_LIB glogd) set(GFLAGS_LIB gflags_nothreads_debug) set(PROTOBUF_LIB protobufd) else() set(GLOG_LIB glog) set(GFLAGS_LIB gflags_nothreads) set(PROTOBUF_LIB protobuf) endif() # --- Dependency groups --- find_package(Threads REQUIRED) set(zvec_ailego_deps arrow parquet arrow_bundled_dependencies ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS} ) set(zvec_core_deps zvec_turbo ) set(zvec_db_deps roaring rocksdb arrow arrow_acero arrow_bundled_dependencies arrow_compute arrow_dataset parquet antlr4-runtime ${GLOG_LIB} ${GFLAGS_LIB} ${PROTOBUF_LIB} lz4 ) # --- Create INTERFACE targets for Zvec components --- # zvec_ailego: links libzvec_ailego.a + its deps add_library(zvec-ailego INTERFACE) target_link_libraries(zvec-ailego INTERFACE -lzvec_ailego ${zvec_ailego_deps} ) # zvec_core: links libzvec_core.a via special flags (handled externally), but declare logical deps add_library(zvec-core INTERFACE) if(CMAKE_SYSTEM_NAME STREQUAL "Linux") target_link_libraries(zvec-core INTERFACE -Wl,--whole-archive zvec_core -Wl,--no-whole-archive -Wl,--start-group zvec-ailego ${zvec_core_deps} -Wl,--end-group ) elseif(APPLE) target_link_libraries(zvec-core INTERFACE -Wl,-force_load ${ZVEC_LIB_DIR}/libzvec_core.a zvec-ailego ${zvec_core_deps} ) elseif(ANDROID) target_link_libraries(zvec-core INTERFACE -Wl,--whole-archive zvec_core -Wl,--no-whole-archive -Wl,--start-group zvec-ailego ${zvec_core_deps} -Wl,--end-group ) else() message(FATAL_ERROR "Unsupported platform: ${CMAKE_SYSTEM_NAME}") endif() # zvec_db: links libzvec_db.a + all deps add_library(zvec-db INTERFACE) if(CMAKE_SYSTEM_NAME STREQUAL "Linux") target_link_libraries(zvec-db INTERFACE zvec_db zvec-core zvec-ailego -Wl,--start-group ${zvec_db_deps} -Wl,--end-group ) elseif(APPLE) target_link_libraries(zvec-db INTERFACE zvec_db zvec-core zvec-ailego ${zvec_db_deps} ) elseif(ANDROID) target_link_libraries(zvec-db INTERFACE zvec_db zvec-core zvec-ailego -Wl,--start-group ${zvec_db_deps} -Wl,--end-group ) else() message(FATAL_ERROR "Unsupported platform: ${CMAKE_SYSTEM_NAME}") endif() # --- Main executable --- add_executable(db-example db/main.cc) target_link_libraries(db-example PRIVATE zvec-db ) if(ANDROID) target_link_libraries(db-example PRIVATE log ) endif() add_executable(core-example core/main.cc) target_link_libraries(core-example PRIVATE zvec-core ) add_executable(ailego-example ailego/main.cc) target_link_libraries(ailego-example PRIVATE zvec-ailego ) # Strip symbols to reduce executable size if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) add_custom_command(TARGET db-example POST_BUILD COMMAND ${CMAKE_STRIP} "$" COMMENT "Stripping symbols from db-example") add_custom_command(TARGET core-example POST_BUILD COMMAND ${CMAKE_STRIP} "$" COMMENT "Stripping symbols from core-example") add_custom_command(TARGET ailego-example POST_BUILD COMMAND ${CMAKE_STRIP} "$" COMMENT "Stripping symbols from ailego-example") endif() # Optimize for size if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) set_property(TARGET db-example core-example ailego-example PROPERTY COMPILE_FLAGS "-Os") set_property(TARGET db-example core-example ailego-example PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) endif() ================================================ FILE: examples/c++/ailego/main.cc ================================================ #include #include #include using namespace zvec; int main() { std::string a{"hello world"}; std::cout << ailego::StringHelper::StartsWith(a, "hello") << std::endl; } ================================================ FILE: examples/c++/core/main.cc ================================================ #include #include #include #include #include #include using namespace zvec::core_interface; constexpr uint32_t kDimension = 64; const std::string index_name{"test.index"}; Index::Pointer create_index(const BaseIndexParam::Pointer ¶m, int doc_num = 10) { auto index = IndexFactory::CreateAndInitIndex(*param); if (!index) { std::cout << "Failed to create index." << std::endl; return nullptr; } int ret = index->Open( index_name, StorageOptions{StorageOptions::StorageType::kMMAP, true}); if (ret != 0) { std::cout << "Failed to open index." << std::endl; return nullptr; } for (int i = 0; i < doc_num; ++i) { std::vector vector(kDimension, i / 10.0f + 0.1f); VectorData vector_data; vector_data.vector = DenseVector{vector.data()}; ret = index->Add(vector_data, i); if (ret != 0) { std::cout << "Failed to add to index." << std::endl; return nullptr; } } ret = index->Train(); if (ret != 0) { std::cout << "Failed to train index." << std::endl; return nullptr; } return index; } int main() { char cmd_buf[100]; snprintf(cmd_buf, 100, "rm -f %s", index_name.c_str()); system(cmd_buf); auto param = HNSWIndexParamBuilder() .WithMetricType(MetricType::kInnerProduct) .WithDataType(DataType::DT_FP32) .WithDimension(kDimension) .WithIsSparse(false) .Build(); auto index = create_index(param, 1); std::cout << "index stats: " << index->GetDocCount() << std::endl; // query auto query_param = HNSWQueryParamBuilder() .with_topk(10) .with_fetch_vector(true) .with_ef_search(20) .build(); SearchResult result; VectorData query; std::vector vector(kDimension, 0.1f); query.vector = DenseVector{vector.data()}; int ret = index->Search(query, query_param, &result); if (ret != 0) { std::cout << "Failed to search index." << std::endl; return -1; } std::cout << "query results: " << result.doc_list_.size() << std::endl; if (result.doc_list_.size() == 0) { std::cout << "No results found." << std::endl; return -1; } std::cout << "key: " << result.doc_list_[0].key() << ", score: " << result.doc_list_[0].score() << std::endl; return 0; } ================================================ FILE: examples/c++/db/main.cc ================================================ #include #include #include #include #include #include #include using namespace zvec; Doc create_doc(const uint64_t doc_id, const CollectionSchema &schema, std::string pk = "") { Doc new_doc; if (pk.empty()) { pk = "pk_" + std::to_string(doc_id); } new_doc.set_pk(pk); for (auto &field : schema.fields()) { switch (field->data_type()) { case DataType::BINARY: { std::string binary_str("binary_" + std::to_string(doc_id)); new_doc.set(field->name(), binary_str); break; } case DataType::BOOL: new_doc.set(field->name(), doc_id % 10 == 0); break; case DataType::INT32: new_doc.set(field->name(), (int32_t)doc_id); break; case DataType::INT64: new_doc.set(field->name(), (int64_t)doc_id); break; case DataType::UINT32: new_doc.set(field->name(), (uint32_t)doc_id); break; case DataType::UINT64: new_doc.set(field->name(), (uint64_t)doc_id); break; case DataType::FLOAT: new_doc.set(field->name(), (float)doc_id); break; case DataType::DOUBLE: new_doc.set(field->name(), (double)doc_id); break; case DataType::STRING: new_doc.set(field->name(), "value_" + std::to_string(doc_id)); break; case DataType::ARRAY_BINARY: { std::vector bin_vec; for (size_t i = 0; i < (doc_id % 10); i++) { bin_vec.push_back("bin_" + std::to_string(i)); } new_doc.set>(field->name(), bin_vec); break; } case DataType::ARRAY_BOOL: new_doc.set>(field->name(), std::vector(10, doc_id % 10 == 0)); break; case DataType::ARRAY_INT32: new_doc.set>( field->name(), std::vector(10, (int32_t)doc_id)); break; case DataType::ARRAY_INT64: new_doc.set>( field->name(), std::vector(10, (int64_t)doc_id)); break; case DataType::ARRAY_UINT32: new_doc.set>( field->name(), std::vector(10, (uint32_t)doc_id)); break; case DataType::ARRAY_UINT64: new_doc.set>( field->name(), std::vector(10, (uint64_t)doc_id)); break; case DataType::ARRAY_FLOAT: new_doc.set>(field->name(), std::vector(10, (float)doc_id)); break; case DataType::ARRAY_DOUBLE: new_doc.set>( field->name(), std::vector(10, (double)doc_id)); break; case DataType::ARRAY_STRING: new_doc.set>( field->name(), std::vector(10, "value_" + std::to_string(doc_id))); break; case DataType::VECTOR_BINARY32: new_doc.set>( field->name(), std::vector(field->dimension(), uint32_t(doc_id + 0.1))); break; case DataType::VECTOR_BINARY64: new_doc.set>( field->name(), std::vector(field->dimension(), uint64_t(doc_id + 0.1))); break; case DataType::VECTOR_FP32: new_doc.set>( field->name(), std::vector(field->dimension(), float(doc_id + 0.1))); break; case DataType::VECTOR_FP64: new_doc.set>( field->name(), std::vector(field->dimension(), double(doc_id + 0.1))); break; case DataType::VECTOR_FP16: new_doc.set>( field->name(), std::vector( field->dimension(), static_cast( float(doc_id + 0.1)))); break; case DataType::VECTOR_INT8: new_doc.set>( field->name(), std::vector(field->dimension(), (int8_t)doc_id)); break; case DataType::VECTOR_INT16: new_doc.set>( field->name(), std::vector(field->dimension(), (int16_t)doc_id)); break; case DataType::SPARSE_VECTOR_FP16: { std::vector indices; std::vector values; for (uint32_t i = 0; i < 100; i++) { indices.push_back(i); values.push_back(zvec::float16_t(float(doc_id + 0.1))); } std::pair, std::vector> sparse_float_vec; sparse_float_vec.first = indices; sparse_float_vec.second = values; new_doc.set< std::pair, std::vector>>( field->name(), sparse_float_vec); break; } case DataType::SPARSE_VECTOR_FP32: { std::vector indices; std::vector values; for (uint32_t i = 0; i < 100; i++) { indices.push_back(i); values.push_back(float(doc_id + 0.1)); } std::pair, std::vector> sparse_float_vec; sparse_float_vec.first = indices; sparse_float_vec.second = values; new_doc.set, std::vector>>( field->name(), sparse_float_vec); break; } default: std::cout << "Unsupported data type: " << field->name() << std::endl; throw std::runtime_error("Unsupported vector data type"); } } return new_doc; } CollectionSchema::Ptr create_schema() { auto schema = std::make_shared("demo"); schema->set_max_doc_count_per_segment(1000); schema->add_field(std::make_shared( "id", DataType::INT64, false, std::make_shared(true))); schema->add_field(std::make_shared( "name", DataType::STRING, false, std::make_shared(false))); schema->add_field( std::make_shared("weight", DataType::FLOAT, true)); schema->add_field(std::make_shared( "dense", DataType::VECTOR_FP32, 128, false, std::make_shared(MetricType::IP))); schema->add_field(std::make_shared( "sparse", DataType::SPARSE_VECTOR_FP32, 0, false, std::make_shared(MetricType::IP))); return schema; } int main() { std::string path = "./demo"; std::string rm_cmd = "rm -rf " + path; system(rm_cmd.c_str()); auto schema = create_schema(); CollectionOptions options{false, true}; auto result = Collection::CreateAndOpen(path, *schema, options); if (!result.has_value()) { std::cout << result.error().message() << std::endl; return -1; } std::cout << "init stats: " << result.value()->Stats().value().to_string() << std::endl; auto coll = std::move(result).value(); // insert docs { auto doc1 = create_doc(0, *schema); std::vector docs{doc1}; auto res = coll->Insert(docs); if (!res.has_value()) { std::cout << res.error().message() << std::endl; return -1; } std::cout << "after insert stats " << coll->Stats().value().to_string() << std::endl; } // optimize { auto res = coll->Optimize(); if (!res.ok()) { std::cout << res.message() << std::endl; return -1; } std::cout << "after optimize stats " << coll->Stats().value().to_string() << std::endl; } // query { VectorQuery query; query.topk_ = 10; query.field_name_ = "dense"; query.include_vector_ = true; std::vector query_vector = std::vector(128, 0.1); query.query_vector_.assign((char *)query_vector.data(), query_vector.size() * sizeof(float)); auto res = coll->Query(query); if (!res.has_value()) { std::cout << res.error().message() << std::endl; return -1; } std::cout << "query result: doc_count[" << res.value().size() << "]" << std::endl; std::cout << "first doc: " << res.value()[0]->to_detail_string() << std::endl; } // close and reopen coll.reset(); options.read_only_ = true; result = Collection::Open(path, options); if (!result.has_value()) { std::cout << result.error().message() << std::endl; return -1; } std::cout << "reopen stats: " << result.value()->Stats().value().to_string() << std::endl; return 0; } ================================================ FILE: pyproject.toml ================================================ ###################################################################################################### # Zvec: High-Performance Vector Database with PyBind11 & C++ Backend ###################################################################################################### [project] name = "zvec" dynamic = ["version"] description = "A high-performance vector database engine with native C++ backend and Python bindings" readme = "README.md" license = { text = "Apache-2.0" } authors = [ { name = "zvec", email = "zvec@alibaba-inc.com" }, ] maintainers = [ { name = "Zvec Core Team", email = "zvec@alibaba-inc.com" }, ] requires-python = ">=3.9" classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", "Programming Language :: C++", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Topic :: Database", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] keywords = [ "vector-database", "ann", "nearest-neighbor" ] dependencies = [ "numpy >=1.23", ] [project.urls] Homepage = "https://github.com/alibaba/zvec" Repository = "https://github.com/alibaba/zvec" "Bug Tracker" = "https://github.com/alibaba/zvec/issues" "Documentation" = "https://zvec.org" [project.optional-dependencies] test = [ "pytest >=8.0", "pytest-cov >=4.1", "pytest-mock >=3.12", "cibuildwheel == 3.4.0", ] docs = [ "mkdocs >=1.5", "mkdocs-material >=9.5", "mkdocstrings[python] >=0.24", ] dev = [ "ruff >=0.4", "black >=24.0", "mypy >=1.8", "pre-commit >=3.6", "build >=1.0", "twine >=4.0", "numpy >=1.23", # Inherit test deps "pytest >=8.0", "pytest-cov >=4.1", "pytest-mock >=3.12", "cibuildwheel == 3.4.0", # Inherit docs deps "mkdocs >=1.5", "mkdocs-material >=9.5", "mkdocstrings[python] >=0.24", "pybind11-stubgen>=2.5.5", "pybind11 >=3.0", ] ###################################################################################################### # BUILD SYSTEM CONFIGURATION (scikit-build-core) ###################################################################################################### [build-system] requires = [ "scikit-build-core >=0.11", "pybind11 >=3.0", "setuptools_scm>=8.0", "cmake>=3.26,<4.0", "ninja>=1.11", ] build-backend = "scikit_build_core.build" [tool.scikit-build] # Core settings minimum-version = "0.11" metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" # CMake configuration cmake.version = ">=3.26,<4.0" ninja.version = ">=1.11" cmake.build-type = "Release" install.strip = true # Strip symbols in release builds to reduce wheel size # Build directory build-dir = "build" # Platform support wheel.expand-macos-universal-tags = true wheel.packages = ["python/zvec"] # Source distribution sdist.include = [ "README.md", "LICENSE", "pyproject.toml", "CMakeLists.txt", "src/**/*", "stub/zvec/**/*", "python/zvec/py.typed", ] # CMake defines (env-overridable) [tool.scikit-build.cmake.define] BUILD_TOOLS = "OFF" BUILD_PYTHON_BINDINGS = "ON" #CMAKE_VERBOSE_MAKEFILE = "ON" # Setuptools config for test pypi [tool.setuptools_scm] local_scheme = "no-local-version" version_scheme = "guess-next-dev" fallback_version = "0.2.1b1" ###################################################################################################### # TESTING & QUALITY ###################################################################################################### [tool.pytest.ini_options] minversion = "8.0" addopts = [ "-ra", "--showlocals", "--strict-markers", "--strict-config", "--tb=short", ] xfail_strict = true log_cli_level = "INFO" filterwarnings = [ "error", "ignore::pytest.PytestCacheWarning", # Ignore numpy deprecation warnings in tests (if any) "ignore:.*numpy.*:DeprecationWarning", ] testpaths = ["python/tests"] markers = [ "title: Custom marker for test title/description", # "slow: marks tests as slow", ] ###################################################################################################### # BUILD WHEEL ###################################################################################################### [tool.cibuildwheel] build = [ "cp310-*", "cp311-*", "cp312-*", "cp313-*", "cp314-*", ] build-frontend = "build" test-requires = ["pytest", "numpy"] test-command = "cd {project} && pytest python/tests -v --tb=short" build-verbosity = 1 [tool.cibuildwheel.linux] archs = ["auto"] environment = { CMAKE_GENERATOR = "Unix Makefiles", CMAKE_BUILD_PARALLEL_LEVEL = "16" } manylinux-x86_64-image = "manylinux_2_28" manylinux-aarch64-image = "manylinux_2_28" # Skip 32-bit builds and musllinux skip = ["*-manylinux_i686", "*-musllinux*"] [tool.cibuildwheel.macos] archs = ["arm64"] # Inherits CMAKE_GENERATOR and CMAKE_BUILD_PARALLEL_LEVEL from [tool.cibuildwheel] won't work; # platform-level environment overrides the top-level entirely, so all vars must be listed here environment = { CMAKE_GENERATOR = "Unix Makefiles", CMAKE_BUILD_PARALLEL_LEVEL = "16", MACOSX_DEPLOYMENT_TARGET = "11.0" } ###################################################################################################### # CODE QUALITY & FORMATTING (Ruff) ###################################################################################################### [tool.ruff] target-version = "py310" line-length = 88 exclude = [ "build/", "dist/", ".git/", ".venv/", "venv/", "thirdparty", ] [tool.ruff.lint] extend-select = [ "B", # flake8-bugbear "I", # isort "ARG", # flake8-unused-arguments "C4", # flake8-comprehensions "EM", # flake8-errmsg "ICN", # flake8-import-conventions "G", # flake8-logging-format "PGH", # pygrep-hooks "PIE", # flake8-pie "PL", # pylint "PT", # flake8-pytest-style "PTH", # flake8-use-pathlib "RET", # flake8-return "RUF", # Ruff-specific "SIM", # flake8-simplify "T20", # flake8-print "UP", # pyupgrade "YTT", # flake8-2020 "EXE", # flake8-executable "NPY", # NumPy-specific "PD", # pandas-vet ] ignore = [ "PLR0913", # Too many arguments (common in bindings) "PLR2004", # Magic value used in comparison "UP045", "UP007", # Use list() instead of [] (breaks C++ init) "EM101", "EM102", # Exception messages as literals (ok in tests/utils) "B008", # Mutable default args (cautiously allowed in config) "E731", # Lambda assignment (used in callbacks) "B019", # `functools.lru_cache` on methods (handled manually) "PLR0912", # Too many branches "PLC0105", # Ignore contravariant "RUF002", # Ignore Unicode ] fixable = ["ALL"] unfixable = [] # Ignore all errors in docstrings [tool.ruff.lint.pydocstyle] convention = "google" # or "numpy", "pep257" ignore-decorators = ["typing.overload"] [tool.ruff.lint.flake8-type-checking] # Don't check code examples in docstrings quote-annotations = true [tool.ruff.lint.isort] required-imports = ["from __future__ import annotations"] known-first-party = ["zvec"] [tool.ruff.lint.per-file-ignores] "python/tests/**" = ["ALL"] "bench/core/**" = ["ALL"] "python/zvec/__init__.py" = [ "F401", # Unused import (for __all__) "E402", # Module level import not at top (C++ module init order) "PLE0605", # Invalid format for __all__ "RUF022", # __all__ is not sorted ] "python/zvec/model/doc.py" = [ "RUF023", # Unused sort (for __slot__) ] "python/zvec/extension/**" = [ "PLC0415", # Import outside top-level (dynamic imports in _get_model) ] [tool.ruff.format] indent-style = "space" quote-style = "double" line-ending = "lf" skip-magic-trailing-comma = false ================================================ FILE: python/tests/detail/distance_helper.py ================================================ import logging import math import numpy as np from zvec import ( MetricType, DataType, QuantizeType, Doc, CollectionSchema, FieldSchema, VectorSchema, ) from typing import Dict def is_float_equal(actual, expected, rel_tol=1e-5, abs_tol=1e-8): if actual is None and expected is None: return True return math.isclose(actual, expected, rel_tol=rel_tol, abs_tol=abs_tol) def is_dense_vector_equal(vec1, vec2, rtol=1e-5, atol=1e-8): """Compare two dense vectors with tolerance.""" return np.allclose(vec1, vec2, rtol=rtol, atol=atol) def is_sparse_vector_equal(vec1, vec2, rtol=1e-5, atol=1e-8): """Compare two sparse vectors with tolerance.""" # Check if they have the same keys if set(vec1.keys()) != set(vec2.keys()): return False # Check if all values are close for key in vec1: if not math.isclose(vec1[key], vec2[key], rel_tol=rtol, abs_tol=atol): return False return True def is_float_array_equal(arr1, arr2, rtol=1e-5, atol=1e-8): """Compare two float arrays with tolerance.""" return np.allclose(arr1, arr2, rtol=rtol, atol=atol) def is_double_array_equal(arr1, arr2, rtol=1e-9, atol=1e-12): """Compare two double arrays with tolerance.""" return np.allclose(arr1, arr2, rtol=rtol, atol=atol) def is_int_array_equal(arr1, arr2): """Compare two integer arrays with exact equality.""" return np.array_equal(arr1, arr2) def cosine_distance_dense( vec1, vec2, dtype: DataType = DataType.VECTOR_FP32, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): if dtype == DataType.VECTOR_FP16 or quantize_type == QuantizeType.FP16: # More stable conversion to float16 to avoid numerical issues vec1 = [float(np.float16(a)) for a in vec1] vec2 = [float(np.float16(b)) for b in vec2] elif dtype == DataType.VECTOR_INT8: # For INT8 vectors, convert to integers for proper calculation vec1 = [ int(round(min(max(val, -128), 127))) for val in vec1 ] # Clamp to valid INT8 range vec2 = [ int(round(min(max(val, -128), 127))) for val in vec2 ] # Clamp to valid INT8 range dot_product = sum(a * b for a, b in zip(vec1, vec2)) magnitude1 = math.sqrt(sum(a * a for a in vec1)) magnitude2 = math.sqrt(sum(b * b for b in vec2)) if magnitude1 == 0 or magnitude2 == 0: return 1.0 # Zero vector case - maximum distance cosine_similarity = dot_product / (magnitude1 * magnitude2) # Clamp to [-1, 1] range to handle floating-point precision errors cosine_similarity = max(-1.0, min(1.0, cosine_similarity)) # For identical vectors (within floating point precision), ensure cosine distance is 0.0 # This is especially important for low-precision types which have limited precision if ( dtype == DataType.VECTOR_FP16 or quantize_type == QuantizeType.FP16 or dtype == DataType.VECTOR_INT8 ): if ( abs(cosine_similarity - 1.0) < 1e-3 ): # Handle precision issues for low-precision types cosine_similarity = 1.0 # Return cosine distance (1 - cosine similarity) to maintain compatibility # with system internal processing and existing test expectations return 1.0 - cosine_similarity def dp_distance_dense( vec1, vec2, dtype: DataType = DataType.VECTOR_FP32, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): if dtype == DataType.VECTOR_FP16 or quantize_type == QuantizeType.FP16: # More stable computation to avoid numerical issues products = [ float(np.float16(a)) * float(np.float16(b)) for a, b in zip(vec1, vec2) ] return sum(products) elif dtype == DataType.VECTOR_INT8: # For INT8 vectors, convert to integers for proper calculation products = [ int(round(min(max(a, -128), 127))) * int(round(min(max(b, -128), 127))) for a, b in zip(vec1, vec2) ] return sum(products) return sum(a * b for a, b in zip(vec1, vec2)) def euclidean_distance_dense( vec1, vec2, dtype: DataType = DataType.VECTOR_FP32, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): if dtype == DataType.VECTOR_FP16 or quantize_type == QuantizeType.FP16: # Convert to float16 and compute squared differences safely # Use a more stable computation to avoid overflow squared_diffs = [] for a, b in zip(vec1, vec2): diff = np.float16(a) - np.float16(b) squared_diff = float(diff) * float( diff ) # Convert to float for multiplication squared_diffs.append(squared_diff) squared_distance = sum(squared_diffs) elif dtype == DataType.VECTOR_INT8: # For INT8 vectors, convert to integers and handle potential scaling # INT8 values might be treated differently in the library implementation vec1_int = [ int(round(min(max(val, -128), 127))) for val in vec1 ] # Clamp to valid INT8 range vec2_int = [ int(round(min(max(val, -128), 127))) for val in vec2 ] # Clamp to valid INT8 range # Use float type to prevent overflow when summing large squared differences squared_distance = sum(float(a - b) ** 2 for a, b in zip(vec1_int, vec2_int)) else: squared_distance = sum((a - b) ** 2 for a, b in zip(vec1, vec2)) return squared_distance # Return squared distance for INT8 def distance_dense( vec1, vec2, metric: MetricType, data_type: DataType = DataType.VECTOR_FP32, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): if metric == MetricType.COSINE: return cosine_distance_dense(vec1, vec2, data_type, quantize_type) elif metric == MetricType.L2: return euclidean_distance_dense(vec1, vec2, data_type, quantize_type) elif metric == MetricType.IP: return dp_distance_dense(vec1, vec2, data_type, quantize_type) else: raise ValueError("Unsupported metric type") def dp_distance_sparse( vec1, vec2, data_type: DataType = DataType.SPARSE_VECTOR_FP32, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): dot_product = 0.0 for dim in set(vec1.keys()) & set(vec2.keys()): print("dim,vec1,vec2:\n") print(dim, vec1, vec2) if ( data_type == DataType.SPARSE_VECTOR_FP16 or quantize_type == QuantizeType.FP16 ): vec1[dim] = np.float16(vec1[dim]) vec2[dim] = np.float16(vec2[dim]) dot_product += vec1[dim] * vec2[dim] return dot_product def distance( vec1, vec2, metric: MetricType, data_type: DataType, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): is_sparse = ( data_type == DataType.SPARSE_VECTOR_FP32 or data_type == DataType.SPARSE_VECTOR_FP16 ) if is_sparse: if metric != MetricType.IP: raise ValueError("Unsupported metric type for sparse vectors") if is_sparse: return dp_distance_sparse(vec1, vec2, data_type, quantize_type) else: return distance_dense(vec1, vec2, metric, data_type, quantize_type) def distance_recall( vec1, vec2, metric: MetricType, data_type: DataType, quantize_type: QuantizeType = QuantizeType.UNDEFINED, ): is_sparse = ( data_type == DataType.SPARSE_VECTOR_FP32 or data_type == DataType.SPARSE_VECTOR_FP16 ) if is_sparse: return dp_distance_sparse(vec1, vec2, data_type, quantize_type) else: if data_type in [DataType.VECTOR_FP32, DataType.VECTOR_FP16]: return distance_dense(vec1, vec2, metric, data_type, quantize_type) elif data_type in [DataType.VECTOR_INT8] and metric in [ MetricType.L2, MetricType.IP, ]: return distance_dense(vec1, vec2, metric, data_type, quantize_type) else: return dp_distance_dense(vec1, vec2, data_type, quantize_type) def calculate_rrf_score(rank, k=60): return 1.0 / (k + rank + 1) def calculate_multi_vector_rrf_scores(query_results: Dict[str, Doc], k=60): rrf_scores = {} for vector_name, docs in query_results.items(): for rank, doc in enumerate(docs): doc_id = doc.id rrf_score = calculate_rrf_score(rank, k) if doc_id in rrf_scores: rrf_scores[doc_id] += rrf_score else: rrf_scores[doc_id] = rrf_score return rrf_scores def calculate_multi_vector_weighted_scores( query_results: Dict[str, Doc], weights: Dict[str, float], metric: MetricType ): def _normalize_score(score: float, metric: MetricType) -> float: if metric == MetricType.L2: return 1.0 - 2 * math.atan(score) / math.pi if metric == MetricType.IP: return 0.5 + math.atan(score) / math.pi if metric == MetricType.COSINE: return 1.0 - score / 2.0 raise ValueError("Unsupported metric type") weighted_scores = {} for vector_name, docs in query_results.items(): weight = weights.get(vector_name, 1.0) for doc in docs: doc_id = doc.id weighted_score = (_normalize_score(doc.score, metric)) * weight if doc_id in weighted_scores: weighted_scores[doc_id] += weighted_score else: weighted_scores[doc_id] = weighted_score return weighted_scores def is_field_equal(field1, field2, schema: FieldSchema) -> bool: if field1 is None and field2 is None: return True if field1 is None or field2 is None: return False if schema.data_type == DataType.ARRAY_FLOAT: return is_float_array_equal(field1, field2) elif schema.data_type == DataType.ARRAY_DOUBLE: return is_double_array_equal(field1, field2) elif schema.data_type in [ DataType.ARRAY_INT32, DataType.ARRAY_INT64, DataType.ARRAY_BOOL, DataType.ARRAY_STRING, DataType.ARRAY_UINT32, DataType.ARRAY_UINT64, DataType.ARRAY_INT64, ]: return is_int_array_equal(field1, field2) elif schema.data_type in [DataType.FLOAT, DataType.DOUBLE]: return is_float_equal(field1, field2) return field1 == field2 def is_vector_equal(vec1, vec2, schema: VectorSchema) -> bool: if ( schema.data_type == DataType.SPARSE_VECTOR_FP16 or schema.data_type == DataType.VECTOR_FP16 ): # skip fp16 vector equal return True is_sparse = ( schema.data_type == DataType.SPARSE_VECTOR_FP32 or schema.data_type == DataType.SPARSE_VECTOR_FP16 ) if is_sparse: return is_sparse_vector_equal(vec1, vec2) else: return is_dense_vector_equal(vec1, vec2) def is_doc_equal( doc1: Doc, doc2: Doc, schema: CollectionSchema, except_score: bool = True, include_vector: bool = True, ): if doc1.id != doc2.id: logging.error("doc ids are not equal") return False reduce_field_names = set(doc1.field_names() + doc2.field_names()) reduce_vector_names = set(doc1.vector_names() + doc2.vector_names()) is_doc1_fields_empty = doc1.fields is None or doc1.fields == {} is_doc2_fields_empty = doc2.fields is None or doc2.fields == {} if is_doc1_fields_empty or is_doc2_fields_empty: if is_doc1_fields_empty != is_doc2_fields_empty: return False else: for field_name in reduce_field_names: field_schema = schema.field(field_name) if field_schema is None: return False if is_field_equal( doc1.field(field_name), doc2.field(field_name), field_schema ): continue else: logging.error(f"{field_name} are not equal") return False if include_vector: is_doc1_vectors_empty = doc1.vectors is None or doc1.vectors == {} is_doc2_vectors_empty = doc2.vectors is None or doc2.vectors == {} if is_doc1_vectors_empty or is_doc2_vectors_empty: if is_doc1_fields_empty != is_doc2_vectors_empty: return False else: for vector_name in reduce_vector_names: vector_schema = schema.vector(vector_name) if vector_schema is None: return False if is_vector_equal( doc1.vector(vector_name), doc2.vector(vector_name), vector_schema ): continue else: return False return True ================================================ FILE: python/tests/detail/doc_helper.py ================================================ from zvec import CollectionSchema, Doc from support_helper import * import numpy as np from typing import Literal, Optional, Union, Tuple import random import string import math def generate_constant_vector( i: int, dimension: int, dtype: Literal["int8", "float16", "float32"] = "float32" ): if dtype == "int8": vec = [(i % 127)] * dimension vec[i % dimension] = (i + 1) % 127 else: base_val = (i % 1000) / 256.0 special_val = ((i + 1) % 1000) / 256.0 vec = [base_val] * dimension vec[i % dimension] = special_val return vec def generate_constant_vector_recall( i: int, dimension: int, dtype: Literal["int8", "float16", "float32"] = "float32" ): if dtype == "int8": vec = [(i % 127)] * dimension vec[i % dimension] = (i + 1) % 127 else: base_val = math.sin((i) * 1000) / 256.0 special_val = math.sin((i + 1) * 1000) / 256.0 vec = [base_val] * dimension vec[i % dimension] = special_val return vec def generate_sparse_vector(i: int): return {i: i + 0.1} def generate_vectordict(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields = {} doc_vectors = {} for field in schema.fields: if field.data_type == DataType.BOOL: doc_fields[field.name] = i % 2 == 0 elif field.data_type == DataType.INT32: doc_fields[field.name] = i elif field.data_type == DataType.UINT32: doc_fields[field.name] = i elif field.data_type == DataType.INT64: doc_fields[field.name] = i elif field.data_type == DataType.UINT64: doc_fields[field.name] = i elif field.data_type == DataType.FLOAT: doc_fields[field.name] = float(i) + 0.1 elif field.data_type == DataType.DOUBLE: doc_fields[field.name] = float(i) + 0.11 elif field.data_type == DataType.STRING: doc_fields[field.name] = f"test_{i}" elif field.data_type == DataType.ARRAY_BOOL: doc_fields[field.name] = [i % 2 == 0, i % 3 == 0] elif field.data_type == DataType.ARRAY_INT32: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT32: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_INT64: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT64: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_FLOAT: doc_fields[field.name] = [float(i + 0.1), float(i + 1.1), float(i + 2.1)] elif field.data_type == DataType.ARRAY_DOUBLE: doc_fields[field.name] = [float(i + 0.11), float(i + 1.11), float(i + 2.11)] elif field.data_type == DataType.ARRAY_STRING: doc_fields[field.name] = [f"test_{i}", f"test_{i + 1}", f"test_{i + 2}"] else: raise ValueError(f"Unsupported field type: {field.data_type}") for vector in schema.vectors: if vector.data_type == DataType.VECTOR_FP16: doc_vectors[vector.name] = generate_constant_vector( i, vector.dimension, "float16" ) elif vector.data_type == DataType.VECTOR_FP32: doc_vectors[vector.name] = generate_constant_vector( i, vector.dimension, "float32" ) elif vector.data_type == DataType.VECTOR_INT8: doc_vectors[vector.name] = generate_constant_vector( i, vector.dimension, "int8", ) elif vector.data_type == DataType.SPARSE_VECTOR_FP32: doc_vectors[vector.name] = generate_sparse_vector(i) elif vector.data_type == DataType.SPARSE_VECTOR_FP16: doc_vectors[vector.name] = generate_sparse_vector(i) else: raise ValueError(f"Unsupported vector type: {vector.data_type}") return doc_fields, doc_vectors def generate_vectordict_recall(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields = {} doc_vectors = {} for field in schema.fields: if field.data_type == DataType.BOOL: doc_fields[field.name] = i % 2 == 0 elif field.data_type == DataType.INT32: doc_fields[field.name] = i elif field.data_type == DataType.UINT32: doc_fields[field.name] = i elif field.data_type == DataType.INT64: doc_fields[field.name] = i elif field.data_type == DataType.UINT64: doc_fields[field.name] = i elif field.data_type == DataType.FLOAT: doc_fields[field.name] = float(i) + 0.1 elif field.data_type == DataType.DOUBLE: doc_fields[field.name] = float(i) + 0.11 elif field.data_type == DataType.STRING: doc_fields[field.name] = f"test_{i}" elif field.data_type == DataType.ARRAY_BOOL: doc_fields[field.name] = [i % 2 == 0, i % 3 == 0] elif field.data_type == DataType.ARRAY_INT32: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT32: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_INT64: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT64: doc_fields[field.name] = [i, i + 1, i + 2] elif field.data_type == DataType.ARRAY_FLOAT: doc_fields[field.name] = [float(i + 0.1), float(i + 1.1), float(i + 2.1)] elif field.data_type == DataType.ARRAY_DOUBLE: doc_fields[field.name] = [float(i + 0.11), float(i + 1.11), float(i + 2.11)] elif field.data_type == DataType.ARRAY_STRING: doc_fields[field.name] = [f"test_{i}", f"test_{i + 1}", f"test_{i + 2}"] else: raise ValueError(f"Unsupported field type: {field.data_type}") for vector in schema.vectors: if vector.data_type == DataType.VECTOR_FP16: doc_vectors[vector.name] = generate_constant_vector_recall( i, vector.dimension, "float16" ) elif vector.data_type == DataType.VECTOR_FP32: doc_vectors[vector.name] = generate_constant_vector_recall( i, vector.dimension, "float32" ) elif vector.data_type == DataType.VECTOR_INT8: doc_vectors[vector.name] = generate_constant_vector_recall( i, vector.dimension, "int8", ) elif vector.data_type == DataType.SPARSE_VECTOR_FP32: doc_vectors[vector.name] = generate_sparse_vector(i) elif vector.data_type == DataType.SPARSE_VECTOR_FP16: doc_vectors[vector.name] = generate_sparse_vector(i) else: raise ValueError(f"Unsupported vector type: {vector.data_type}") return doc_fields, doc_vectors def generate_vectordict_update(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields = {} doc_vectors = {} for field in schema.fields: if field.data_type == DataType.BOOL: doc_fields[field.name] = (i + 1) % 2 == 0 elif field.data_type == DataType.INT32: doc_fields[field.name] = i + 1 elif field.data_type == DataType.UINT32: doc_fields[field.name] = i + 1 elif field.data_type == DataType.INT64: doc_fields[field.name] = i + 1 elif field.data_type == DataType.UINT64: doc_fields[field.name] = i + 1 elif field.data_type == DataType.FLOAT: doc_fields[field.name] = float(i + 1) + 0.1 elif field.data_type == DataType.DOUBLE: doc_fields[field.name] = float(i + 1) + 0.11 elif field.data_type == DataType.STRING: doc_fields[field.name] = f"test_{i + 1}" elif field.data_type == DataType.ARRAY_BOOL: doc_fields[field.name] = [(i + 1) % 2 == 0, (i + 1) % 3 == 0] elif field.data_type == DataType.ARRAY_INT32: doc_fields[field.name] = [i + 1, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT32: doc_fields[field.name] = [i + 1, i + 1, i + 2] elif field.data_type == DataType.ARRAY_INT64: doc_fields[field.name] = [i + 1, i + 1, i + 2] elif field.data_type == DataType.ARRAY_UINT64: doc_fields[field.name] = [i + 1, i + 1, i + 2] elif field.data_type == DataType.ARRAY_FLOAT: doc_fields[field.name] = [float(i + 1.1), float(i + 2.1), float(i + 3.1)] elif field.data_type == DataType.ARRAY_DOUBLE: doc_fields[field.name] = [float(i + 1.11), float(i + 2.11), float(i + 3.11)] elif field.data_type == DataType.ARRAY_STRING: doc_fields[field.name] = [f"test_{i + 1}", f"test_{i + 2}", f"test_{i + 3}"] else: raise ValueError(f"Unsupported field type: {field.data_type}") for vector in schema.vectors: if vector.data_type == DataType.VECTOR_FP16: doc_vectors[vector.name] = generate_constant_vector( i + 1, vector.dimension, "float16" ) elif vector.data_type == DataType.VECTOR_FP32: doc_vectors[vector.name] = generate_constant_vector( i + 1, vector.dimension, "float32" ) elif vector.data_type == DataType.VECTOR_INT8: doc_vectors[vector.name] = generate_constant_vector( i + 1, vector.dimension, "int8", ) elif vector.data_type == DataType.SPARSE_VECTOR_FP32: doc_vectors[vector.name] = generate_sparse_vector(i + 1) elif vector.data_type == DataType.SPARSE_VECTOR_FP16: doc_vectors[vector.name] = generate_sparse_vector(i + 1) else: raise ValueError(f"Unsupported vector type: {vector.data_type}") return doc_fields, doc_vectors def generate_doc(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields, doc_vectors = generate_vectordict(i, schema) doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) return doc def generate_doc_recall(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields, doc_vectors = generate_vectordict_recall(i, schema) doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) return doc def generate_update_doc(i: int, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} doc_fields, doc_vectors = generate_vectordict_update(i, schema) doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) return doc def generate_doc_random(i, schema: CollectionSchema) -> Doc: doc_fields = {} doc_vectors = {} random.seed(i) for field in schema.fields: if field.data_type == DataType.BOOL: doc_fields[field.name] = random.choice([True, False]) elif field.data_type == DataType.INT32: doc_fields[field.name] = random.randint(-2147483648, 2147483647) elif field.data_type == DataType.UINT32: doc_fields[field.name] = random.randint(0, 4294967295) elif field.data_type == DataType.INT64: doc_fields[field.name] = random.randint( -9223372036854775808, 9223372036854775807 ) elif field.data_type == DataType.UINT64: doc_fields[field.name] = random.randint(0, 18446744073709551615) elif field.data_type == DataType.FLOAT: doc_fields[field.name] = random.uniform(-3.4028235e38, 3.4028235e38) elif field.data_type == DataType.DOUBLE: doc_fields[field.name] = random.uniform( -1.7976931348623157e308, 1.7976931348623157e308 ) elif field.data_type == DataType.STRING: length = random.randint(1, 999) doc_fields[field.name] = "".join( random.choices(string.ascii_letters + string.digits, k=length) ) elif field.data_type == DataType.ARRAY_BOOL: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.choice([True, False]) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_INT32: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(-2147483648, 2147483647) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_UINT32: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(0, 4294967295) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_INT64: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(-9223372036854775808, 9223372036854775807) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_UINT64: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(0, 18446744073709551615) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_FLOAT: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.uniform(-3.4028235e38, 3.4028235e38) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_DOUBLE: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.uniform(-1.7976931348623157e308, 1.7976931348623157e308) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_STRING: array_length = random.randint(0, 10) doc_fields[field.name] = [ "".join( random.choices( string.ascii_letters + string.digits, k=random.randint(1, 100) ) ) for _ in range(array_length) ] else: raise ValueError(f"Unsupported field type: {field.data_type}") for vector in schema.vectors: if vector.data_type == DataType.VECTOR_FP16: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), DEFAULT_VECTOR_DIMENSION, "float16" ) elif vector.data_type == DataType.VECTOR_FP32: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), DEFAULT_VECTOR_DIMENSION, "float32" ) elif vector.data_type == DataType.VECTOR_INT8: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), DEFAULT_VECTOR_DIMENSION, "int8" ) elif vector.data_type == DataType.SPARSE_VECTOR_FP32: doc_vectors[vector.name] = generate_sparse_vector(random.randint(1, 100)) elif vector.data_type == DataType.SPARSE_VECTOR_FP16: doc_vectors[vector.name] = generate_sparse_vector(random.randint(1, 100)) else: raise ValueError(f"Unsupported vector type: {vector.data_type}") doc = Doc(id=i, fields=doc_fields, vectors=doc_vectors) return doc def generate_vectordict_random(schema: CollectionSchema): doc_fields = {} doc_vectors = {} for field in schema.fields: if field.data_type == DataType.BOOL: doc_fields[field.name] = random.choice([True, False]) elif field.data_type == DataType.INT32: doc_fields[field.name] = random.randint(-2147483648, 2147483647) elif field.data_type == DataType.UINT32: doc_fields[field.name] = random.randint(0, 4294967295) elif field.data_type == DataType.INT64: doc_fields[field.name] = random.randint( -9223372036854775808, 9223372036854775807 ) elif field.data_type == DataType.UINT64: doc_fields[field.name] = random.randint(0, 18446744073709551615) elif field.data_type == DataType.FLOAT: doc_fields[field.name] = random.uniform(-3.4028235e38, 3.4028235e38) elif field.data_type == DataType.DOUBLE: doc_fields[field.name] = random.uniform( -1.7976931348623157e308, 1.7976931348623157e308 ) elif field.data_type == DataType.STRING: length = random.randint(1, 999) doc_fields[field.name] = "".join( random.choices(string.ascii_letters + string.digits, k=length) ) elif field.data_type == DataType.ARRAY_BOOL: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.choice([True, False]) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_INT32: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(-2147483648, 2147483647) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_UINT32: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(0, 4294967295) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_INT64: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(-9223372036854775808, 9223372036854775807) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_UINT64: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.randint(0, 18446744073709551615) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_FLOAT: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.uniform(-3.4028235e38, 3.4028235e38) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_DOUBLE: array_length = random.randint(0, 10) doc_fields[field.name] = [ random.uniform(-1.7976931348623157e308, 1.7976931348623157e308) for _ in range(array_length) ] elif field.data_type == DataType.ARRAY_STRING: array_length = random.randint(0, 10) doc_fields[field.name] = [ "".join( random.choices( string.ascii_letters + string.digits, k=random.randint(1, 100) ) ) for _ in range(array_length) ] else: raise ValueError(f"Unsupported field type: {field.data_type}") for vector in schema.vectors: if vector.data_type == DataType.VECTOR_FP16: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), vector.dimension, "float16" ) elif vector.data_type == DataType.VECTOR_FP32: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), vector.dimension, "float32" ) elif vector.data_type == DataType.VECTOR_INT8: doc_vectors[vector.name] = generate_constant_vector( random.randint(1, 100), vector.dimension, "int8" ) elif vector.data_type == DataType.SPARSE_VECTOR_FP32: doc_vectors[vector.name] = generate_sparse_vector(random.randint(1, 100)) elif vector.data_type == DataType.SPARSE_VECTOR_FP16: doc_vectors[vector.name] = generate_sparse_vector(random.randint(1, 100)) else: raise ValueError(f"Unsupported vector type: {vector.data_type}") return doc_fields, doc_vectors ================================================ FILE: python/tests/detail/fixture_helper.py ================================================ import pytest import logging from typing import Any, Generator from zvec.typing import DataType, StatusCode, MetricType, QuantizeType import zvec from zvec import ( CollectionOption, InvertIndexParam, HnswIndexParam, FlatIndexParam, IVFIndexParam, FieldSchema, VectorSchema, CollectionSchema, Collection, Doc, VectorQuery, ) from support_helper import * @pytest.fixture(scope="session") def basic_schema(collection_name="test_collection") -> CollectionSchema: return CollectionSchema( name=collection_name if len(collection_name) > 0 else "test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam() ), FieldSchema("weight", DataType.FLOAT, nullable=True), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) @pytest.fixture(scope="session") def full_schema( nullable: bool = False, has_index: bool = False, ) -> CollectionSchema: scalar_index_param = None vector_index_param = None if has_index: scalar_index_param = InvertIndexParam(enable_range_optimization=True) vector_index_param = HnswIndexParam() fields = [] for k, v in DEFAULT_SCALAR_FIELD_NAME.items(): fields.append( FieldSchema( v, k, nullable=nullable, index_param=scalar_index_param, ) ) vetors = [] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): vetors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) return CollectionSchema( name="full_collection", fields=fields, vectors=vetors, ) @pytest.fixture(scope="function") def full_schema_new(request) -> CollectionSchema: if hasattr(request, "param"): nullable, has_index, vector_index = request.param else: nullable, has_index, vector_index = True, False, HnswIndexParam() scalar_index_param = None vector_index_param = None if has_index: scalar_index_param = InvertIndexParam(enable_range_optimization=True) vector_index_param = vector_index fields = [] for k, v in DEFAULT_SCALAR_FIELD_NAME.items(): fields.append( FieldSchema( v, k, nullable=nullable, index_param=scalar_index_param, ) ) vectors = [] if vector_index_param in [ HnswIndexParam(), FlatIndexParam(), HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, ), FlatIndexParam( metric_type=MetricType.IP, ), ]: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) elif vector_index_param in [ IVFIndexParam(), IVFIndexParam( metric_type=MetricType.IP, n_list=100, n_iters=10, use_soar=False, ), IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), ( IVFIndexParam( metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, ) ), ( HnswIndexParam( metric_type=MetricType.COSINE, m=24, ef_construction=150, ) ), ( HnswIndexParam( metric_type=MetricType.L2, m=32, ef_construction=200, ) ), ( FlatIndexParam( metric_type=MetricType.COSINE, ) ), ( FlatIndexParam( metric_type=MetricType.L2, ) ), ]: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) elif v in ["vector_int8_field"] and vector_index_param in [ IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), ( HnswIndexParam( metric_type=MetricType.L2, m=32, ef_construction=200, ) ), ( FlatIndexParam( metric_type=MetricType.L2, ) ), ]: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) else: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=HnswIndexParam(), ) ) else: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) else: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=HnswIndexParam(), ) ) return CollectionSchema( name="full_collection_new", fields=fields, vectors=vectors, ) @pytest.fixture(scope="function") def full_schema_ivf(request) -> CollectionSchema: if hasattr(request, "param"): nullable, has_index, vector_index = request.param else: nullable, has_index, vector_index = True, False, IVFIndexParam() scalar_index_param = None vector_index_param = None if has_index: scalar_index_param = InvertIndexParam(enable_range_optimization=True) vector_index_param = vector_index fields = [] for k, v in DEFAULT_SCALAR_FIELD_NAME.items(): fields.append( FieldSchema( v, k, nullable=nullable, index_param=scalar_index_param, ) ) vectors = [] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: vectors.append( VectorSchema( v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) return CollectionSchema( name="full_collection_ivf", fields=fields, vectors=vectors, ) @pytest.fixture(scope="function") def full_schema_1024(request) -> CollectionSchema: if hasattr(request, "param"): nullable, has_index, vector_index = request.param else: nullable, has_index, vector_index = True, False, HnswIndexParam() scalar_index_param = None vector_index_param = None if has_index: scalar_index_param = InvertIndexParam(enable_range_optimization=True) vector_index_param = vector_index fields = [] for k, v in DEFAULT_SCALAR_FIELD_NAME.items(): fields.append( FieldSchema( v, k, nullable=nullable, index_param=scalar_index_param, ) ) vectors = [] if vector_index_param in [ HnswIndexParam(), FlatIndexParam(), HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, ), FlatIndexParam( metric_type=MetricType.IP, ), ]: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): vectors.append( VectorSchema( v, k, dimension=VECTOR_DIMENSION_1024, index_param=vector_index_param, ) ) elif vector_index_param in [ IVFIndexParam(), IVFIndexParam( metric_type=MetricType.IP, n_list=100, n_iters=10, use_soar=False, ), IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), IVFIndexParam( metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, ), ]: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: vectors.append( VectorSchema( v, k, dimension=VECTOR_DIMENSION_1024, index_param=vector_index_param, ) ) elif v in ["vector_int8_field"] and vector_index_param in [ IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), IVFIndexParam( metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, ), ]: vectors.append( VectorSchema( v, k, dimension=DVECTOR_DIMENSION_1024, index_param=vector_index_param, ) ) else: vectors.append( VectorSchema( v, k, dimension=VECTOR_DIMENSION_1024, index_param=HnswIndexParam(), ) ) else: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field", "vector_int8_field"]: vectors.append( VectorSchema( v, k, dimension=VECTOR_DIMENSION_1024, index_param=vector_index_param, ) ) else: vectors.append( VectorSchema( v, k, dimension=VECTOR_DIMENSION_1024, index_param=HnswIndexParam(), ) ) return CollectionSchema( name="full_collection_new", fields=fields, vectors=vectors, ) @pytest.fixture(scope="function") def single_vector_schema( data_type: DataType, ) -> CollectionSchema: vector_schema = [ VectorSchema( DEFAULT_VECTOR_FIELD_NAME[data_type], data_type, DEFAULT_VECTOR_DIMENSION, ) ] return CollectionSchema( name="full_collection", vectors=vector_schema, ) @pytest.fixture(scope="function") def single_vector_schema_with_index_param( data_type: DataType, index_param ) -> CollectionSchema: vector_schema = [ VectorSchema( DEFAULT_VECTOR_FIELD_NAME[data_type], data_type, DEFAULT_VECTOR_DIMENSION, index_param, ) ] return CollectionSchema( name="full_collection", vectors=vector_schema, ) def create_collection_fixture( collection_temp_dir, schema: CollectionSchema, collection_option: CollectionOption ) -> Generator[Any, Any, Collection]: """Common helper function to create and manage collection fixtures.""" coll = zvec.create_and_open( path=str(collection_temp_dir), schema=schema, option=collection_option, ) assert coll is not None, "Failed to create and open collection" assert coll.path == str(collection_temp_dir) assert coll.schema.name == schema.name assert list(coll.schema.fields) == list(schema.fields) assert list(coll.schema.vectors) == list(schema.vectors) assert coll.option.read_only == collection_option.read_only assert coll.option.enable_mmap == collection_option.enable_mmap try: yield coll finally: if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: logging.warning(f"Warning: failed to destroy collection: {e}") @pytest.fixture(scope="function") def basic_collection( collection_temp_dir, basic_schema, collection_option ) -> Generator[Any, Any, Collection]: yield from create_collection_fixture( collection_temp_dir, basic_schema, collection_option ) @pytest.fixture(scope="function") def collection_option(): return CollectionOption(read_only=False, enable_mmap=True) @pytest.fixture(scope="function") def collection_temp_dir(tmp_path_factory): temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection_path" return str(collection_path) @pytest.fixture(scope="function") def full_collection( collection_temp_dir, full_schema, collection_option, nullable: bool = True, has_index: bool = False, ) -> Generator[Any, Any, Collection]: yield from create_collection_fixture( collection_temp_dir, full_schema, collection_option ) @pytest.fixture(scope="function") def full_collection_new( collection_temp_dir, full_schema_new, collection_option ) -> Generator[Any, Any, Collection]: yield from create_collection_fixture( collection_temp_dir, full_schema_new, collection_option ) @pytest.fixture(scope="function") def full_collection_ivf( collection_temp_dir, full_schema_ivf, collection_option ) -> Generator[Any, Any, Collection]: yield from create_collection_fixture( collection_temp_dir, full_schema_ivf, collection_option ) @pytest.fixture(scope="function") def full_collection_1024( collection_temp_dir, full_schema_1024, collection_option ) -> Generator[Any, Any, Collection]: yield from create_collection_fixture( collection_temp_dir, full_schema_1024, collection_option ) @pytest.fixture def sample_field_list(nullable: bool = True, scalar_index_param=None, name_prefix=""): field_list = [] for k, v in DEFAULT_SCALAR_FIELD_NAME.items(): field_list.append( FieldSchema( f"{name_prefix}_{v}" if len(name_prefix) > 0 else v, k, nullable=nullable, index_param=scalar_index_param, ) ) return field_list @pytest.fixture def sample_vector_list(vector_index_param=None, name_prefix=""): vector_list = [] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): vector_list.append( VectorSchema( f"{name_prefix}_{v}" if len(name_prefix) > 0 else v, k, dimension=DEFAULT_VECTOR_DIMENSION, index_param=vector_index_param, ) ) return vector_list ================================================ FILE: python/tests/detail/params_helper.py ================================================ from zvec import ( CollectionOption, IndexOption, OptimizeOption, InvertIndexParam, HnswIndexParam, IVFIndexParam, FlatIndexParam, AlterColumnOption, AddColumnOption, DataType, MetricType, QuantizeType, ) VALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP = { DataType.VECTOR_FP32: [ HnswIndexParam(), HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, quantize_type=QuantizeType.INT8, ), HnswIndexParam( metric_type=MetricType.COSINE, m=24, ef_construction=150, quantize_type=QuantizeType.INT4, ), HnswIndexParam( metric_type=MetricType.L2, m=32, ef_construction=200, quantize_type=QuantizeType.FP16, ), FlatIndexParam(), FlatIndexParam(metric_type=MetricType.IP, quantize_type=QuantizeType.INT4), FlatIndexParam(metric_type=MetricType.L2, quantize_type=QuantizeType.INT8), FlatIndexParam(metric_type=MetricType.COSINE, quantize_type=QuantizeType.FP16), IVFIndexParam(), IVFIndexParam( metric_type=MetricType.IP, quantize_type=QuantizeType.INT4, n_list=100, n_iters=10, use_soar=False, ), IVFIndexParam( metric_type=MetricType.L2, quantize_type=QuantizeType.INT8, n_list=200, n_iters=20, use_soar=True, ), IVFIndexParam( metric_type=MetricType.COSINE, quantize_type=QuantizeType.FP16, n_list=150, n_iters=15, use_soar=False, ), ], DataType.VECTOR_FP16: [ HnswIndexParam(), FlatIndexParam(), # IVFIndexParam(), ], DataType.VECTOR_INT8: [ HnswIndexParam(), FlatIndexParam(), # IVFIndexParam(), ], DataType.SPARSE_VECTOR_FP32: [ HnswIndexParam(), FlatIndexParam(), HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, quantize_type=QuantizeType.FP16, ), ], DataType.SPARSE_VECTOR_FP16: [ HnswIndexParam(), FlatIndexParam(), HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, ), ], } VALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP_PARAMS = [ (data_type, param) for data_type, params in VALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP.items() for param in params ] INVALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP = { DataType.VECTOR_FP32: [ InvertIndexParam(), ], DataType.VECTOR_FP16: [ InvertIndexParam(), ], DataType.VECTOR_INT8: [ InvertIndexParam(), ], DataType.SPARSE_VECTOR_FP32: [ HnswIndexParam(metric_type=MetricType.L2), FlatIndexParam(metric_type=MetricType.COSINE), IVFIndexParam(), InvertIndexParam(), ], DataType.SPARSE_VECTOR_FP16: [ HnswIndexParam(metric_type=MetricType.L2), FlatIndexParam(metric_type=MetricType.COSINE), IVFIndexParam(), InvertIndexParam(), ], } INVALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP_PARAMS = [ (data_type, param) for data_type, params in INVALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP.items() for param in params ] COLLECTION_NAME_MAX_LENGTH = 64 COLLECTION_NAME_VALID_LIST = [ "col", "C0llECTION", "Collection1", "collection_2", "123collection-", "a" * COLLECTION_NAME_MAX_LENGTH, ] COLLECTION_NAME_INVALID_LIST = [ "l", "1C", "", " ", None, "abcdefghijklmnopqrstuvwxzy123456abcdefghijklmnopqrstuvwxzy1234561", "test/", "!@#$%^&*()test", ] FIELD_NAME_VALID_LIST = [ "1", "12", "col", "ID", "name1", "Weigt_12-", "123age", "name_with_underscores", "123numeric_start", "name-with-dashes", ] FIELD_NAME_INVALID_LIST = [ "", " ", None, "abcdefghijklmnopqrstuvwxzy1234561", "test/", "!@#$%^&*()test", "name@with#special$chars", "name with spaces", ] FIELD_LIST_MAX_LENGTH = 1024 VECTOR_LIST_MAX_LENGTH = 5 DENSE_VECTOR_MAX_DIMENSION = 20000 SPARSE_VECTOR_MAX_DIMENSION = 4096 FIELD_VECTOR_LIST_DIMENSION_VALID_LIST = [ # field_list_len, vector_list_len, dimension (1, 1, 1), (2, 2, 512), (512, 3, 1024), (1024, 4, 20000), ] FIELD_VECTOR_LIST_DIMENSION_INVALID_LIST = [ # field_list_len, vector_list_len, dimension (1, 1, 0), (1, 1, -1), (1, 1, "1"), (1, 1, 20001), ] INCOMPATIBLE_CONSTRUCTOR_ERROR_MSG = "incompatible constructor arguments" SCHEMA_VALIDATE_ERROR_MSG = "schema validate failed" CREATE_READ_ONLY_ERROR_MSG = "Unable to create collection with read-only mode" INCOMPATIBLE_FUNCTION_ERROR_MSG = "incompatible function arguments" INVALID_PATH_ERROR_MSG = "path validate failed" INDEX_NON_EXISTENT_COLUMN_ERROR_MSG = "not found in schema" ACCESS_DESTROYED_COLLECTION_ERROR_MSG = "is already destroyed" COLLECTION_PATH_NOT_EXIST_ERROR_MSG = "not exist" NOT_SUPPORT_ADD_COLUMN_ERROR_MSG = "Only support basic numeric data type" NOT_EXIST_COLUMN_TO_DROP_ERROR_MSG = "Column not exists" ================================================ FILE: python/tests/detail/support_helper.py ================================================ from zvec import ( CollectionOption, IndexOption, OptimizeOption, InvertIndexParam, HnswIndexParam, IVFIndexParam, FlatIndexParam, DataType, IndexType, QuantizeType, ) SUPPORT_SCALAR_DATA_TYPES = [ DataType.BOOL, DataType.FLOAT, DataType.DOUBLE, DataType.INT32, DataType.INT64, DataType.UINT32, DataType.UINT64, DataType.STRING, DataType.ARRAY_BOOL, DataType.ARRAY_FLOAT, DataType.ARRAY_DOUBLE, DataType.ARRAY_INT32, DataType.ARRAY_INT64, DataType.ARRAY_UINT32, DataType.ARRAY_UINT64, DataType.ARRAY_STRING, ] DEFAULT_SCALAR_FIELD_NAME = { DataType.BOOL: "bool_field", DataType.FLOAT: "float_field", DataType.DOUBLE: "double_field", DataType.INT32: "int32_field", DataType.INT64: "int64_field", DataType.UINT32: "uint32_field", DataType.UINT64: "uint64_field", DataType.STRING: "string_field", DataType.ARRAY_BOOL: "array_bool_field", DataType.ARRAY_FLOAT: "array_float_field", DataType.ARRAY_DOUBLE: "array_double_field", DataType.ARRAY_INT32: "array_int32_field", DataType.ARRAY_INT64: "array_int64_field", DataType.ARRAY_UINT32: "array_uint32_field", DataType.ARRAY_UINT64: "array_uint64_field", DataType.ARRAY_STRING: "array_string_field", } SUPPORT_SCALAR_INDEX_TYPES = [ IndexType.INVERT, ] SUPPORT_VECTOR_DATA_TYPES = [ DataType.VECTOR_FP16, DataType.VECTOR_FP32, DataType.VECTOR_INT8, DataType.SPARSE_VECTOR_FP32, DataType.SPARSE_VECTOR_FP16, ] SUPPORT_VECTOR_INDEX_TYPES = [ IndexType.FLAT, IndexType.HNSW, IndexType.IVF, ] DEFAULT_VECTOR_FIELD_NAME = { DataType.VECTOR_FP16: "vector_fp16_field", DataType.VECTOR_FP32: "vector_fp32_field", DataType.VECTOR_INT8: "vector_int8_field", DataType.SPARSE_VECTOR_FP32: "sparse_vector_fp32_field", DataType.SPARSE_VECTOR_FP16: "sparse_vector_fp16_field", } DEFAULT_VECTOR_DIMENSION = 128 VECTOR_DIMENSION_1024 = 4 SUPPORT_VECTOR_DATA_TYPE_INDEX_MAP = { DataType.VECTOR_FP16: [IndexType.FLAT, IndexType.HNSW, IndexType.IVF], DataType.VECTOR_FP32: [IndexType.FLAT, IndexType.HNSW, IndexType.IVF], DataType.VECTOR_INT8: [IndexType.FLAT, IndexType.HNSW], DataType.SPARSE_VECTOR_FP32: [IndexType.FLAT, IndexType.HNSW], DataType.SPARSE_VECTOR_FP16: [IndexType.FLAT, IndexType.HNSW], } SUPPORT_VECTOR_DATA_TYPE_INDEX_MAP_PARAMS = [ (data_type, index_type) for data_type, index_types in SUPPORT_VECTOR_DATA_TYPE_INDEX_MAP.items() for index_type in index_types ] DEFAULT_INDEX_PARAMS = { IndexType.FLAT: FlatIndexParam(), IndexType.HNSW: HnswIndexParam(), IndexType.IVF: IVFIndexParam(), IndexType.INVERT: InvertIndexParam(), } SUPPORT_VECTOR_DATA_TYPE_QUANT_MAP = { DataType.VECTOR_FP32: [QuantizeType.FP16, QuantizeType.INT8, QuantizeType.INT4], DataType.SPARSE_VECTOR_FP32: [QuantizeType.FP16], } SUPPORT_ADD_COLUMN_DATA_TYPE = [ DataType.INT32, DataType.UINT32, DataType.INT64, DataType.UINT64, DataType.FLOAT, DataType.DOUBLE, ] NOT_SUPPORT_ADD_COLUMN_DATA_TYPE = [ DataType.BOOL, DataType.STRING, DataType.ARRAY_BOOL, DataType.ARRAY_INT32, DataType.ARRAY_INT64, DataType.ARRAY_UINT32, DataType.ARRAY_UINT64, DataType.ARRAY_FLOAT, DataType.ARRAY_DOUBLE, DataType.ARRAY_STRING, ] ================================================ FILE: python/tests/detail/test_collection_concurrency.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pytest import threading import numpy as np import zvec from zvec import ( CollectionOption, InvertIndexParam, HnswIndexParam, Collection, Doc, DataType, FieldSchema, VectorSchema, ) class TestCollectionConcurrency: @pytest.fixture(scope="function") def test_collection(self, tmp_path_factory): """Fixture to create a test collection""" collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), FieldSchema("weight", DataType.FLOAT, nullable=True), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert coll is not None, "Failed to create and open collection" yield coll # Clean up if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") def test_concurrent_read_write(self, test_collection: Collection): results = [] def insert_docs(thread_id): try: docs = [ Doc( id=f"{thread_id}_{i}", fields={ "id": int(f"{thread_id}{i}"), "name": f"thread_{thread_id}_doc_{i}", "weight": float(i), }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i), 2: float(i * 2)}, }, ) for i in range(5) ] result = test_collection.insert(docs) results.append((thread_id, "insert", len(result))) except Exception as e: results.append((thread_id, "insert_exception", str(e))) def query_docs(thread_id): try: result = test_collection.query(filter="id > 0", topk=10) results.append((thread_id, "query", len(result))) except Exception as e: results.append((thread_id, "query_exception", str(e))) # Create threads for concurrent operations threads = [] # Start insert threads for i in range(3): thread = threading.Thread(target=insert_docs, args=(i,)) threads.append(thread) thread.start() # Start query threads for i in range(3): thread = threading.Thread(target=query_docs, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Analyze results insert_results = [r for r in results if r[1] == "insert"] query_results = [r for r in results if r[1] == "query"] logging.info( f"Concurrent read/write results - Inserts: {len(insert_results)}, Queries: {len(query_results)}" ) # At least some operations should succeed assert len(insert_results) + len(query_results) > 0 def test_concurrent_query(self, test_collection: Collection): # First insert some data docs = [ Doc( id=f"{i}", fields={"id": i, "name": f"test_{i}", "weight": float(i)}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i), 2: float(i * 2)}, }, ) for i in range(20) ] insert_result = test_collection.insert(docs) assert len(insert_result) == 20 results = [] def query_operation(thread_id): """Perform query operation from a thread""" try: result = test_collection.query(filter=f"id > {thread_id}", topk=5) results.append((thread_id, "query", len(result))) except Exception as e: results.append((thread_id, "query_exception", str(e))) # Create multiple threads for concurrent queries threads = [] for i in range(5): thread = threading.Thread(target=query_operation, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Analyze results query_results = [r for r in results if r[1] == "query"] logging.info(f"Concurrent query results - Queries: {len(query_results)}") # All query operations should succeed assert len(query_results) == 5 def test_concurrent_modifications(self, test_collection: Collection): # First insert some data docs = [ Doc( id=f"{i}", fields={"id": i, "name": f"test_{i}", "weight": float(i)}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i), 2: float(i * 2)}, }, ) for i in range(10) ] insert_result = test_collection.insert(docs) assert len(insert_result) == 10 results = [] def update_operation(thread_id): """Perform update operation from a thread""" try: # Each thread updates different documents update_docs = [ Doc( id=f"{i}", fields={ "id": i, "name": f"updated_by_thread_{thread_id}", "weight": float(i + thread_id), }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i) + 0.5, 2: float(i * 2) + 0.5}, }, ) for i in range(thread_id * 2, thread_id * 2 + 2) ] result = test_collection.update(update_docs) results.append((thread_id, "update", len(result))) except Exception as e: results.append((thread_id, "update_exception", str(e))) def delete_operation(thread_id): """Perform delete operation from a thread""" try: # Each thread deletes different documents delete_ids = [f"{thread_id * 2 + 2}", f"{thread_id * 2 + 3}"] result = test_collection.delete(delete_ids) results.append((thread_id, "delete", len(result))) except Exception as e: results.append((thread_id, "delete_exception", str(e))) # Create threads for concurrent operations threads = [] # Start update threads for i in range(3): thread = threading.Thread(target=update_operation, args=(i,)) threads.append(thread) thread.start() # Start delete threads for i in range(2): thread = threading.Thread(target=delete_operation, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Analyze results update_results = [r for r in results if r[1] == "update"] delete_results = [r for r in results if r[1] == "delete"] logging.info( f"Concurrent modification results - Updates: {len(update_results)}, Deletes: {len(delete_results)}" ) # At least some operations should succeed assert len(update_results) + len(delete_results) > 0 def test_read_write_locking(self, test_collection: Collection): # Perform operations that should be thread-safe docs = [ Doc( id=f"{i}", fields={"id": i, "name": f"test_{i}", "weight": float(i)}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i), 2: float(i * 2)}, }, ) for i in range(5) ] # Insert data insert_result = test_collection.insert(docs) assert len(insert_result) == 5 # Concurrent operations should not cause data corruption results = [] def mixed_operation(thread_id): """Perform mixed operations from a thread""" try: # Mix of read and write operations if thread_id % 2 == 0: # Read operation result = test_collection.fetch([f"{thread_id % 5}"]) results.append((thread_id, "read", len(result))) else: # Write operation doc = Doc( id=f"{thread_id % 5}", fields={ "id": thread_id % 5, "name": f"mixed_op_{thread_id}", "weight": float(thread_id), }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(thread_id), 2: float(thread_id * 2)}, }, ) result = test_collection.upsert(doc) results.append((thread_id, "write", len(result))) except Exception as e: results.append((thread_id, "exception", str(e))) # Create multiple threads threads = [] for i in range(10): thread = threading.Thread(target=mixed_operation, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify that the collection is still in a consistent state final_result = test_collection.query() assert len(final_result) >= 0 # Should not crash or return corrupted data def test_race_condition_detection(self, test_collection: Collection): # Insert initial data docs = [ Doc( id=f"{i}", fields={"id": i, "name": f"initial_{i}", "weight": float(i)}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(i), 2: float(i * 2)}, }, ) for i in range(10) ] insert_result = test_collection.insert(docs) assert len(insert_result) == 10 # Perform many rapid concurrent operations operation_count = 100 results = [] def rapid_operation(op_id): """Perform rapid operations""" try: # Alternate between different types of operations if op_id % 4 == 0: # Insert doc = Doc( id=f"rapid_{op_id}", fields={ "id": op_id, "name": f"rapid_{op_id}", "weight": float(op_id), }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(op_id), 2: float(op_id * 2)}, }, ) result = test_collection.insert(doc) results.append(("insert", len(result))) elif op_id % 4 == 1: # Update doc = Doc( id=f"{op_id % 10}", fields={ "id": op_id % 10, "name": f"rapid_update_{op_id}", "weight": float(op_id), }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: float(op_id), 2: float(op_id * 2)}, }, ) result = test_collection.update(doc) results.append(("update", len(result))) elif op_id % 4 == 2: # Query result = test_collection.query(filter=f"id > {op_id % 5}", topk=3) results.append(("query", len(result))) else: # Fetch result = test_collection.fetch([f"{op_id % 10}"]) results.append(("fetch", len(result))) except Exception as e: results.append(("exception", str(e))) # Create many threads for rapid concurrent operations threads = [] for i in range(operation_count): thread = threading.Thread(target=rapid_operation, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify collection is still functional final_query = test_collection.query() assert len(final_query) >= 0 # Should not be corrupted logging.info( f"Rapid concurrent operations completed - Total operations: {len(results)}" ) ================================================ FILE: python/tests/detail/test_collection_create_and_open.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import os from distance_helper import * from fixture_helper import * from doc_helper import * from params_helper import * def check_collection_info( coll: Collection, schema: CollectionSchema, option: CollectionOption, path: str ): assert coll is not None, "Failed to create and open collection" assert coll.path == path assert coll.schema.name == schema.name assert list(coll.schema.fields) == list(schema.fields) assert list(coll.schema.vectors) == list(schema.vectors) assert coll.option.read_only == option.read_only assert coll.option.enable_mmap == option.enable_mmap def check_collection_basic(coll: Collection, optimize: bool = False): schema = coll.schema docs = [generate_doc(i, schema) for i in range(10)] results = coll.insert(docs=docs) assert len(results) == len(docs) for result in results: assert result.ok() assert coll.stats.doc_count == len(docs) def check_fetch_query(): results = coll.fetch([str(i) for i in range(len(docs))]) assert len(results) == len(docs) for i in range(len(docs)): assert str(i) in results results = coll.query() assert len(results) == len(docs) check_fetch_query() if optimize: coll.optimize() check_fetch_query() def check_collection_full(coll: Collection): test_doc = generate_doc(1, coll.schema) insert_result = coll.insert(test_doc) assert insert_result.ok() stats = coll.stats assert stats.doc_count == 1 fetched_docs = coll.fetch(ids=["1"]) assert len(fetched_docs) == 1 assert "1" in fetched_docs assert fetched_docs["1"] is not None assert is_doc_equal(fetched_docs["1"], test_doc, coll.schema) query_result = coll.query() assert len(query_result) == 1 updated_doc = Doc( id="1", fields={"int32_field": 1}, vectors={"vector_fp32_field": [0.2] * 128}, ) update_result = coll.update(updated_doc) assert update_result.ok() upserted_doc = generate_doc(1, coll.schema) upsert_result = coll.upsert(upserted_doc) assert upsert_result.ok() # 8. Delete document delete_result = coll.delete("1") assert delete_result.ok() # Verify document was deleted stats = coll.stats assert stats.doc_count == 0 valid_collection_options = [ # (read_only, enable_mmap) (False, True), (False, False), ] invalid_collection_options = [ # (read_only, enable_mmap) (True, True), (True, False), ] duplicate_names_test = [ ("field1", "field1", "vector1", "vector2"), ("field1", "field2", "vector1", "vector1"), ( "shared_name1", "shared_name2", "shared_name1", "shared_name2", ), ] long_names = [ "a" * 100, # 100 characters "b" * 200, # 200 characters ] valid_path_list = [ "/tmp/nonexistent/directory/test_collection", "test/collection/with/slashes", ] invalid_path_list = [ "invalid:path", "", "test_collection_with_spaces ", "test@#$%collection", ] class TestCreateAndOpen: @pytest.mark.parametrize("collection_name", COLLECTION_NAME_VALID_LIST) def test_valid_collection_name( self, collection_temp_dir, collection_name, collection_option, sample_field_list, sample_vector_list, ): collection_schema = zvec.CollectionSchema( name=collection_name, fields=sample_field_list, vectors=sample_vector_list, ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) check_collection_info( coll, collection_schema, collection_option, collection_temp_dir ) check_collection_basic(coll) coll.destroy() @pytest.mark.parametrize("collection_name", COLLECTION_NAME_INVALID_LIST) def test_invalid_collection_name( self, collection_temp_dir, collection_name, collection_option, sample_field_list, sample_vector_list, ): with pytest.raises(Exception) as exc_info: collection_schema = zvec.CollectionSchema( name=collection_name, fields=sample_field_list, vectors=sample_vector_list, ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize("name_prefix", FIELD_NAME_VALID_LIST) def test_valid_field_vector_name( self, collection_temp_dir, collection_option, name_prefix, sample_field_list, sample_vector_list, ): collection_schema = zvec.CollectionSchema( name="test_collection", fields=sample_field_list, vectors=sample_vector_list, ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) check_collection_info( coll, collection_schema, collection_option, collection_temp_dir ) check_collection_basic(coll) coll.destroy() @pytest.mark.parametrize("field_name", FIELD_NAME_INVALID_LIST) def test_invalid_field_name( self, collection_temp_dir, collection_option, field_name ): with pytest.raises(Exception) as exc_info: field_list = [FieldSchema(field_name, DataType.STRING)] vector_list = [ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ] collection_schema = zvec.CollectionSchema( name="collection_name", fields=field_list, vectors=vector_list ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize("vector_name", FIELD_NAME_INVALID_LIST) def test_invalid_vector_name( self, collection_temp_dir, collection_option, vector_name ): with pytest.raises(Exception) as exc_info: field_list = [ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ) ] vector_list = [ VectorSchema(vector_name, DataType.VECTOR_FP32, dimension=128) ] collection_schema = zvec.CollectionSchema( name="collection_name", fields=field_list, vectors=vector_list ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize( "field_list_len,vector_list_len,dimension", FIELD_VECTOR_LIST_DIMENSION_VALID_LIST, ) def test_valid_field_vector_size_dimension( self, collection_temp_dir, collection_option, field_list_len, vector_list_len, dimension, ): field_list = [] vector_list = [] for i in range(0, field_list_len): field_list.append( FieldSchema("id_" + str(i), DataType.INT64, nullable=True) ) for i in range(0, vector_list_len): vector_list.append( VectorSchema( "dense_vector_" + str(i), DataType.VECTOR_FP32, dimension=dimension, index_param=HnswIndexParam(), ) ) collection_schema = zvec.CollectionSchema( name="test_dense_vector_list", fields=field_list, vectors=vector_list ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) check_collection_info( coll, collection_schema, collection_option, collection_temp_dir ) check_collection_basic(coll) coll.destroy() @pytest.mark.parametrize( "field_list_len,vector_list_len,dimension", FIELD_VECTOR_LIST_DIMENSION_INVALID_LIST, ) def test_invalid_field_vector_size_dimension( self, collection_temp_dir, collection_option, vector_list_len, field_list_len, dimension, ): with pytest.raises(Exception) as exc_info: field_list = [] vector_list = [] for i in range(0, field_list_len): field_list.append( FieldSchema( "id_" + str(i), DataType.INT64, nullable=False, ) ) for i in range(0, vector_list_len): vector_list.append( VectorSchema( "dense_vector_" + str(i), DataType.VECTOR_FP32, dimension=dimension, index_param=HnswIndexParam(), ) ) collection_schema = zvec.CollectionSchema( name="test_dense_vector_list", fields=field_list, vectors=vector_list ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) def test_valid_single_vector_field_construction( self, collection_temp_dir, collection_option ): field = FieldSchema( "id", DataType.INT64, nullable=True, index_param=InvertIndexParam(enable_range_optimization=True), ) vector = VectorSchema( "dense_vector", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) collection_schema = zvec.CollectionSchema( name="test_single_dense_vector_non_list", fields=field, vectors=vector, # Non-list form ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) check_collection_info( coll, collection_schema, collection_option, collection_temp_dir ) check_collection_basic(coll) coll.destroy() def test_collection_concurrent_create( self, collection_temp_dir, basic_schema, collection_option ): results = [] errors = [] lock = threading.Lock() # Function to be executed by each thread def create_collection_thread(thread_id): try: coll = zvec.create_and_open( path=collection_temp_dir, schema=basic_schema, option=collection_option, ) with lock: results.append((thread_id, coll)) except Exception as e: with lock: errors.append((thread_id, str(e))) threads = [] for i in range(5): thread = threading.Thread(target=create_collection_thread, args=(i,)) threads.append(thread) thread.start() for thread in threads: thread.join() assert len(results) == 1, ( f"Expected exactly one successful creation, but got {len(results)}" ) assert len(errors) == 4, ( f"Expected exactly four failures, but got {len(errors)}" ) successful_thread_id, successful_collection = results[0] assert successful_collection is not None, ( "Successful creation should return a valid collection" ) assert successful_collection.path == collection_temp_dir, ( "Collection path mismatch" ) def test_create_open_loop( self, collection_temp_dir, collection_option, full_schema ): for cycle in range(10): coll = zvec.create_and_open( path=collection_temp_dir, schema=full_schema, option=collection_option, ) assert coll is not None, ( f"Failed to create and open collection in cycle {cycle}" ) assert coll.path == collection_temp_dir, ( f"Collection path mismatch in cycle {cycle}" ) del coll reopened_coll = zvec.open( path=collection_temp_dir, option=collection_option ) assert reopened_coll is not None, ( f"Failed to reopen collection in cycle {cycle}" ) assert reopened_coll.path == collection_temp_dir, ( f"Reopened collection path mismatch in cycle {cycle}" ) check_collection_full(reopened_coll) reopened_coll.destroy() @pytest.mark.parametrize( "data_type, index_param", VALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP_PARAMS ) def test_valid_vector_index_params( self, data_type, index_param, single_vector_schema_with_index_param, collection_temp_dir, collection_option, ): coll = zvec.create_and_open( path=collection_temp_dir, schema=single_vector_schema_with_index_param, option=collection_option, ) check_collection_info( coll, single_vector_schema_with_index_param, collection_option, collection_temp_dir, ) check_collection_basic(coll, True) @pytest.mark.parametrize( "data_type, index_param", INVALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP_PARAMS ) def test_invalid_vector_index_params( self, data_type, index_param, single_vector_schema_with_index_param, collection_temp_dir, collection_option, ): with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=collection_temp_dir, schema=single_vector_schema_with_index_param, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) def test_open_concurrent_same_path(self, tmp_path_factory, collection_option): """Test concurrent opening of the same collection path. - Multi-threading concurrency: 5 threads simultaneously open the same collection - Result verification: Verify that only one can open successfully, others must fail """ # Create a temporary directory and path for the collection temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "concurrent_open_test_collection" # First, create a collection that we'll try to open concurrently field_list = [ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam() ), ] vector_list = [ VectorSchema( "dense_vector", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ] collection_schema = zvec.CollectionSchema( name="concurrent_open_test_collection", fields=field_list, vectors=vector_list, ) # Create the collection first coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) # Close the collection so we can test opening it if hasattr(coll, "close") and coll is not None: coll.close() # Shared variables to collect results from threads results = [] errors = [] # Lock for thread-safe operations lock = threading.Lock() # Clean up the created collection reference del coll # Function to be executed by each thread def open_collection_thread(thread_id): try: reopened_coll = zvec.open( path=str(collection_path), option=collection_option ) with lock: results.append((thread_id, reopened_coll)) # Clean up the collection if opened successfully if hasattr(reopened_coll, "close") and reopened_coll is not None: reopened_coll.close() except Exception as e: with lock: errors.append((thread_id, str(e))) # Create and start 5 threads threads = [] for i in range(5): thread = threading.Thread(target=open_collection_thread, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify results: # 1. Only one open should succeed (exactly one collection in results) # 2. Others should fail (4 errors in errors) assert len(results) == 1, ( f"Expected exactly one successful open, but got {len(results)}" ) assert len(errors) == 4, ( f"Expected exactly four failures, but got {len(errors)}" ) # Additional verification: check that the successful open has a valid collection successful_thread_id, successful_collection = results[0] assert successful_collection is not None, ( "Successful open should return a valid collection" ) assert successful_collection.path == str(collection_path), ( "Collection path mismatch" ) @pytest.mark.parametrize("read_only,enable_mmap", valid_collection_options) def test_valid_option( self, collection_temp_dir, basic_schema, read_only, enable_mmap ): option = CollectionOption(read_only=read_only, enable_mmap=enable_mmap) coll = zvec.create_and_open( path=collection_temp_dir, schema=basic_schema, option=option, ) check_collection_info(coll, basic_schema, option, collection_temp_dir) check_collection_basic(coll) coll.destroy() def test_valid_none_option(self, collection_temp_dir, basic_schema): zvec.create_and_open( path=collection_temp_dir, schema=basic_schema, option=None, ) @pytest.mark.parametrize("read_only,enable_mmap", invalid_collection_options) def test_invalid_option( self, collection_temp_dir, basic_schema, read_only, enable_mmap ): with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=collection_temp_dir, schema=basic_schema, option=CollectionOption(read_only=read_only, enable_mmap=enable_mmap), ) assert CREATE_READ_ONLY_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize( "field_name1,field_name2,vector_name1,vector_name2", duplicate_names_test, ) def test_duplicate_field_names( self, collection_temp_dir, collection_option, field_name1, field_name2, vector_name1, vector_name2, ): with pytest.raises(Exception) as exc_info: collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( field_name1, DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( field_name2, DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), ], vectors=[ VectorSchema( vector_name1, DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( vector_name2, DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), ], ) coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize("long_name", long_names) def test_invalid_long_field_names( self, collection_option, collection_temp_dir, long_name ): collection_schema = zvec.CollectionSchema( name=long_name, fields=[ FieldSchema( long_name + "_field", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), ], vectors=[ VectorSchema( long_name + "_vector", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) def test_invalid_empty_fields_and_vectors( self, collection_temp_dir, collection_option ): collection_schema = zvec.CollectionSchema( name="test_collection", fields=[], # Empty fields vectors=[], # Empty vectors ) with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=collection_temp_dir, schema=collection_schema, option=collection_option, ) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value), str(exc_info.value) @pytest.mark.parametrize("valid_path", valid_path_list) def test_valid_path(self, basic_schema, collection_option, valid_path): if os.path.exists(valid_path): import shutil shutil.rmtree(valid_path) coll = zvec.create_and_open( path=valid_path, schema=basic_schema, option=collection_option ) check_collection_info(coll, basic_schema, collection_option, valid_path) coll.destroy() @pytest.mark.parametrize("invalid_path", invalid_path_list) def test_invalid_path(self, basic_schema, collection_option, invalid_path): with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=invalid_path, schema=basic_schema, option=collection_option ) assert INVALID_PATH_ERROR_MSG in str(exc_info.value), str(exc_info.value) ================================================ FILE: python/tests/detail/test_collection_ddl.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from distance_helper import * from fixture_helper import * from doc_helper import * from params_helper import * class TestDDL: def test_collection_stats(self, basic_collection: Collection): assert basic_collection.stats is not None stats = basic_collection.stats assert stats.doc_count == 0 assert len(stats.index_completeness) == 2 assert stats.index_completeness["dense"] == 1 assert stats.index_completeness["sparse"] == 1 def test_collection_destroy( self, basic_collection: Collection, collection_temp_dir, collection_option ): doc = generate_doc(1, basic_collection.schema) result = basic_collection.insert(doc) assert bool(result) assert result.ok() stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 basic_collection.destroy() with pytest.raises(Exception) as exc_info: stats = basic_collection.stats assert ACCESS_DESTROYED_COLLECTION_ERROR_MSG in str(exc_info.value) with pytest.raises(Exception) as exc_info: zvec.open(path=collection_temp_dir, option=collection_option) assert COLLECTION_PATH_NOT_EXIST_ERROR_MSG in str(exc_info.value) def test_collection_flush(self, basic_collection: Collection): doc = generate_doc(1, basic_collection.schema) result = basic_collection.insert(doc) assert bool(result) assert result.ok() basic_collection.flush() fetched_docs = basic_collection.fetch(["1"]) assert "1" in fetched_docs assert fetched_docs["1"].id == "1" class TestIndexDDL: @pytest.mark.parametrize("field_name", DEFAULT_SCALAR_FIELD_NAME.values()) @pytest.mark.parametrize("index_type", SUPPORT_SCALAR_INDEX_TYPES) def test_scalar_index_operation( self, full_collection: Collection, field_name: str, index_type: IndexType, ): # INSERT 0~5 Doc docs = [generate_doc(i, full_collection.schema) for i in range(5)] result = full_collection.insert(docs) assert len(result) == 5 for item in result: assert item.ok() stats = full_collection.stats assert stats is not None assert stats.doc_count == 5 if field_name in ["bool_field"]: query_filter = f"{field_name} = true" elif field_name in ["double_field", "float_field"]: query_filter = f"{field_name} >= 3.0" elif field_name in [ "int32_field", "int64_field", "uint32_field", "uint64_field", ]: query_filter = f"{field_name} >= 30" elif field_name in ["string_field"]: query_filter = f"{field_name} >= 'test_3'" elif field_name in ["array_bool_field"]: query_filter = f"{field_name} contain_any (false)" elif field_name in ["array_double_field", "array_float_field"]: query_filter = f"{field_name} contain_any (3.0, 4.0)" elif field_name in [ "array_int64_field", "array_int32_field", "array_uint64_field", "array_uint32_field", ]: query_filter = f"{field_name} contain_any (3, 4)" elif field_name == "array_string_field": query_filter = f"{field_name} contain_any ('test_3', 'test_4')" else: assert False, f"Unsupported field type for index creation: {field_name}" query_result_before = full_collection.query(filter=query_filter, topk=10) if index_type not in DEFAULT_INDEX_PARAMS: pytest.fail(f"Unsupported index type for index creation: {index_type}") index_param = DEFAULT_INDEX_PARAMS[index_type] full_collection.create_index( field_name=field_name, index_param=index_param, option=IndexOption() ) stats_after_create = full_collection.stats assert stats_after_create is not None assert stats_after_create.doc_count == 5 query_result_after = full_collection.query(filter=query_filter, topk=10) assert len(query_result_before) == len(query_result_after), ( f"Query result count mismatch for {field_name} with index type {index_type}: before={len(query_result_before)}, after={len(query_result_after)}" ) before_ids = set(doc.id for doc in query_result_before) after_ids = set(doc.id for doc in query_result_after) assert before_ids == after_ids, ( f"Query result IDs mismatch for {field_name} with index type {index_type}: before={before_ids}, after={after_ids}" ) # INSERT 5~8 Doc new_docs = [generate_doc(i, full_collection.schema) for i in range(5, 8)] result = full_collection.insert(new_docs) assert len(result) == 3 for item in result: assert item.ok() stats_after_insert1 = full_collection.stats assert stats_after_insert1 is not None assert stats_after_insert1.doc_count == 8 fetched_docs = full_collection.fetch([f"{i}" for i in range(5, 8)]) assert len(fetched_docs) == 3 for i in range(5, 8): doc_id = f"{i}" assert doc_id in fetched_docs query_result = full_collection.query(filter=query_filter, topk=20) assert len(query_result) >= len(query_result_before) full_collection.drop_index(field_name=field_name) # Insert 8~10 Doc more_docs = [generate_doc(i, full_collection.schema) for i in range(8, 10)] result = full_collection.insert(more_docs) assert len(result) == 2 for item in result: assert item.ok() stats_after_insert2 = full_collection.stats assert stats_after_insert2 is not None assert stats_after_insert2.doc_count == 10 fetched_docs = full_collection.fetch([f"{i}" for i in range(8, 10)]) assert len(fetched_docs) == 2 for i in range(8, 10): doc_id = f"{i}" assert doc_id in fetched_docs query_result = full_collection.query(filter=query_filter, topk=20) assert len(query_result) >= len(query_result_before) final_stats = full_collection.stats assert final_stats is not None assert final_stats.doc_count == 10 full_collection.destroy() @pytest.mark.parametrize("field_name", DEFAULT_SCALAR_FIELD_NAME.values()) @pytest.mark.parametrize("index_type", SUPPORT_SCALAR_INDEX_TYPES) def test_duplicate_create_index( self, full_collection: Collection, field_name: str, index_type: IndexType ): docs = [generate_doc(i, full_collection.schema) for i in range(10)] result = full_collection.insert(docs) assert bool(result) for item in result: assert item.ok() stats = full_collection.stats assert stats is not None assert stats.doc_count == 10 if field_name in ["bool_field"]: query_filter = f"{field_name} = true" elif field_name in ["double_field", "float_field"]: query_filter = f"{field_name} >= 3.0" elif field_name in [ "int32_field", "int64_field", "uint32_field", "uint64_field", ]: query_filter = f"{field_name} >= 30" elif field_name in ["string_field"]: query_filter = f"{field_name} >= 'test_3'" elif field_name in ["array_bool_field"]: query_filter = f"{field_name} contain_any (false)" elif field_name in ["array_double_field", "array_float_field"]: query_filter = f"{field_name} contain_any (3.0, 4.0)" elif field_name in [ "array_int64_field", "array_int32_field", "array_uint64_field", "array_uint32_field", ]: query_filter = f"{field_name} contain_any (3, 4)" elif field_name == "array_string_field": query_filter = f"{field_name} contain_any ('test_3', 'test_4')" else: assert False, f"Unsupported field type for index creation: {field_name}" query_result_before = full_collection.query(filter=query_filter, topk=5) if index_type not in DEFAULT_INDEX_PARAMS: pytest.fail(f"Unsupported index type for index creation: {index_type}") index_param = DEFAULT_INDEX_PARAMS[index_type] full_collection.create_index( field_name=field_name, index_param=index_param, option=IndexOption() ) query_result_after = full_collection.query(filter=query_filter, topk=5) assert len(query_result_before) == len(query_result_after), ( f"Query result count mismatch: before={len(query_result_before)}, after={len(query_result_after)}" ) before_ids = set(doc.id for doc in query_result_before) after_ids = set(doc.id for doc in query_result_after) assert before_ids == after_ids, ( f"Query result IDs mismatch: before={before_ids}, after={after_ids}" ) full_collection.create_index( field_name=field_name, index_param=index_param, option=IndexOption() ) def test_optimize(self, full_collection: Collection): docs = [generate_doc(i, full_collection.schema) for i in range(10)] result = full_collection.insert(docs) assert bool(result) for item in result: assert item.ok() stats = full_collection.stats assert stats is not None assert stats.doc_count == 10 full_collection.optimize(option=OptimizeOption()) fetched_docs = full_collection.fetch(["1"]) assert "1" in fetched_docs assert fetched_docs["1"].id == "1" @pytest.mark.parametrize( "vector_type, index_type", SUPPORT_VECTOR_DATA_TYPE_INDEX_MAP_PARAMS ) def test_vector_index_operation( self, full_collection: Collection, vector_type: DataType, index_type: IndexType, ): vector_field_name = DEFAULT_VECTOR_FIELD_NAME[vector_type] docs = [generate_doc(i, full_collection.schema) for i in range(5)] result = full_collection.insert(docs) assert len(result) == 5, ( f"Expected 5 insertion results, got {len(result)} for vector type {vector_type} and index type {index_type}" ) for i, item in enumerate(result): assert item.ok(), ( f"Before create_index,result={result},Insertion result {i} is not OK for vector type {vector_type} and index type {index_type} and result={result}" ) stats = full_collection.stats assert stats is not None, ( f"stats is None for vector type {vector_type} and index type {index_type}" ) assert stats.doc_count == 5, ( f"doc_count!=5 for vector type {vector_type} and index type {index_type}" ) if index_type not in DEFAULT_INDEX_PARAMS: pytest.fail( f"Unsupported index type {index_type} for vector type {vector_type} in test_vector_all_data_types_index_create_drop_validation" ) index_param = DEFAULT_INDEX_PARAMS[index_type] full_collection.create_index( field_name=vector_field_name, index_param=index_param, option=IndexOption(), ) stats_after_create = full_collection.stats assert stats_after_create is not None, ( f"stats_after_create_index is None for vector type {vector_type} and index type {index_type}" ) new_docs = [generate_doc(i, full_collection.schema) for i in range(5, 8)] result = full_collection.insert(new_docs) assert len(result) == 3, ( f"Expected 3 insertion results, got {len(result)} for vector type {vector_type} and index type {index_type}" ) for i, item in enumerate(result): assert item.ok(), ( f"Before drop_index,result={result},BInsertion result {i} is not OK for vector type {vector_type} and index type {index_type} and " ) stats_after_insert1 = full_collection.stats assert stats_after_insert1 is not None, ( f"stats_after_insert1 is None for vector type {vector_type} and index type {index_type}" ) assert stats_after_insert1.doc_count == 8, ( f"Expected 8 documents, got {stats_after_insert1.doc_count} for vector type {vector_type} and index type {index_type}" ) fetched_docs = full_collection.fetch([f"{i}" for i in range(5, 8)]) assert len(fetched_docs) == 3, ( f"Expected 3 fetched documents, got {len(fetched_docs)} for vector type {vector_type} and index type {index_type}" ) for i in range(5, 8): doc_id = f"{i}" assert doc_id in fetched_docs, ( f"Document ID {doc_id} not found in fetched results for vector type {vector_type} and index type {index_type}" ) assert fetched_docs[doc_id].id == doc_id, ( f"Document {doc_id} has incorrect ID field value for vector type {vector_type} and index type {index_type}" ) full_collection.drop_index(field_name=vector_field_name) more_docs = [generate_doc(i, full_collection.schema) for i in range(8, 10)] result = full_collection.insert(more_docs) assert len(result) == 2, ( f"Expected 2 insertion results, got {len(result)} for vector type {vector_type} and index type {index_type}" ) for i, item in enumerate(result): assert item.ok(), ( f"After drop_index,Insertion result {i} is not OK for vector type {vector_type} and index type {index_type} and result={result}" ) # Verify document count after second insertion stats_after_insert2 = full_collection.stats assert stats_after_insert2 is not None, ( f"stats_after_insert2 is None for vector type {vector_type} and index type {index_type}" ) assert stats_after_insert2.doc_count == 10, ( f"Expected 10 documents, got {stats_after_insert2.doc_count} for vector type {vector_type} and index type {index_type}" ) # Fetch data fetched_docs = full_collection.fetch([f"{i}" for i in range(8, 10)]) assert len(fetched_docs) == 2, ( f"Expected 2 fetched documents, got {len(fetched_docs)} for vector type {vector_type} and index type {index_type}" ) # Verify fetched documents have correct data for i in range(8, 10): doc_id = f"{i}" assert doc_id in fetched_docs, ( f"Document ID {doc_id} not found in fetched results for vector type {vector_type} and index type {index_type}" ) assert fetched_docs[doc_id].id == doc_id, ( f"Document {doc_id} has incorrect ID field value for vector type {vector_type} and index type {index_type}" ) # Final verification final_stats = full_collection.stats assert final_stats is not None, ( f"final_stats is None for vector type {vector_type} and index type {index_type}" ) assert final_stats.doc_count == 10, ( f"Expected 10 documents, got {final_stats.doc_count} for vector type {vector_type} and index type {index_type}" ) full_collection.destroy() @staticmethod def create_collection( collection_path, collection_option: CollectionOption ) -> Collection: schema = CollectionSchema( name="test_collection_invalid_vector_index", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=True, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), ], ) coll = zvec.create_and_open( path=collection_path, schema=schema, option=collection_option ) assert coll is not None, "Failed to create and open collection" return coll @staticmethod def check_error_message(exc_info, invalid_name): if type(invalid_name) is str: assert INDEX_NON_EXISTENT_COLUMN_ERROR_MSG in str(exc_info.value), ( "Error message is unreasonable: e=" + str(exc_info.value) ) else: assert INCOMPATIBLE_FUNCTION_ERROR_MSG in str(exc_info.value), ( "Error message is unreasonable: e=" + str(exc_info.value) ) @pytest.mark.parametrize( "invalid_field_name,invalid_vector_name", [ ("", ""), # Empty string (" ", " "), # Space only ("v" * 33, "v" * 33), # Too long (33 characters, exceeds 32) ("vector name", "vector_name"), # Contains space ("vector@name", "vector@name"), # Contains special character ("vector/name", "vector/name"), # Contains slash ("vector\\name", "vector\\name"), # Contains backslash ("vector.name", "vector.name"), # Contains dot ("vector$data", "vector$data"), # Contains dollar sign ("vector+name", "vector+name"), # Contains plus sign ("vector=name", "vector=name"), # Contains equals sign (None, None), # None value, (1, 1), (1.1, 1.1), ], ) def test_invalid_field_and_vector_name( self, collection_temp_dir, collection_option: CollectionOption, invalid_field_name: Any, invalid_vector_name: Any, ): coll = self.create_collection(collection_temp_dir, collection_option) with pytest.raises(Exception) as exc_info: coll.create_index( field_name=invalid_vector_name, index_param=HnswIndexParam(), option=IndexOption(), ) self.check_error_message(exc_info, invalid_vector_name) with pytest.raises(Exception) as exc_info: coll.create_index( field_name=invalid_field_name, index_param=InvertIndexParam(), option=IndexOption(), ) self.check_error_message(exc_info, invalid_field_name) coll.destroy() coll = self.create_collection(collection_temp_dir, collection_option) with pytest.raises(Exception) as exc_info: coll.drop_index(field_name=invalid_vector_name) self.check_error_message(exc_info, invalid_vector_name) with pytest.raises(Exception) as exc_info: coll.drop_index(field_name=invalid_field_name) self.check_error_message(exc_info, invalid_field_name) coll.destroy() @pytest.mark.parametrize( "field_name,vector_name", [ ("2", "3"), ("col", "co1"), ("ID", "IM"), ("name-1", "name2"), ("Weigt_12", "Weigt_13"), ("123age", "123agl"), ], ) def test_valid_field_and_vector_name( self, collection_temp_dir, collection_option: CollectionOption, field_name: str, vector_name: str, ): schema = zvec.CollectionSchema( name="test_index_names", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema(field_name, DataType.STRING, nullable=True), ], vectors=[ VectorSchema( vector_name, DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) coll = zvec.create_and_open( path=collection_temp_dir, schema=schema, option=collection_option ) assert coll is not None, ( f"Failed to create and open collection with field_name={field_name}, vector_name={vector_name}" ) # Insert some data docs = [ Doc( id=f"{i}", fields={"id": i, field_name: f"value_{i}"}, vectors={vector_name: [float(j % 10) for j in range(128)]}, ) for i in range(5) ] result = coll.insert(docs) assert len(result) == 5, ( f"Expected 5 insertion results, got {len(result)} for field_name={field_name}, vector_name={vector_name}" ) for item in result: assert item.ok(), ( f"Insertion failed for field_name={field_name}, vector_name={vector_name}: {item}" ) # Create index on field coll.create_index( field_name=field_name, index_param=InvertIndexParam(), option=IndexOption(), ) # Create index on vector coll.create_index( field_name=vector_name, index_param=HnswIndexParam(), option=IndexOption(), ) # Verify indexes were created successfully stats = coll.stats assert stats is not None, ( f"Stats is None for field_name={field_name}, vector_name={vector_name}" ) coll.destroy() def test_compicated_workflow( self, collection_temp_dir, basic_schema: CollectionSchema, collection_option: CollectionOption, ): """ Test the complete workflow: 1. Create collection 2. Create index 3. Insert doc 4. Upsert 5. Update doc 6. Fetch doc 7. Query doc 8. Drop index 9. Insert doc 10. Update doc 11. Upsert doc 12. Fetch doc 13. Query doc 14. Flush 15. Destroy """ # Step 1: Create collection coll = zvec.create_and_open( path=collection_temp_dir, schema=basic_schema, option=collection_option, ) assert coll is not None, "Failed to create and open collection" assert coll.path == collection_temp_dir assert coll.schema.name == basic_schema.name assert coll.stats.doc_count == 0 # Step 2: Create index coll.create_index( field_name="name", index_param=InvertIndexParam(), option=IndexOption() ) # Verify index was created stats = coll.stats assert stats is not None, "coll.stats is None!" # Step 3: Insert doc doc1 = Doc( id="1", fields={"id": 1, "name": "test1", "weight": 80.5}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = coll.insert(doc1) assert bool(result) assert result.ok() assert coll.stats.doc_count == 1 # Step 4: Upsert (existing doc) doc1_updated = Doc( id="1", fields={"id": 1, "name": "test1_updated", "weight": 85.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.5, 2: 2.5}, }, ) result = coll.upsert(doc1_updated) assert bool(result) assert result.ok() assert coll.stats.doc_count == 1 # Step 5: Update doc doc2 = Doc( id="2", fields={"id": 2, "name": "test2", "weight": 90.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 3.0, 2: 4.0}, }, ) # First insert doc2 result = coll.insert(doc2) assert bool(result) assert result.ok() assert coll.stats.doc_count == 2 # Then update it doc2_updated = Doc( id="2", fields={"id": 2, "name": "test2_updated", "weight": 95.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 3.5, 2: 4.5}, }, ) result = coll.update(doc2_updated) assert bool(result) assert result.ok() assert coll.stats.doc_count == 2 # Step 6: Fetch doc fetched_docs = coll.fetch(["1", "2"]) assert len(fetched_docs) == 2 assert "1" in fetched_docs assert "2" in fetched_docs assert fetched_docs["1"].field("name") == "test1_updated" assert fetched_docs["2"].field("name") == "test2_updated" # Step 7: Query doc query_result = coll.query(filter="id >= 1", topk=10) assert len(query_result) == 2 # Step 8: Drop index coll.drop_index(field_name="name") # Step 9: Insert doc doc3 = Doc( id="3", fields={"id": 3, "name": "test3", "weight": 100.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 5.0, 2: 6.0}, }, ) result = coll.insert(doc3) assert bool(result) assert result.ok() assert coll.stats.doc_count == 3 # Step 10: Update doc doc3_updated = Doc( id="3", fields={"id": 3, "name": "test3_updated", "weight": 105.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 5.5, 2: 6.5}, }, ) result = coll.update(doc3_updated) assert bool(result) assert result.ok() assert coll.stats.doc_count == 3 # Step 11: Upsert doc doc4 = Doc( id="4", fields={"id": 4, "name": "test4", "weight": 110.0}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 7.0, 2: 8.0}, }, ) result = coll.upsert(doc4) assert bool(result) assert result.ok() assert coll.stats.doc_count == 4 # Step 12: Fetch doc fetched_docs = coll.fetch(["3", "4"]) assert len(fetched_docs) == 2 assert "3" in fetched_docs assert "4" in fetched_docs assert fetched_docs["3"].field("name") == "test3_updated" assert fetched_docs["4"].field("name") == "test4" # Step 13: Query doc query_result = coll.query(filter="id >= 3", topk=10) assert len(query_result) == 2 # Step 14: Flush coll.flush() # Verify data is still accessible after flush fetched_docs = coll.fetch(["1", "2", "3", "4"]) assert len(fetched_docs) == 4 # Step 15: Destroy coll.destroy() @pytest.mark.parametrize( "data_type, index_param", VALID_VECTOR_DATA_TYPE_INDEX_PARAM_MAP_PARAMS ) def test_vector_index_params( self, collection_temp_dir, collection_option: CollectionOption, data_type: DataType, index_param, single_vector_schema, ): vector_name = DEFAULT_VECTOR_FIELD_NAME[data_type] dimension = DEFAULT_VECTOR_DIMENSION coll = zvec.create_and_open( path=collection_temp_dir, schema=single_vector_schema, option=collection_option, ) assert coll is not None, ( f"Failed to create and open collection, {data_type}, {index_param}" ) docs = {str(i): generate_doc(i, single_vector_schema) for i in range(5)} result = coll.insert(docs.values()) assert len(result) == len(docs), ( f"Expected 5 results, got {len(result)}, {data_type}, {index_param}" ) for item in result: assert item.ok(), f"Insertion failed for, {data_type}, {index_param}" def check_result( label: str, metric_type: MetricType, quantize_type: QuantizeType ): query_vector = [1] * dimension if data_type in [DataType.SPARSE_VECTOR_FP16, DataType.SPARSE_VECTOR_FP32]: query_vector = {1: 1} fetch_result = coll.fetch([str(i) for i in range(len(docs))]) assert len(fetch_result) == len(docs), ( f"{label}, Expected 5 fetched docs, got {len(fetch_result)}, {data_type}, {index_param}" ) for i in range(len(docs)): doc_id = str(i) assert doc_id in fetch_result, ( f"{label}, Document ID '{doc_id}' not found, {data_type}, {index_param}" ) fetched_doc = fetch_result[doc_id] # Verify doc equal assert is_doc_equal(fetched_doc, docs[doc_id], single_vector_schema), ( f"{label}, doc not equal, insert: {docs[doc_id]}, fetched: {fetched_doc}, {data_type}, {index_param}" ) query_result: list[Doc] = coll.query( VectorQuery(field_name=vector_name, vector=query_vector), include_vector=False, topk=len(docs), ) assert len(query_result) == len(docs), ( f"{label}, Expected {len(docs)} result, got {len(query_result)}, {data_type}, {index_param}" ) inserted_ids = [str(i) for i in range(len(docs))] queried_ids = [doc.id for doc in query_result] assert set(inserted_ids) == set(queried_ids), ( f"{label}, inserted_ids != queried_ids, insert: {inserted_ids}, query: {queried_ids}, {data_type}, {index_param}" ) last_score = None for i, doc in enumerate(query_result): # Get the document's vector for comparison expect_doc = generate_doc(int(doc.id), single_vector_schema) doc_vector = expect_doc.vector(vector_name) expected_score = distance( doc_vector, query_vector, metric_type, data_type, quantize_type, ) print(f"query: {doc}, expect_core: {expected_score}") if quantize_type is QuantizeType.UNDEFINED: assert is_float_equal(doc.score, expected_score), ( f"{label} top{i} pk{doc.id} score {doc.score:6f} expected:{expected_score:6f}, {data_type}, {index_param}" ) if last_score is not None: if metric_type == MetricType.IP: assert last_score >= doc.score, ( f"{label}, score not sorted, last_score: {last_score}, current_score: {doc.score}, {data_type}, {index_param}" ) else: assert last_score <= doc.score, ( f"{label}, score not sorted, last_score: {last_score}, current_score: {doc.score}, {data_type}, {index_param}" ) last_score = doc.score # default metric_type=IP, quantize_type=None check_result("pre_create_index", MetricType.IP, QuantizeType.UNDEFINED) # create index coll.create_index( field_name=vector_name, index_param=index_param, option=IndexOption(), ) check_result( "post_create_index", index_param.metric_type, index_param.quantize_type ) coll.drop_index(field_name=vector_name) check_result("post_drop_index", MetricType.IP, QuantizeType.UNDEFINED) new_docs = {str(i): generate_doc(i, single_vector_schema) for i in range(5, 8)} new_result = coll.insert(new_docs.values()) assert len(new_result) == len(new_docs), ( f"Expected {len(new_docs)} insertion results for new docs, got {len(new_result)} for vector {vector_name}" ) for item in new_result: assert item.ok(), ( f"New document insertion failed for vector {vector_name}: {item}" ) docs |= new_docs coll.create_index( field_name=vector_name, index_param=index_param, option=IndexOption(), ) check_result( "post_create_index2", index_param.metric_type, index_param.quantize_type ) coll.destroy() class TestColumnDDL: def test_add_column(self, basic_collection: Collection): basic_collection.add_column( field_schema=FieldSchema("income", DataType.INT32), expression="'weight' * 2", # Simple expression ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, "income": 1}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 def test_add_column_with_default_option(self, basic_collection: Collection): # Add a new column with default option basic_collection.add_column( field_schema=FieldSchema("test_column_default", DataType.INT32), expression="100", option=AddColumnOption(), # Default option ) # Verify column was added by inserting data doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, "test_column_default": 1}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) # Verify document was inserted stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize("concurrency", [0, 1, 4, 8]) def test_add_column_with_various_concurrency_options( self, basic_collection: Collection, concurrency ): field_name = f"test_column_concurrent_{concurrency}" basic_collection.add_column( field_schema=FieldSchema(field_name, DataType.INT32), expression="100", option=AddColumnOption(concurrency=concurrency), ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize("data_type", SUPPORT_ADD_COLUMN_DATA_TYPE) def test_add_column_valid_data_types(self, basic_collection: Collection, data_type): field_name = f"test_field_{data_type.name.lower()}" # Add a new column with specific data type basic_collection.add_column( field_schema=FieldSchema(field_name, data_type), expression="1" if data_type != DataType.STRING else "'test'", ) # Verify column was added by inserting data if data_type == DataType.STRING: field_value = "test_value" elif data_type in [DataType.ARRAY_STRING]: field_value = ["test_value"] elif data_type in [DataType.ARRAY_INT32, DataType.ARRAY_INT64]: field_value = [1, 2, 3] elif data_type in [DataType.ARRAY_FLOAT, DataType.ARRAY_DOUBLE]: field_value = [1.1, 2.2, 3.3] elif data_type == DataType.ARRAY_BOOL: field_value = [True, False] elif data_type in [DataType.FLOAT, DataType.DOUBLE]: field_value = 1.5 elif data_type in [DataType.INT32, DataType.INT64]: field_value = 100 elif data_type == DataType.BOOL: field_value = True else: field_value = 1 doc = Doc( id="1", fields={ "id": 1, "name": "test", "weight": 80.5, field_name: field_value, }, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) # Verify document was inserted stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize("data_type", NOT_SUPPORT_ADD_COLUMN_DATA_TYPE) def test_add_column_invalid_data_types( self, basic_collection: Collection, data_type ): with pytest.raises(Exception) as exc_info: field_name = f"test_field_{data_type.name.lower()}" # Add a new column with specific data type basic_collection.add_column( field_schema=FieldSchema(field_name, data_type), expression="1" if data_type != DataType.STRING else "'test'", ) assert NOT_SUPPORT_ADD_COLUMN_ERROR_MSG in str(exc_info.value) @pytest.mark.parametrize("nullable", [True, False]) def test_add_column_with_nullable_options( self, basic_collection: Collection, nullable ): field_name = f"test_field_nullable_{str(nullable).lower()}" # Add a new column with specific nullable option basic_collection.add_column( field_schema=FieldSchema(field_name, DataType.INT32, nullable=nullable), expression="100", ) # Verify column was added by inserting data doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) # Verify document was inserted stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 # Verify column was added by inserting data doc = Doc( id="2", fields={"id": 2, "name": "test", "weight": 80.5, field_name: None}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) if nullable: result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) else: with pytest.raises(ValueError) as e: basic_collection.insert(doc) assert ( "Field 'test_field_nullable_false': expected non-nullable type" in str(e.value) ) # Verify document was inserted stats = basic_collection.stats assert stats is not None if nullable: assert stats.doc_count == 2 else: assert stats.doc_count == 1 @pytest.mark.parametrize( "expression", [ "1", # Constant integer "1.5", # Constant float "'test'", # Constant string "id", # Reference to existing field "weight * 2", # Simple arithmetic "weight + id", # Complex arithmetic "CASE WHEN weight > 50 THEN 1 ELSE 0 END", # Conditional expression ], ) def test_add_column_with_different_expressions( self, basic_collection: Collection, expression ): field_name = f"test_field_expr_{abs(hash(expression)) % 1000}" # Add a new column with specific expression basic_collection.add_column( field_schema=FieldSchema(field_name, DataType.INT32), expression=expression, ) # Verify column was added by inserting data doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) # Verify document was inserted stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 def test_add_column_with_index_param(self, basic_collection: Collection): basic_collection.add_column( field_schema=FieldSchema( "indexed_field", DataType.INT32, index_param=InvertIndexParam(enable_range_optimization=True), ), expression="id * 2", ) # Verify column was added by inserting data doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, "indexed_field": 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) # Verify document was inserted stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize( "field_name", [ "a", # Minimum length "a" * 32, # Maximum length (32 characters) "valid_field_name_123", # Alphanumeric with underscore "Valid-Field-Name", # With hyphens "_underscore_start", # Starting with underscore "field_name_with_123_numbers", # Numbers in middle "FIELD_NAME_UPPERCASE", # Uppercase # "field_with_nums_123_and_hyphens-456", # Complex valid name within limit ], ) def test_add_column_with_valid_field_names( self, basic_collection: Collection, field_name ): basic_collection.add_column( field_schema=FieldSchema(field_name, DataType.INT32), expression="200" ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, field_name: 300}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize( "invalid_field_name", [ "", # Empty string " ", # Space only "a" * 33, # Too long (33 characters, exceeds 32) "field name", # Contains space "field.name", # Contains dot "field@name", # Contains special character "field/name", # Contains slash "field\\name", # Contains backslash "field$name", # Contains dollar sign "field+name", # Contains plus sign "field=name", # Contains equals sign None, # None value ], ) def test_add_column_with_invalid_field_names( self, basic_collection: Collection, invalid_field_name ): with pytest.raises(Exception) as exc_info: basic_collection.add_column( field_schema=FieldSchema(invalid_field_name, DataType.INT32), expression="100", ) if invalid_field_name is None: assert "validate failed" in str(exc_info.value), ( "Error message is unreasonable: e=" + str(exc_info.value) ) else: assert ( "invalid" in str(exc_info.value).lower() or "name" in str(exc_info.value).lower() ) def test_alter_column_rename(self, basic_collection: Collection): basic_collection.alter_column( old_name="weight", new_name="mass", option=AlterColumnOption(), ) doc = Doc( id="1", fields={"id": 1, "name": "test", "mass": 80.5}, # Use new name vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 def test_alter_column_non_exist(self, basic_collection: Collection): with pytest.raises(Exception) as exc_info: basic_collection.alter_column( old_name="non_existing", new_name="new_name", field_schema=FieldSchema("new_name", DataType.STRING), ) assert "column non_existing not found" in str(exc_info.value), ( "Error message is unreasonable: e=" + str(exc_info.value) ) def test_alter_column_with_default_option(self, basic_collection: Collection): basic_collection.add_column( field_schema=FieldSchema("original_field", DataType.INT32), expression="100" ) basic_collection.alter_column( old_name="original_field", new_name="renamed_field", option=AlterColumnOption(), ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, "renamed_field": 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize("concurrency", [0, 1, 4, 8]) def test_alter_column_with_various_concurrency_options( self, basic_collection: Collection, concurrency ): old_field_name = f"orig_field_{concurrency}" new_field_name = f"modified_field_{concurrency}" basic_collection.add_column( field_schema=FieldSchema(old_field_name, DataType.INT32), expression="100", ) basic_collection.alter_column( old_name=old_field_name, new_name=new_field_name, option=AlterColumnOption(concurrency=concurrency), ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, new_field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize( "old_field_name,new_field_name", [ ("a", "new_a"), # Minimum length ( "abcdefghijklmnopqrstuvwxyz123456", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", ), # Maximum length (32 characters) ("valid_field_name_123", "new_valid_field"), # Alphanumeric with underscore ("Valid-Field-Name", "New-Field-Name"), # With hyphens ("_underscore_start", "new_underscore"), # Starting with underscore ("field_name_with_123_numbers", "new_with_nums"), # Numbers in middle ("FIELD_NAME_UPPERCASE", "new_uppercase"), # Uppercase ( "field_with_nums_3_and_hyphens-6", "new_field_hyphens", ), # Complex valid name ], ) def test_alter_column_field_name_valid( self, basic_collection: Collection, old_field_name, new_field_name ): basic_collection.add_column( field_schema=FieldSchema(old_field_name, DataType.INT32), expression="100", ) basic_collection.alter_column( old_name=old_field_name, new_name=new_field_name, option=AlterColumnOption(), ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, new_field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 @pytest.mark.parametrize( "valid_old_name,invalid_new_name", [ ("temp_field", ""), # Empty new name ("temp_field", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), # Too long new name ("temp_field", "field name"), # New name with space ("temp_field", "field.name"), # New name with dot ("temp_field", "field@name"), # New name with special character ("temp_field", "field/name"), # New name with slash ("temp_field", "field\\name"), # New name with backslash ("temp_field", "field$name"), # New name with dollar sign ("temp_field", "field+name"), # New name with plus sign ("temp_field", "field=name"), # New name with equals sign ("temp_field", None), # None new name ], ) def test_alter_column_with_invalid_field_names( self, basic_collection: Collection, valid_old_name, invalid_new_name ): basic_collection.add_column( field_schema=FieldSchema("temp_field", DataType.INT32), expression="100" ) with pytest.raises(Exception) as exc_info: basic_collection.alter_column( old_name=valid_old_name, new_name=invalid_new_name if invalid_new_name is not None else "", field_schema=FieldSchema( invalid_new_name if invalid_new_name is not None else "", DataType.INT32, ), ) assert ( "invalid" in str(exc_info.value).lower() or "name" in str(exc_info.value).lower() or "incompatible" in str(exc_info.value).lower() ) def test_drop_column_exist(self, basic_collection: Collection): basic_collection.add_column( field_schema=FieldSchema("temp_field", DataType.INT32), expression="100" ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, "temp_field": 1}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 basic_collection.drop_column("temp_field") doc = Doc( id="2", fields={"id": 2, "name": "test", "weight": 80.5, "temp_field": 1}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) with pytest.raises(Exception) as exc_info: result = basic_collection.insert(doc) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value) def test_drop_column_non_exist(self, basic_collection: Collection): with pytest.raises(Exception) as exc_info: basic_collection.drop_column("non_existing_column") assert NOT_EXIST_COLUMN_TO_DROP_ERROR_MSG in str(exc_info.value) @pytest.mark.parametrize( "field_name", [ "a", # Minimum length "a" * 32, # Maximum length (32 characters) "valid_field_name_123", # Alphanumeric with underscore "Valid-Field-Name", # With hyphens "_underscore_start", # Starting with underscore "field_name_with_123_numbers", # Numbers in middle "FIELD_NAME_UPPERCASE", # Uppercase "field_with_nums_3_and_hyphens-6", # Complex valid name within limit ], ) def test_drop_column_field_name_valid( self, basic_collection: Collection, field_name ): basic_collection.add_column( field_schema=FieldSchema(field_name, DataType.INT32), expression="100" ) doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5, field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) result = basic_collection.insert(doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.ok(), ( f"result={result},Insert operation failed with code = {result.code()}" ) stats = basic_collection.stats assert stats is not None assert stats.doc_count == 1 basic_collection.drop_column(field_name) doc = Doc( id="2", fields={"id": 2, "name": "test", "weight": 80.5, field_name: 200}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) with pytest.raises(Exception) as exc_info: result = basic_collection.insert(doc) assert SCHEMA_VALIDATE_ERROR_MSG in str(exc_info.value) ================================================ FILE: python/tests/detail/test_collection_dml.py ================================================ import logging import pytest from zvec import ( CollectionOption, InvertIndexParam, HnswIndexParam, FieldSchema, VectorSchema, CollectionSchema, Collection, Doc, VectorQuery, StatusCode, ) from distance_helper import * from fixture_helper import * from doc_helper import * Maximum = 1024 DOCID_VALID_LIST = [ "1valid_Id", "123.45", "123abc", "-!@#$%+=.123abc_+", "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ123456789012", ] DOCID_INVALID_LIST = [ None, "", "()qsd123", " ", "/&AS12", "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890121", ] FIELD_VALUE_VALID_LIST = [ ( "bool_field", [ None, True, False, ], ), ( "float_field", [ None, 0.0, -1.0, 1.0, 3.4028235e38, -3.4028235e38, 1.17549435e-38, -1.17549435e-38, float("inf"), float("-inf"), ], ), ( "double_field", [ None, 0.0, -1.0, 1.0, 1.7976931348623157e308, -1.7976931348623157e308, 2.2250738585072014e-308, -2.2250738585072014e-308, float("inf"), float("-inf"), ], ), ( "int32_field", [ None, 0, 1, -1, 2147483647, -2147483648, ], ), ( "int64_field", [ None, 0, 1, -1, 9223372036854775807, -9223372036854775808, ], ), ( "uint32_field", [ None, 0, 1, 4294967295, ], ), ( "uint64_field", [ None, 0, 1, 18446744073709551615, ], ), ( "string_field", [ None, "", "a", "test_name", "这是一个中文名称测试", "a" * 1000, ], ), ( "array_bool_field", [ None, [], [True], [False, True], [True, False, True, False] * 10, ], ), ( "array_float_field", [ None, [], [0.0], [1.0, 2.0, 3.0], [3.4028235e38, -3.4028235e38], ], ), ( "array_double_field", [ None, [], [0.0], [1.0, 2.0, 3.0], [1.7976931348623157e308, -1.7976931348623157e308], ], ), ( "array_int32_field", [ None, [], [0], [1, 2, 3], [2147483647, -2147483648], ], ), ( "array_int64_field", [ None, [], [0], [1, 2, 3], [9223372036854775807, -9223372036854775808], ], ), ( "array_uint32_field", [ None, [], [0], [1, 2, 3], [4294967295], ], ), ( "array_uint64_field", [ None, [], [0], [1, 2, 3], [18446744073709551615], ], ), ( "array_string_field", [ None, [], [""], ["a", "b", "c"], ["test_string", "测试字符串"], ["a" * 100] * 5, ], ), ] FIELD_VALUE_INVALID_LIST = [ ( "bool_field", [ "True", "False", "", ], ), ("float_field", ["invalid", [1.0], {"value": 1.0}]), ("double_field", ["invalid", [1.0], {"value": 1.0}]), ( "int32_field", [ "invalid", [1], {"value": 1}, 2147483648, -2147483649, ], ), ( "int64_field", [ "invalid", [1], {"value": 1}, 9223372036854775808, -9223372036854775809, ], ), ( "uint32_field", [ "invalid", [1], {"value": 1}, 4294967296, -1, ], ), ( "uint64_field", [ "invalid", [1], {"value": 1}, 18446744073709551616, -1, ], ), ( "string_field", [ 123, 12.34, True, ["array"], {"key": "value"}, ], ), ( "array_bool_field", [ True, False, [True, "invalid"], {"key": True}, ], ), ( "array_float_field", [ [1.0, "invalid"], [1.0, None], "invalid", [1.0, [2.0]], 1.0, ], ), ( "array_double_field", [ [1.0, "invalid"], [1.0, None], "invalid", [1.0, [2.0]], 1.0, ], ), ( "array_int32_field", [ [1, "invalid"], [1, None], "invalid", [1, [2]], 1, ], ), ( "array_int64_field", [ [1, "invalid"], [1, None], "invalid", [1, [2]], 1, ], ), ( "array_uint32_field", [ [1, "invalid"], [1, None], [1, -1], "invalid", [1, [2]], 1, ], ), ( "array_uint64_field", [ [1, "invalid"], [1, None], [1, -1], "invalid", [1, [2]], 1, ], ), ( "array_string_field", [ ["valid", 123], ["valid", None], "invalid", [["nested"]], 123, ], ), ] VECTOR_VALUE_VALID_LIST = [ ( "vector_fp32_field", [ [0.0] * 128, [1.0] * 128, [-1.0] * 128, [float("inf")] * 128, [float("-inf")] * 128, [i / 128.0 for i in range(128)], [-i / 128.0 for i in range(128)], ], ), ( "vector_fp16_field", [ [0.0] * 128, [1.0] * 128, [-1.0] * 128, [float("inf")] * 128, [float("-inf")] * 128, [i / 128.0 for i in range(128)], [-i / 128.0 for i in range(128)], ], ), ("vector_int8_field", [[100] * 128, [0] * 128, [-100] * 128]), ( "sparse_vector_fp32_field", [ {0: 1.0}, {0: 0.0, 1: 1.0, 2: -1.0}, {0: float("inf"), 1: float("-inf")}, {i: float(i) for i in range(10)}, {128: 1.0, 256: -1.0, 512: 0.5}, ], ), ( "sparse_vector_fp16_field", [ {0: 1.0}, {0: 0.0, 1: 1.0, 2: -1.0}, {0: float("inf"), 1: float("-inf")}, {i: float(i) for i in range(10)}, {128: 1.0, 256: -1.0, 512: 0.5}, ], ), ] VECTOR_VALUE_INVALID_LIST = [ ( "vector_fp32_field", [ None, [], [0.0] * 127, [0.0] * 129, [0.0] * 1000, ["invalid"], [0, 1, 2], [None] * 128, ], ), ( "vector_fp16_field", [ None, [], [0.0] * 127, [0.0] * 129, [0.0] * 1000, ["invalid"], [0, 1, 2], [None] * 128, ], ), ( "vector_int8_field", [ None, [], [1] * 127, [10] * 129, [0] * 1000, ["invalid"], [0, 1, 2], [None] * 128, ], ), ( "sparse_vector_fp32_field", [ None, "invalid", {None: 1.0}, {"0": 1.0}, {0: "invalid"}, {0: None}, {-1: 1.0}, ], ), ( "sparse_vector_fp16_field", [ None, "invalid", {None: 1.0}, {"0": 1.0}, {0: "invalid"}, {0: None}, {-1: 1.0}, ], ), ] UPDATE_PARTIAL_VALUE = [ ( "partial_fields", {"string_field": "partially_updated_test", "float_field": 95.5}, {}, ), ("dense_vector_only", {}, {"vector_fp32_field": [0.3] * 128}), ("dense_vector_only", {}, {"vector_fp16_field": [0.6] * 128}), ("dense_vector_only", {}, {"vector_int8_field": [3] * 128}), ("sparse_vector_only", {}, {"sparse_vector_fp32_field": {1: 2.0, 2: 3.0, 4: 4.0}}), ( "sparse_vector_only", {}, {"sparse_vector_fp16_field": {10: 2.1, 20: 3.1, 40: 4.1}}, ), ( "fields_and_vectors", {"string_field": "fully_updated_test", "bool_field": False}, { "vector_fp32_field": [0.4] * 128, "sparse_vector_fp32_field": {1: 3.0, 3: 5.0}, }, ), ] # ==================== helper ==================== def singledoc_and_check( collection: Collection, insert_doc, operator="insert", is_delete=1 ): if operator == "insert": result = collection.insert(insert_doc) elif operator == "upsert": result = collection.upsert(insert_doc) elif operator == "update": result = collection.update(insert_doc) else: logging.error("operator value is error!") assert bool(result) assert result.ok() stats = collection.stats assert stats is not None assert stats.doc_count == 1 fetched_docs = collection.fetch([insert_doc.id]) assert len(fetched_docs) == 1 assert insert_doc.id in fetched_docs fetched_doc = fetched_docs[insert_doc.id] assert is_doc_equal(fetched_doc, insert_doc, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v != {}: query_result = collection.query( VectorQuery(field_name=v, vector=insert_doc.vectors[v]), topk=10, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) found_doc = None for doc in query_result: if doc.id == insert_doc.id: found_doc = doc break assert found_doc is not None, ( f"Inserted document {insert_doc.id} not found in query results" ) assert is_doc_equal(found_doc, insert_doc, collection.schema, True, False) if is_delete == 1: collection.delete(insert_doc.id) assert collection.stats.doc_count == 0, "Document should be deleted" def updatedoc_partial_check( collection, update_doc_partial, update_doc_full, operator="update", is_delete=1 ): if operator == "upsert": result = collection.upsert(update_doc_partial) elif operator == "update": result = collection.update(update_doc_partial) else: logging.error("operator value is error!") assert bool(result) assert result.ok() stats = collection.stats assert stats is not None assert stats.doc_count == 1 fetched_docs = collection.fetch([update_doc_partial.id]) assert len(fetched_docs) == 1, ( f"fetched_docs={fetched_docs},Expected 1 fetched document, but got {len(fetched_docs)}" ) assert update_doc_partial.id in fetched_docs, ( f"Expected document ID {update_doc_partial.id} in fetched documents" ) fetched_doc = fetched_docs[update_doc_partial.id] assert is_doc_equal(fetched_doc, update_doc_full, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v != {}: query_result = collection.query( VectorQuery(field_name=v, vector=update_doc_full.vectors[v]), topk=10, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) found_doc = None for doc in query_result: if doc.id == update_doc_partial.id: found_doc = doc break assert found_doc is not None, ( f"Inserted document {update_doc_partial.id} not found in query results" ) assert is_doc_equal( found_doc, update_doc_full, collection.schema, True, False ) if is_delete == 1: collection.delete(update_doc_partial.id) assert collection.stats.doc_count == 0, "Document should be deleted" def batchdoc_and_check(collection, multiple_docs, doc_num, operator="insert"): if operator == "insert": result = collection.insert(multiple_docs) elif operator == "upsert": result = collection.upsert(multiple_docs) elif operator == "update": result = collection.update(multiple_docs) else: logging.error("operator value is error!") assert len(result) == len(multiple_docs) for item in result: assert item.ok(), ( f"result={result},Insert operation failed with code {item.code()}" ) stats = collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == len(multiple_docs), ( f"Document count should be {len(multiple_docs)} after insert, but got {stats.doc_count}" ) doc_ids = [doc.id for doc in multiple_docs] fetched_docs = collection.fetch(doc_ids) assert len(fetched_docs) == len(multiple_docs), ( f"fetched_docs={fetched_docs},Expected {len(multiple_docs)} fetched documents, but got {len(fetched_docs)}" ) for original_doc in multiple_docs: assert original_doc.id in fetched_docs, ( f"Expected document ID {original_doc.id} in fetched documents" ) fetched_doc = fetched_docs[original_doc.id] assert is_doc_equal(fetched_doc, original_doc, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) first_doc = multiple_docs[doc_num - 1] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): query_result = collection.query( VectorQuery(field_name=v, vector=first_doc.vectors[v]), topk=1024, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) found_doc = None for doc in query_result: if doc.id == first_doc.id: found_doc = doc break assert found_doc is not None, ( f"Inserted document {first_doc.id} not found in query results" ) assert is_doc_equal(found_doc, first_doc, collection.schema, True, False) # ==================== Tests ==================== # ---------------------------- # Collection Insert Test Case # ---------------------------- class TestCollectionInsert: def test_insert(self, full_collection: Collection): single_doc = generate_doc(1, full_collection.schema) singledoc_and_check(full_collection, single_doc) @pytest.mark.parametrize("doc_num", [1, 5, Maximum]) def test_insert_batch(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num) def test_insert_duplicate(self, full_collection: Collection): insert_doc = generate_doc(1, full_collection.schema) result = full_collection.insert(insert_doc) assert result.code().value == 0 assert result.ok() # Verify documents were inserted stats = full_collection.stats assert stats is not None assert stats.doc_count == 1 insert_doc_duplicate = full_collection.insert(insert_doc) assert bool(insert_doc_duplicate) assert insert_doc_duplicate.code() == StatusCode.ALREADY_EXISTS, ( f"Second insert operation should fail with ALREADY_EXISTS, but got code {insert_doc_duplicate.code()}" ) stats = full_collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == 1, ( f"Document count should still be 1 after failed insert, but got {stats.doc_count}" ) @pytest.mark.parametrize("doc_id", DOCID_VALID_LIST) def test_insert_docid_valid(self, full_collection: Collection, doc_id): insert_doc = generate_doc_random(doc_id, full_collection.schema) singledoc_and_check(full_collection, insert_doc) @pytest.mark.parametrize("doc_id", DOCID_INVALID_LIST) def test_insert_docid_invalid(self, full_collection: Collection, doc_id): insert_doc = generate_doc_random(doc_id, full_collection.schema) with pytest.raises(Exception) as exc_info: full_collection.insert(insert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_VALID_LIST) @pytest.mark.parametrize( "full_schema_new", [(True, True, HnswIndexParam()), (False, True, HnswIndexParam())], indirect=True, ) def test_insert_fields_valid( self, full_collection_new: Collection, field_name: str, field_values, request ): for i, field_value in enumerate(field_values): doc_id = str(field_value) if field_name == "id" else str(i) doc_fields, doc_vectors = generate_vectordict_random( full_collection_new.schema ) full_schema_params = request.getfixturevalue("full_schema_new") target_field = None for field in full_schema_params.fields: if field.name == field_name: target_field = field break doc_fields[field_name] = field_value insert_doc = Doc(id=doc_id, fields=doc_fields, vectors=doc_vectors) if target_field and not target_field.nullable and field_value is None: with pytest.raises(Exception) as exc_info: full_collection_new.insert(insert_doc) assert exc_info.value is not None else: singledoc_and_check(full_collection_new, insert_doc) @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_INVALID_LIST) def test_insert_fields_invalid( self, full_collection: Collection, field_name: str, field_values ): for i, field_value in enumerate(field_values): doc_id = str(field_value) if field_name == "id" else str(i) doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_fields[field_name] = field_value insert_doc = Doc(id=doc_id, fields=doc_fields, vectors=doc_vectors) with pytest.raises(Exception) as exc_info: full_collection.insert(insert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_VALID_LIST) def test_insert_vector_valid( self, full_collection: Collection, vector_field: str, vector_values ): for i, vector_value in enumerate(vector_values): doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_vectors[vector_field] = vector_value insert_doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) singledoc_and_check(full_collection, insert_doc) @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_INVALID_LIST) def test_insert_vector_invalid( self, full_collection: Collection, vector_field: str, vector_values ): for i, vector_value in enumerate(vector_values): doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_vectors[vector_field] = vector_value insert_doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) with pytest.raises(Exception) as exc_info: full_collection.insert(insert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 class TestCollectionUpdate: def test_update(self, full_collection: Collection): insert_doc = generate_doc(1, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) updated_doc = generate_update_doc(1, full_collection.schema) singledoc_and_check(full_collection, updated_doc, operator="update") @pytest.mark.parametrize("doc_num", [1, 5, Maximum]) def test_update_batch(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num) multiple_update_docs = [ generate_update_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check( full_collection, multiple_update_docs, doc_num, operator="update" ) def test_empty_collection_update(self, full_collection: Collection): updated_doc = generate_update_doc(1, full_collection.schema) result = full_collection.update(updated_doc) assert bool(result), f"Expected 1 result, but got {len(result)}" assert result.code() == StatusCode.NOT_FOUND, ( f"Update operation should fail with NOT_FOUND, but got code {result.code()}" ) fetched_docs = full_collection.fetch([updated_doc.id]) assert len(fetched_docs) == 0 stats = full_collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == 0, ( f"Document count should be 0, but got {stats.doc_count}" ) @pytest.mark.parametrize("doc_num", [1, 5, Maximum]) def test_empty_collection_update_batch(self, full_collection: Collection, doc_num): multiple_update_docs = [ generate_update_doc(i, full_collection.schema) for i in range(doc_num) ] result = full_collection.update(multiple_update_docs) assert len(result) == len(multiple_update_docs), ( f"Expected {len(multiple_update_docs)} results, but got {len(result)}" ) for item in result: assert item.code() == StatusCode.NOT_FOUND, ( f"Update operation should fail with NOT_FOUND, but got code {item.code()}" ) stats = full_collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == 0, ( f"Document count should be 0, but got {stats.doc_count}" ) doc_ids = [doc.id for doc in multiple_update_docs] fetched_docs = full_collection.fetch(doc_ids) assert len(fetched_docs) == 0 @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_VALID_LIST) @pytest.mark.parametrize( "full_schema_new", [(True, True, HnswIndexParam()), (False, True, HnswIndexParam())], indirect=True, ) def test_update_fields_valid( self, full_collection_new: Collection, field_name: str, field_values, request ): for i, field_value in enumerate(field_values): insert_doc = generate_doc(i, full_collection_new.schema) singledoc_and_check(full_collection_new, insert_doc, is_delete=0) update_doc_fields, update_doc_vectors = generate_vectordict_random( full_collection_new.schema ) full_schema_params = request.getfixturevalue("full_schema_new") target_field = None for field in full_schema_params.fields: if field.name == field_name: target_field = field break update_doc_fields[field_name] = field_value update_doc = Doc( id=str(i), fields=update_doc_fields, vectors=update_doc_vectors ) if target_field and not target_field.nullable and field_value is None: with pytest.raises(Exception) as exc_info: update_doc_fields[field_name] = field_value full_collection_new.update(update_doc) assert exc_info.value is not None full_collection_new.delete(insert_doc.id) else: singledoc_and_check( full_collection_new, update_doc, operator="update", is_delete=1 ) @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_INVALID_LIST) def test_update_fields_invalid( self, full_collection: Collection, field_name: str, field_values ): for i, field_value in enumerate(field_values): insert_doc = generate_doc(i, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) update_doc_fields, update_doc_vectors = generate_vectordict_random( full_collection.schema ) update_doc_fields[field_name] = field_value update_doc = Doc( id=str(i), fields=update_doc_fields, vectors=update_doc_vectors ) with pytest.raises(Exception) as exc_info: full_collection.update(update_doc) assert exc_info.value is not None full_collection.delete(insert_doc.id) stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_VALID_LIST) def test_update_doc_vector_valid( self, full_collection: Collection, collection_temp_dir, collection_option, vector_field: str, vector_values, ): for i, vector_value in enumerate(vector_values): insert_doc = generate_doc(i, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) update_doc_fields, update_doc_vectors = generate_vectordict_random( full_collection.schema ) update_doc_vectors[vector_field] = vector_value update_doc = Doc( id=str(i), fields=update_doc_fields, vectors=update_doc_vectors ) singledoc_and_check(full_collection, update_doc, operator="update") @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_INVALID_LIST) def test_update_doc_vector_invalid( self, full_collection: Collection, collection_temp_dir, collection_option, vector_field: str, vector_values, ): for i, vector_value in enumerate(vector_values): insert_doc = generate_doc(i, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) update_doc_fields, update_doc_vectors = generate_vectordict_random( full_collection.schema ) update_doc_vectors[vector_field] = vector_value update_doc = Doc( id=str(i), fields=update_doc_fields, vectors=update_doc_vectors ) with pytest.raises(Exception) as exc_info: full_collection.update(update_doc) assert exc_info.value is not None full_collection.delete(insert_doc.id) stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize( "update_type, fields_to_update, vectors_to_update", UPDATE_PARTIAL_VALUE ) def test_update_partial_fields( self, full_collection: Collection, collection_temp_dir, collection_option, update_type: str, fields_to_update: dict, vectors_to_update: dict, doc_id=1, ): insert_doc = generate_doc(doc_id, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) update_doc_fields, update_doc_vectors = insert_doc.fields, insert_doc.vectors for k, v in fields_to_update.items(): update_doc_fields[k] = v for k, v in vectors_to_update.items(): update_doc_vectors[k] = v update_doc_full = Doc( id=str(doc_id), fields=update_doc_fields, vectors=update_doc_vectors ) update_doc_partial = Doc( id=str(doc_id), fields=fields_to_update, vectors=vectors_to_update ) updatedoc_partial_check( full_collection, update_doc_partial, update_doc_full, operator="update", is_delete=1, ) class TestCollectionUpsert: def test_new_doc_upsert(self, full_collection: Collection): single_doc = generate_doc(1, full_collection.schema) singledoc_and_check(full_collection, single_doc, operator="upsert", is_delete=1) @pytest.mark.parametrize("doc_num", [1, 5, Maximum]) def test_new_doc_upsert_batch(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="upsert") def test_existing_doc_upsert(self, full_collection: Collection): insert_doc = generate_doc(1, full_collection.schema) singledoc_and_check(full_collection, insert_doc, is_delete=0) updated_doc = generate_update_doc(1, full_collection.schema) singledoc_and_check(full_collection, updated_doc, operator="upsert") @pytest.mark.parametrize("doc_id", DOCID_VALID_LIST) def test_upsert_docid_valid(self, full_collection: Collection, doc_id): upsert_doc = generate_doc_random(doc_id, full_collection.schema) singledoc_and_check(full_collection, upsert_doc, operator="upsert", is_delete=1) @pytest.mark.parametrize("doc_id", DOCID_INVALID_LIST) def test_upsert_docid_invalid(self, full_collection: Collection, doc_id): upsert_doc = generate_doc_random(doc_id, full_collection.schema) with pytest.raises(Exception) as exc_info: full_collection.upsert(upsert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_VALID_LIST) @pytest.mark.parametrize( "full_schema_new", [(True, True, HnswIndexParam()), (False, True, HnswIndexParam())], indirect=True, ) def test_upsert_fields_valid( self, full_collection_new: Collection, field_name: str, field_values, request ): for i, field_value in enumerate(field_values): doc_id = str(field_value) if field_name == "id" else str(i) doc_fields, doc_vectors = generate_vectordict_random( full_collection_new.schema ) full_schema_params = request.getfixturevalue("full_schema_new") target_field = None for field in full_schema_params.fields: if field.name == field_name: target_field = field break doc_fields[field_name] = field_value upsert_doc = Doc(id=doc_id, fields=doc_fields, vectors=doc_vectors) if target_field and not target_field.nullable and field_value is None: with pytest.raises(Exception) as exc_info: full_collection_new.upsert(upsert_doc) assert exc_info.value is not None else: singledoc_and_check( full_collection_new, upsert_doc, operator="upsert", is_delete=1 ) @pytest.mark.parametrize("field_name, field_values", FIELD_VALUE_INVALID_LIST) def test_upsert_fields_invalid( self, full_collection: Collection, field_name: str, field_values ): for i, field_value in enumerate(field_values): doc_id = str(field_value) if field_name == "id" else str(i) doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_fields[field_name] = field_value upsert_doc = Doc(id=doc_id, fields=doc_fields, vectors=doc_vectors) with pytest.raises(Exception) as exc_info: full_collection.upsert(upsert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_VALID_LIST) def test_upsert_vector_valid( self, full_collection: Collection, vector_field: str, vector_values ): for i, vector_value in enumerate(vector_values): doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_vectors[vector_field] = vector_value upsert_doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) singledoc_and_check( full_collection, upsert_doc, operator="upsert", is_delete=1 ) @pytest.mark.parametrize("vector_field, vector_values", VECTOR_VALUE_INVALID_LIST) def test_upsert_vector_invalid( self, full_collection: Collection, vector_field: str, vector_values ): for i, vector_value in enumerate(vector_values): doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) doc_vectors[vector_field] = vector_value upsert_doc = Doc(id=str(i), fields=doc_fields, vectors=doc_vectors) with pytest.raises(Exception) as exc_info: full_collection.upsert(upsert_doc) assert exc_info.value is not None stats = full_collection.stats assert stats is not None assert stats.doc_count == 0 class TestCollectionDelete: @pytest.mark.parametrize("doc_num", [1, 5, Maximum]) def test_delete_batch(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_ids = [doc.id for doc in multiple_docs] result = full_collection.delete(doc_ids) assert len(result) == len(doc_ids) for item in result: assert item.ok() def test_delete_non_exist(self, full_collection: Collection): result = full_collection.delete("non_existing_id") assert result.code().value == 1 assert result.code() == StatusCode.NOT_FOUND @pytest.mark.parametrize("doc_num", [5]) def test_delete_batch_part_non_exist(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_ids = [doc.id for doc in multiple_docs] doc_ids.extend([str(doc_num), str(doc_num + 1)]) result = full_collection.delete(doc_ids) assert len(result) == len(doc_ids) for i in range(len(result)): if i < doc_num: assert result[i].ok() else: assert result[i].code().value == 1 assert result[i].code() == StatusCode.NOT_FOUND @pytest.mark.parametrize("doc_num", [5]) def test_delete_by_filter(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") result = full_collection.delete_by_filter("int32_field > 0") assert result is None def test_delete_empty_ids(self, full_collection: Collection): result = full_collection.delete([]) assert len(result) == 0 ================================================ FILE: python/tests/detail/test_collection_dql.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from zvec.typing import DataType, StatusCode, MetricType, QuantizeType from zvec.model import Collection, Doc, VectorQuery from zvec.model.param import ( CollectionOption, InvertIndexParam, HnswIndexParam, FlatIndexParam, IVFIndexParam, HnswQueryParam, IVFQueryParam, ) from zvec.model.schema import FieldSchema, VectorSchema from zvec.extension import RrfReRanker, WeightedReRanker, QwenReRanker from distance_helper import * from zvec import StatusCode from distance_helper import * from fixture_helper import * from doc_helper import * from params_helper import * # ==================== helper ==================== def batchdoc_and_check( collection: Collection, multiple_docs, doc_num, operator="insert" ): if operator == "insert": result = collection.insert(multiple_docs) elif operator == "upsert": result = collection.upsert(multiple_docs) elif operator == "update": result = collection.update(multiple_docs) else: logging.error("operator value is error!") assert len(result) == len(multiple_docs) for item in result: assert item.ok(), ( f"result={result},Insert operation failed with code {item.code()}" ) stats = collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == len(multiple_docs), ( f"Document count should be {len(multiple_docs)} after insert, but got {stats.doc_count}" ) doc_ids = [doc.id for doc in multiple_docs] fetched_docs = collection.fetch(doc_ids) assert len(fetched_docs) == len(multiple_docs), ( f"fetched_docs={fetched_docs},Expected {len(multiple_docs)} fetched documents, but got {len(fetched_docs)}" ) for original_doc in multiple_docs: assert original_doc.id in fetched_docs, ( f"Expected document ID {original_doc.id} in fetched documents" ) fetched_doc = fetched_docs[original_doc.id] assert is_doc_equal(fetched_doc, original_doc, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) first_doc = multiple_docs[doc_num - 1] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): query_result = collection.query( VectorQuery(field_name=v, vector=first_doc.vectors[v]), topk=1024, include_vector=True, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) found_doc = None for doc in query_result: if doc.id == first_doc.id: found_doc = doc break assert found_doc is not None, ( f"Inserted document {first_doc.id} not found in query results" ) assert is_doc_equal(found_doc, first_doc, collection.schema) assert hasattr(found_doc, "score") assert isinstance(found_doc.score, (int, float)) def batchdoc_and_check_ivf( collection: Collection, multiple_docs, doc_num, operator="insert" ): if operator == "insert": result = collection.insert(multiple_docs) elif operator == "upsert": result = collection.upsert(multiple_docs) elif operator == "update": result = collection.update(multiple_docs) else: logging.error("operator value is error!") assert len(result) == len(multiple_docs) for item in result: assert item.ok(), ( f"result={result},Insert operation failed with code {item.code()}" ) stats = collection.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == len(multiple_docs), ( f"Document count should be {len(multiple_docs)} after insert, but got {stats.doc_count}" ) doc_ids = [doc.id for doc in multiple_docs] fetched_docs = collection.fetch(doc_ids) assert len(fetched_docs) == len(multiple_docs), ( f"fetched_docs={fetched_docs},Expected {len(multiple_docs)} fetched documents, but got {len(fetched_docs)}" ) for original_doc in multiple_docs: assert original_doc.id in fetched_docs, ( f"Expected document ID {original_doc.id} in fetched documents" ) fetched_doc = fetched_docs[original_doc.id] assert is_doc_equal(fetched_doc, original_doc, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) first_doc = multiple_docs[doc_num - 1] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: query_result = collection.query( VectorQuery(field_name=v, vector=first_doc.vectors[v]), topk=1024, include_vector=True, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) found_doc = None for doc in query_result: if doc.id == first_doc.id: found_doc = doc break assert found_doc is not None, ( f"Inserted document {first_doc.id} not found in query results" ) assert is_doc_equal(found_doc, first_doc, collection.schema) assert hasattr(found_doc, "score") assert isinstance(found_doc.score, (int, float)) def single_querydoc_check( multiple_docs, query_result, full_collection: Collection, is_by_vector=0, query_vector=None, data_type=None, vector_name=None, metric_type=MetricType.IP, id_include_vector: bool = False, is_output_fields=0, ): for original_doc in multiple_docs: for doc in query_result: if doc.id == original_doc.id: found_doc = doc if is_output_fields == 0: assert is_doc_equal( found_doc, original_doc, full_collection.schema, True, id_include_vector, ) assert hasattr(found_doc, "score") # assert found_doc.score >= 0.0 if not id_include_vector: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): assert found_doc.vector(v) == {} else: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): assert found_doc.vector(v) != {} if is_by_vector: prev_score = float("inf") for i, doc in enumerate(query_result): doc_vector = full_collection.fetch(doc.id)[doc.id].vector( vector_name ) expected_score = distance( query_vector, doc_vector, metric_type, data_type, k ) if ( full_collection.schema.vector(vector_name).data_type != DataType.VECTOR_FP16 ): assert abs(doc.score - expected_score) < 0.001, ( f"{data_type} {vector_name} :Expected score {expected_score:.6f}, but got {doc.score:.6f} for document {doc.id}" ) assert doc.score <= prev_score, ( f"{data_type} {vector_name} :Scores should be in descending order. Current: {doc.score}, Previous: {prev_score}" ) prev_score = doc.score def multi_querydoc_check(multiple_docs, query_result, full_collection): for original_doc in multiple_docs: for doc in query_result: if doc.id == original_doc.id: found_doc = doc assert is_doc_equal( found_doc, original_doc, full_collection.schema, False, False ) assert hasattr(found_doc, "score"), ( "Document should have a score attribute" ) assert found_doc.score >= 0.0, ( "Fetch operation should return default score of 0.0" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): assert found_doc.vector(v) == {} # ==================== Tests ==================== class TestCollectionFetch: def test_fetch_non_existing(self, full_collection: Collection): result = full_collection.fetch(ids=["non_existing_id1", "non_existing_id2"]) assert len(result) == 0 @pytest.mark.parametrize("doc_num", [3]) def test_fetch_partial_non_existing(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") fetch_id_list = [doc.id for doc in multiple_docs] fetch_id_list.append("non_existing_id") result = full_collection.fetch(ids=fetch_id_list) assert len(result) == doc_num assert "non_existing_id" not in result.keys() def test_fetch_empty_ids(self, full_collection: Collection): result = full_collection.fetch(ids=[]) assert len(result) == 0, ( f"Expected 0 results for empty ID list, but got {len(result)}" ) class TestCollectionQuery: @pytest.mark.parametrize("doc_num", [5]) def test_query_with_no_condition(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") query_result = full_collection.query() assert len(query_result) == doc_num single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_empty(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") result1 = full_collection.query(filter="") assert len(result1) == doc_num single_querydoc_check(multiple_docs, result1, full_collection) result2 = full_collection.query(filter=None) assert len(result2) == doc_num single_querydoc_check(multiple_docs, result2, full_collection) ids1 = set(doc.id for doc in result1) ids2 = set(doc.id for doc in result2) assert ids1 == ids2 @pytest.mark.parametrize("field_name", ["int32_field"]) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_single_condition( self, full_collection: Collection, doc_num, field_name ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") filter = field_name + " > 5" query_result = full_collection.query(filter=filter) assert len(query_result) == doc_num - 6 returned_doc_ids = set() for doc in query_result: returned_doc_ids.add(doc.id) expected_doc_ids = set(str(i) for i in range(6, doc_num)) for doc in query_result: assert doc.id in expected_doc_ids assert int(doc.field(field_name)) > 5 single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("field_name", ["int32_field"]) @pytest.mark.parametrize( "filter", [ "int32_field > 3 and int32_field < 9", "int32_field >= 5 and int32_field <= 7", ], ) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_and( self, full_collection: Collection, doc_num, field_name, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") filter = field_name + " > 3 and " + field_name + " < 9" query_result = full_collection.query(filter=filter) if filter == "int32_field > 3 and int32_field < 9": assert len(query_result) == doc_num - 4 - 1 expected_doc_ids = set(str(i) for i in range(4, 9)) for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value > 3 and field_value < 9 else: assert len(query_result) == 3 expected_doc_ids = set(str(i) for i in range(5, 8)) for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value >= 5 and field_value <= 7 single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("field_name", ["int32_field"]) @pytest.mark.parametrize( "filter", [ "int32_field < 3 or int32_field > 8", "int32_field = 3 or int32_field = 7", "int32_field <= 3 or int32_field >= 8", ], ) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_or( self, full_collection: Collection, doc_num, field_name, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") query_result = full_collection.query(filter=filter) if filter == "int32_field < 3 or int32_field > 8": assert len(query_result) == 4 expected_doc_ids = set([str(0), str(1), str(2), str(9)]) for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value < 3 or field_value > 8 elif filter == "int32_field = 3 or int32_field = 7": assert len(query_result) == 2 expected_doc_ids = set([str(3), str(7)]) for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value == 3 or field_value == 7 else: assert len(query_result) == 6 expected_doc_ids = set([str(0), str(1), str(2), str(3), str(8), str(9)]) for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value <= 3 or field_value >= 8 single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("field_names", [("int32_field", "bool_field")]) @pytest.mark.parametrize( "filter", [ "(int32_field < 3 or int32_field > 8) and bool_field = false", "(int32_field > 2 and int32_field < 5) or (int32_field > 7 and bool_field = true)", ], ) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_parentheses( self, full_collection: Collection, doc_num, field_names, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") query_result = full_collection.query(filter=filter) if filter == "(int32_field < 3 or int32_field > 8) and bool_field = false": assert len(query_result) == 2 expected_doc_ids = set([str(1), str(9)]) for doc in query_result: assert doc.id in expected_doc_ids assert ( int(doc.field(field_names[0])) < 3 or int(doc.field(field_names[0])) > 8 ) and doc.field(field_names[1]) == False else: assert len(query_result) == 3 expected_doc_ids = set([str(3), str(4), str(8)]) for doc in query_result: assert doc.id in expected_doc_ids assert ( ( int(doc.field(field_names[0])) > 2 and int(doc.field(field_names[0])) < 5 ) or (doc.field(field_names[0])) > 7 and doc.field(field_names[1]) == True ) single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize( "filter", [ "int32_field >", "int32_field = 'string'", "nonexistent_field = 5", "int32_field > 5 and", "int32_field > > 5", ], ) @pytest.mark.parametrize("doc_num", [10]) def test_query_filter_invalid(self, full_collection: Collection, doc_num, filter): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") with pytest.raises(Exception) as exc_info: full_collection.query(filter=filter) if filter in ["int32_field = 'string'", "nonexistent_field = 5"]: assert "Analyze sql info failed" in str(exc_info.value) else: assert "Invalid filter" in str(exc_info.value) @pytest.mark.parametrize("field_name", ["int32_field"]) @pytest.mark.parametrize("topk_value", [1, 5, 10, 50, 100, 500, 1000, 1024]) def test_query_with_filter_topk_valid( self, full_collection: Collection, topk_value: int, field_name ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(topk_value) ] batchdoc_and_check( full_collection, multiple_docs, topk_value, operator="insert" ) filter = ( field_name + f" >={topk_value - 1} and " + field_name + f" <={topk_value}" ) print("filter:\n") print(filter) query_result = full_collection.query(filter=filter, topk=topk_value) assert len(query_result) == 1 expected_doc_ids = [str(topk_value - 1)] for doc in query_result: assert doc.id in expected_doc_ids field_value = int(doc.field(field_name)) assert field_value >= topk_value - 1 and field_value <= topk_value single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("field_name", ["int32_field"]) @pytest.mark.parametrize("topk_value", [1, 5, 10, 50, 100, 500, 1000, 1024]) def test_query_without_filter_topk_valid( self, full_collection: Collection, topk_value: int, field_name ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(topk_value) ] batchdoc_and_check( full_collection, multiple_docs, topk_value, operator="insert" ) query_result = full_collection.query(topk=topk_value) assert len(query_result) == topk_value single_querydoc_check(multiple_docs, query_result, full_collection) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_include_vector(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") query_result = full_collection.query(include_vector=True) assert len(query_result) > 0 single_querydoc_check( multiple_docs, query_result, full_collection, id_include_vector=1 ) @pytest.mark.parametrize("output_fields", [["int32_field", "int64_field"]]) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_output_fields( self, full_collection: Collection, doc_num, output_fields ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") query_result = full_collection.query(output_fields=output_fields) assert len(query_result) > 0 for doc in query_result: field_names = doc.field_names() assert field_names == output_fields @pytest.mark.parametrize( "filter", [ "int32_field >= 10 and int32_field <= 20", "int32_field = 3 and int32_field = 8", ], ) @pytest.mark.parametrize("doc_num", [10]) def test_query_empty_result(self, full_collection: Collection, doc_num, filter): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") result = full_collection.query(filter=filter) assert len(result) == 0 @pytest.mark.parametrize( "full_schema_new", [(True, True, HnswIndexParam()), (False, True, FlatIndexParam())], indirect=True, ) @pytest.mark.parametrize("doc_num", [10]) def test_query_by_id( self, full_collection_new: Collection, doc_num, full_schema_new ): multiple_docs = [ generate_doc(i, full_collection_new.schema) for i in range(doc_num) ] batchdoc_and_check( full_collection_new, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): query_result = full_collection_new.query(VectorQuery(field_name=v, id="1")) assert len(query_result) > 0 query_doc = full_collection_new.fetch(ids=["1"]) query_vector = query_doc["1"].vector(v) single_querydoc_check( multiple_docs, query_result, full_collection_new, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize("doc_num", [10]) def test_query_by_id_ivf(self, full_collection_ivf: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection_ivf.schema) for i in range(doc_num) ] batchdoc_and_check_ivf( full_collection_ivf, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: query_result = full_collection_ivf.query( VectorQuery(field_name=v, id="1") ) assert len(query_result) > 0 query_doc = full_collection_ivf.fetch(ids=["1"]) query_vector = query_doc["1"].vector(v) single_querydoc_check( multiple_docs, query_result, full_collection_ivf, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize( "full_schema_new", [(True, True, HnswIndexParam()), (False, True, FlatIndexParam())], indirect=True, ) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("topk", [None, 1024]) @pytest.mark.parametrize("filter", [None, "int32_field >= 3 and int32_field <= 7"]) def test_query_by_vector( self, full_collection_new: Collection, doc_num, full_schema_new, topk, filter ): multiple_docs = [ generate_doc(i, full_collection_new.schema) for i in range(doc_num) ] batchdoc_and_check( full_collection_new, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection_new.schema ) query_vector = doc_vectors[v] if topk and filter: query_result = full_collection_new.query( filter=filter, vectors=VectorQuery(field_name=v, vector=query_vector), topk=topk, ) elif topk and not filter: query_result = full_collection_new.query( VectorQuery(field_name=v, vector=query_vector), topk=topk ) elif not topk and filter: query_result = full_collection_new.query( filter=filter, vectors=VectorQuery(field_name=v, vector=query_vector), ) else: query_result = full_collection_new.query( VectorQuery(field_name=v, vector=query_vector) ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) single_querydoc_check( multiple_docs, query_result, full_collection_new, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize("doc_num", [10]) def test_query_by_vector_ivf(self, full_collection_ivf: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection_ivf.schema) for i in range(doc_num) ] batchdoc_and_check_ivf( full_collection_ivf, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): if v in ["vector_fp16_field", "vector_fp32_field"]: doc_fields, doc_vectors = generate_vectordict_random( full_collection_ivf.schema ) query_vector = doc_vectors[v] query_result = full_collection_ivf.query( VectorQuery(field_name=v, vector=query_vector), topk=1024, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) single_querydoc_check( multiple_docs, query_result, full_collection_ivf, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize("doc_num", [10]) def test_query_multivector_rrf(self, full_collection: Collection, doc_num): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) single_query_results = {} for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): single_query_results[v] = full_collection.query( VectorQuery(field_name=v, vector=doc_vectors[v]) ) expected_rrf_scores = calculate_multi_vector_rrf_scores(single_query_results) multi_query_vectors = [] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): multi_query_vectors.append(VectorQuery(field_name=v, vector=doc_vectors[v])) rrf_reranker = RrfReRanker(topn=3) multi_query_result = full_collection.query( vectors=multi_query_vectors, reranker=rrf_reranker, ) assert len(multi_query_result) > 0, ( f"Expected at least 1 result, but got {len(multi_query_result)}" ) multi_querydoc_check(multiple_docs, multi_query_result, full_collection) prev_score = float("inf") for i, doc in enumerate(multi_query_result): doc_id = doc.id assert doc_id in expected_rrf_scores, ( f"Document {doc_id} should be in expected RRF scores" ) expected_score = expected_rrf_scores[doc_id] actual_score = doc.score assert abs(actual_score - expected_score) < 1e-10, ( f"RRF score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( f"Scores should be in descending order. Current: {doc.score}, Previous: {prev_score}" ) prev_score = doc.score @pytest.mark.parametrize( "weights", [ { "vector_fp32_field": 0.3, "vector_fp16_field": 0.2, "vector_int8_field": 0.3, "sparse_vector_fp32_field": 0.1, "sparse_vector_fp16_field": 0.1, } ], ) @pytest.mark.parametrize( "metric_type", [MetricType.L2, MetricType.IP, MetricType.COSINE] ) @pytest.mark.parametrize("doc_num", [10]) def test_query_multivector_weighted( self, full_collection: Collection, doc_num, weights, metric_type ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) weighted_reranker = WeightedReRanker( topn=3, weights=weights, metric=MetricType.IP ) single_query_results = {} for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): single_query_results[v] = full_collection.query( VectorQuery(field_name=v, vector=doc_vectors[v]) ) expected_weighted_scores = calculate_multi_vector_weighted_scores( single_query_results, weights, MetricType.IP ) multi_query_vectors = [] for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): multi_query_vectors.append(VectorQuery(field_name=v, vector=doc_vectors[v])) multi_query_result = full_collection.query( vectors=multi_query_vectors, reranker=weighted_reranker, ) assert len(multi_query_result) > 0, ( f"Expected at least 1 result, but got {len(multi_query_result)}" ) multi_querydoc_check(multiple_docs, multi_query_result, full_collection) prev_score = float("inf") for i, doc in enumerate(multi_query_result): doc_id = doc.id assert doc_id in expected_weighted_scores, ( f"Document {doc_id} should be in expected scores" ) expected_score = expected_weighted_scores[doc_id] actual_score = doc.score assert abs(actual_score - expected_score) < 1e-10, ( f"score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( f"Scores should be in descending order. Current: {doc.score}, Previous: {prev_score}" ) prev_score = doc.score @pytest.mark.parametrize("topk", [5]) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) def test_query_consistency( self, full_collection: Collection, filter, doc_num, topk ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") results = [] for i in range(5): query_result = full_collection.query(filter=filter, topk=topk) single_querydoc_check(multiple_docs, query_result, full_collection) results.append(query_result) assert len(results) == 5 expected_count = len(results[0]) for i, result in enumerate(results): assert len(result) == expected_count expected_ids = set(doc.id for doc in results[0]) for i, result in enumerate(results): result_ids = set(doc.id for doc in result) assert result_ids == expected_ids for i, result in enumerate(results): result_ids = [doc.id for doc in result] expected_sorted_ids = sorted(result_ids, key=lambda x: int(x)) assert result_ids == expected_sorted_ids @pytest.mark.parametrize("ef", [0, 100, 1024, 2048]) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("topk", [1024]) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) def test_query_vector_with_HnswQueryParam_valid( self, full_collection_new: Collection, doc_num, full_schema_new, topk, filter, ef, ): multiple_docs = [ generate_doc(i, full_collection_new.schema) for i in range(doc_num) ] batchdoc_and_check( full_collection_new, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection_new.schema ) query_vector = doc_vectors[v] query_result = full_collection_new.query( filter=filter, vectors=VectorQuery( field_name=v, vector=query_vector, param=HnswQueryParam(ef=ef) ), topk=topk, ) assert len(query_result) > 0, ( f"Expected at least 1 query result, but got {len(query_result)}" ) single_querydoc_check( multiple_docs, query_result, full_collection_new, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize("ef", [None, "invalid", 10.5]) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("topk", [10]) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) def test_query_vector_with_HnswQueryParam_invalid( self, full_collection: Collection, doc_num, topk, ef, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) query_vector = doc_vectors[v] with pytest.raises(Exception) as exc_info: full_collection.query( filter=filter, vectors=VectorQuery( field_name=v, vector=query_vector, param=HnswQueryParam(ef=ef) ), topk=topk, ) assert INCOMPATIBLE_CONSTRUCTOR_ERROR_MSG in str(exc_info.value) @pytest.mark.parametrize("nprobe", [1, 10, 100, 2048]) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("topk", [10]) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) def test_query_vector_with_IVFQueryParam_valid( self, full_collection_ivf: Collection, nprobe, doc_num, topk, filter ): multiple_docs = [ generate_doc(i, full_collection_ivf.schema) for i in range(doc_num) ] batchdoc_and_check_ivf( full_collection_ivf, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection_ivf.schema ) if v in ["vector_fp32_field"]: query_vector = doc_vectors[v] query_result = full_collection_ivf.query( filter=filter, vectors=VectorQuery( field_name=v, vector=query_vector, param=IVFQueryParam(nprobe=nprobe), ), topk=topk, ) assert len(query_result) > 0 single_querydoc_check( multiple_docs, query_result, full_collection_ivf, is_by_vector=1, query_vector=query_vector, data_type=k, vector_name=v, ) @pytest.mark.parametrize("nprobe", [None, 10.5]) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize("topk", [10]) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) def test_query_vector_with_IVFQueryParam_invalid( self, full_collection_ivf: Collection, nprobe, doc_num, topk, filter ): multiple_docs = [ generate_doc(i, full_collection_ivf.schema) for i in range(doc_num) ] batchdoc_and_check_ivf( full_collection_ivf, multiple_docs, doc_num, operator="insert" ) for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection_ivf.schema ) if v in ["vector_fp32_field"]: print("v:\n") print(v) query_vector = doc_vectors[v] with pytest.raises(Exception) as exc_info: full_collection_ivf.query( # filter=filter, vectors=VectorQuery( field_name=v, vector=query_vector, param=IVFQueryParam(nprobe=nprobe), ), topk=topk, ) assert INCOMPATIBLE_CONSTRUCTOR_ERROR_MSG in str(exc_info.value) @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) @pytest.mark.parametrize("doc_num", [10]) def test_query_vector_with_param_invalid( self, full_collection: Collection, doc_num, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") with pytest.raises(Exception) as exc_info: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection.schema ) query_vector = doc_vectors[v] if v in ["vector_fp16_field", "vector_fp32_field"]: full_collection.query( filter=filter, vectors=VectorQuery( field_name=v, vector=query_vector, param=HnswIndexParam() ), ) assert INCOMPATIBLE_FUNCTION_ERROR_MSG in str(exc_info.value) @pytest.mark.parametrize("doc_num", [10]) @pytest.mark.parametrize( "test_case_name,vector_query,expected_error_msg", [ ( "Non-existent vector field name", lambda ref_dense_vector: VectorQuery( field_name="nonexistent_vector", vector=ref_dense_vector ), "Expected exception for non-existent vector field name", ), ( "Invalid vector data type for dense vector (string instead of list)", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field", vector="invalid_vector_data" ), "Expected exception for invalid dense vector data type", ), ( "Invalid vector data type for sparse vector (list instead of dict)", lambda ref_dense_vector: VectorQuery( field_name="sparse_fp32", vector=[1.0, 2.0, 3.0] ), "Expected exception for invalid sparse vector data type", ), ( "Empty vector data for dense vector", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field", vector=[] ), "Expected exception for empty dense vector data", ), ( "Invalid dimension for dense vector", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field", vector=[1.0, 2.0] ), # Only 2 dimensions instead of 128 "Expected exception for invalid dense vector dimension", ), ( "Non-existent document ID for by_id query", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field", id="999" ), # Non-existent ID "Expected exception for non-existent document ID", ), ( "Both vector and id specified (invalid combination)", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field", vector=ref_dense_vector, id="5" ), "Expected exception for specifying both vector and id", ), ( "Neither vector nor id specified", lambda ref_dense_vector: VectorQuery( field_name="vector_fp32_field" ), # Neither vector nor id "Expected exception for specifying neither vector nor id", ), ], ) def test_query_vector_with_vectors_invalid( self, full_collection: Collection, doc_num, test_case_name, vector_query, expected_error_msg, ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") ref_doc_result = full_collection.fetch(ids=["5"]) assert "5" in ref_doc_result ref_doc = ref_doc_result["5"] ref_dense_vector = ref_doc.vector("vector_fp32_field") with pytest.raises(Exception) as exc_info: full_collection.query(vectors=[vector_query(ref_dense_vector)]) assert exc_info.value is not None, expected_error_msg @pytest.mark.parametrize("filter", ["int32_field >= 3 and int32_field <= 7"]) @pytest.mark.parametrize("doc_num", [10]) def test_query_invalid_param_incompatible_type( self, full_collection: Collection, doc_num, filter ): multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") with pytest.raises(Exception) as exc_info: for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): doc_fields, doc_vectors = generate_vectordict_random( full_collection.schema ) query_vector = doc_vectors[v] full_collection.query( filter=filter, vectors=VectorQuery(field_name=v, vector=query_vector), param=HnswIndexParam(), topk=3, ) assert "query() got an unexpected keyword argument 'param'" in str( exc_info.value ) class TestRRFScoreCalculation: class MockDoc: def __init__(self, id, score=0.0): self._id = id self._score = score @property def id(self): return self._id @property def score(self): return self._score @score.setter def score(self, score): self._score = score def test_rrf_score_calculation_formula(self): k = 60 assert abs(calculate_rrf_score(0, k) - 1.0 / 61) < 1e-10, ( "RRF score for rank 0 should be 1/61" ) assert abs(calculate_rrf_score(1, k) - 1.0 / 62) < 1e-10, ( "RRF score for rank 1 should be 1/62" ) assert abs(calculate_rrf_score(2, k) - 1.0 / 63) < 1e-10, ( "RRF score for rank 2 should be 1/63" ) assert abs(calculate_rrf_score(10, k) - 1.0 / 71) < 1e-10, ( "RRF score for rank 10 should be 1/71" ) k = 10 assert abs(calculate_rrf_score(0, k) - 1.0 / 11) < 1e-10, ( "RRF score for rank 0 with k=10 should be 1/11" ) assert abs(calculate_rrf_score(1, k) - 1.0 / 12) < 1e-10, ( "RRF score for rank 1 with k=10 should be 1/12" ) def test_multi_vector_rrf_scores(self): query1_results = [self.MockDoc("1"), self.MockDoc("2"), self.MockDoc("3")] query2_results = [self.MockDoc("3"), self.MockDoc("1"), self.MockDoc("4")] query3_results = [self.MockDoc("2"), self.MockDoc("4"), self.MockDoc("5")] query_results = { "vector1": query1_results, "vector2": query2_results, "vector3": query3_results, } rrf_scores = calculate_multi_vector_rrf_scores(query_results, k=60) expected_doc1_score = 1.0 / 61 + 1.0 / 62 assert abs(rrf_scores["1"] - expected_doc1_score) < 1e-10, ( f"RRF score for doc1 mismatch: expected {expected_doc1_score}, got {rrf_scores['1']}" ) expected_doc2_score = 1.0 / 62 + 1.0 / 61 assert abs(rrf_scores["2"] - expected_doc2_score) < 1e-10, ( f"RRF score for doc2 mismatch: expected {expected_doc2_score}, got {rrf_scores['2']}" ) expected_doc3_score = 1.0 / 63 + 1.0 / 61 assert abs(rrf_scores["3"] - expected_doc3_score) < 1e-10, ( f"RRF score for doc3 mismatch: expected {expected_doc3_score}, got {rrf_scores['3']}" ) expected_doc4_score = 1.0 / 63 + 1.0 / 62 assert abs(rrf_scores["4"] - expected_doc4_score) < 1e-10, ( f"RRF score for doc4 mismatch: expected {expected_doc4_score}, got {rrf_scores['4']}" ) expected_doc5_score = 1.0 / 63 assert abs(rrf_scores["5"] - expected_doc5_score) < 1e-10, ( f"RRF score for doc5 mismatch: expected {expected_doc5_score}, got {rrf_scores['5']}" ) sorted_scores = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) expected_order = ["1", "2", "3", "4", "5"] actual_order = [item[0] for item in sorted_scores] assert actual_order == expected_order, ( f"RRF score ranking mismatch: expected {expected_order}, got {actual_order}" ) class TestCollectionConcurrencyOperations: @pytest.mark.parametrize("doc_num", [10]) def test_concurrent_insert_update_upsert_query( self, full_collection: Collection, doc_num ): import threading results = [] errors = [] multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(1000, 1010) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") def insert_operation(thread_id): try: multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(thread_id, thread_id + 5) ] result = full_collection.insert(multiple_docs) results.append(("insert", thread_id, len(result))) except Exception as e: errors.append(("insert", thread_id, str(e))) def update_operation(thread_id): try: multiple_docs = [ generate_doc_random(i, full_collection.schema) for i in range(1000, 1001) ] result = full_collection.update(multiple_docs) results.append(("update", thread_id, len(result))) except Exception as e: errors.append(("update", thread_id, str(e))) def upsert_operation(thread_id): try: multiple_docs = [ generate_doc(i, full_collection.schema) for i in range(thread_id, thread_id + 5) ] result = full_collection.upsert(multiple_docs) results.append(("upsert", thread_id, len(result))) except Exception as e: errors.append(("upsert", thread_id, str(e))) def query_operation(thread_id): try: if thread_id % 3 == 0: result = full_collection.query(filter="int32_field > 1", topk=5) elif thread_id % 3 == 1: result = full_collection.query(filter="bool_field = true", topk=3) else: query_vector = [0.1] * 128 result = full_collection.query( VectorQuery( field_name="vector_fp32_field", vector=query_vector ), topk=3, ) results.append(("query", thread_id, len(result))) except Exception as e: errors.append(("query", thread_id, str(e))) def delete_operation(thread_id): try: # Delete some existing documents delete_ids = ( [f"{thread_id + 1}", f"{thread_id + 2}"] if thread_id < 5 else [f"{thread_id % 5 + 1}"] ) result = full_collection.delete(delete_ids) results.append(("delete", thread_id, len(result))) except Exception as e: errors.append(("delete", thread_id, str(e))) threads = [] for i in range(1): thread = threading.Thread(target=insert_operation, args=(i,)) threads.append(thread) thread.start() for i in range(1): thread = threading.Thread(target=update_operation, args=(i,)) threads.append(thread) thread.start() for i in range(1): thread = threading.Thread(target=upsert_operation, args=(i,)) threads.append(thread) thread.start() for i in range(1): thread = threading.Thread(target=query_operation, args=(i,)) threads.append(thread) thread.start() for i in range(1): thread = threading.Thread(target=delete_operation, args=(i,)) threads.append(thread) thread.start() for thread in threads: thread.join() insert_results = [r for r in results if r[0] == "insert"] update_results = [r for r in results if r[0] == "update"] upsert_results = [r for r in results if r[0] == "upsert"] query_results = [r for r in results if r[0] == "query"] delete_results = [r for r in results if r[0] == "delete"] assert ( len(insert_results) + len(update_results) + len(upsert_results) + len(query_results) + len(delete_results) > 0 ), f"No operations succeeded. Errors: {errors}" critical_errors = [ e for e in errors if "critical" in e[2].lower() or "fatal" in e[2].lower() ] assert len(critical_errors) == 0, f"Critical errors occurred: {critical_errors}" ================================================ FILE: python/tests/detail/test_collection_exception.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pytest import numpy as np import zvec from zvec import ( CollectionOption, InvertIndexParam, HnswIndexParam, DataType, Collection, Doc, FieldSchema, VectorSchema, VectorQuery, ) class TestCollectionExceptionHandling: @pytest.fixture(scope="function") def test_collection(self, tmp_path_factory): """Fixture to create a test collection""" collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), FieldSchema("weight", DataType.FLOAT, nullable=True), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert coll is not None, "Failed to create and open collection" yield coll # Clean up if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") def test_create_and_open_missing_path(self, tmp_path_factory): collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( schema=collection_schema, option=collection_option ) assert exc_info.value is not None, ( "Expected exception for missing path parameter" ) def test_create_and_open_missing_schema(self, tmp_path_factory): temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" collection_option = CollectionOption(read_only=False, enable_mmap=True) with pytest.raises(Exception) as exc_info: coll = zvec.create_and_open( path=str(collection_path), option=collection_option ) assert exc_info.value is not None, ( "Expected exception for missing schema parameter" ) def test_open_missing_path(self): collection_option = CollectionOption(read_only=False, enable_mmap=True) with pytest.raises(Exception) as exc_info: coll = zvec.open(option=collection_option) assert exc_info.value is not None, ( "Expected exception for missing path parameter" ) def test_insert_missing_docs(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.insert() assert exc_info.value is not None, ( "Expected exception for missing docs parameter" ) def test_update_missing_docs(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.update() assert exc_info.value is not None, ( "Expected exception for missing docs parameter" ) def test_upsert_missing_docs(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.upsert() assert exc_info.value is not None, ( "Expected exception for missing docs parameter" ) def test_delete_missing_ids(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.delete() assert exc_info.value is not None, ( "Expected exception for missing ids parameter" ) def test_fetch_missing_ids(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.fetch() assert exc_info.value is not None, ( "Expected exception for missing ids parameter" ) def test_query_missing_vectorquery_field_name(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: result = test_collection.query(vectors=[VectorQuery()]) assert exc_info.value is not None, ( "Expected exception for missing VectorQuery field_name parameter" ) def test_add_column_missing_field_schema(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: test_collection.add_column() assert exc_info.value is not None, ( "Expected exception for missing field_schema parameter" ) def test_alter_column_missing_old_name(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: test_collection.alter_column(new_name="new_name") assert exc_info.value is not None, ( "Expected exception for missing old_name parameter" ) def test_alter_column_missing_new_name(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: test_collection.alter_column(old_name="old_name") assert exc_info.value is not None, ( "Expected exception for missing new_name parameter" ) def test_drop_column_missing_field_name(self, test_collection: Collection): with pytest.raises(Exception) as exc_info: test_collection.drop_column() assert exc_info.value is not None, ( "Expected exception for missing field_name parameter" ) def test_invalid_parameter_types(self, test_collection: Collection): # This test depends on specific implementation details # Generally, we would expect TypeErrors or similar exceptions pass def test_missing_required_parameters(self, test_collection: Collection): # This test depends on specific implementation details # Generally, we would expect TypeErrors or similar exceptions pass def test_empty_collection_operations(self, tmp_path_factory): collection_schema = zvec.CollectionSchema( name="empty_test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "empty_test_collection" coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert coll is not None, "Failed to create and open collection" # Test fetch on empty collection result = coll.fetch(["1"]) assert len(result) >= 0 # May be empty or have special handling # Test query on empty collection result = coll.query() assert len(result) == 0 # Test update on empty collection doc = Doc( id="1", fields={"id": 1, "name": "test"}, vectors={"dense": np.random.random(128).tolist()}, ) result = coll.update(doc) # Should handle gracefully, possibly with NOT_FOUND status # Clean up if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") def test_resource_management(self, test_collection: Collection): doc = Doc( id="1", fields={"id": 1, "name": "test", "weight": 80.5}, vectors={ "dense": np.random.random(128).tolist(), "sparse": {1: 1.0, 2: 2.0}, }, ) # Insert result = test_collection.insert(doc) assert result.ok() # Fetch result = test_collection.fetch(["1"]) assert len(result) == 1 # Query result = test_collection.query() assert len(result) >= 0 # Update result = test_collection.update(doc) assert result.ok() # Delete result = test_collection.delete("1") assert result.ok() def test_exception_resource_cleanup(self, test_collection: Collection): # This test would need to simulate exception conditions # which is difficult without specific failure injection points pass ================================================ FILE: python/tests/detail/test_collection_open.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import numpy as np from fixture_helper import * COLLECTION_OPTION_TEST_CASES_VALID = [ # (read_only, enable_mmap, description) (False, True, "Read-write with mmap enabled"), (False, False, "Read-write with mmap disabled"), (True, True, "Read-only with mmap enabled"), (True, False, "Read-only with mmap disabled"), ] # Test data for invalid paths INVALID_PATH_LIST = [ "/nonexistent/directory/test_collection", "invalid:path", "", # Empty path ] @pytest.fixture(scope="session") def collection_schema(): return zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam() ), FieldSchema( "weight", DataType.FLOAT, nullable=False, index_param=InvertIndexParam() ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) @pytest.fixture def single_doc(): id = 0 return Doc( id=f"{id}", fields={"id": id, "name": "test"}, vectors={ "dense": [id + 0.1] * 128, }, ) @pytest.fixture(scope="function") def test_collection( tmp_path_factory, collection_schema, collection_option ) -> Generator[Any, Any, Collection]: temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option ) assert coll is not None, "Failed to create and open collection" assert coll.path == str(collection_path) assert coll.schema.name == collection_schema.name assert list(coll.schema.fields) == list(collection_schema.fields) assert list(coll.schema.vectors) == list(collection_schema.vectors) assert coll.option.read_only == collection_option.read_only assert coll.option.enable_mmap == collection_option.enable_mmap try: yield coll finally: if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") class TestCollectionOpen: def test_open_basic_functionality( self, tmp_path_factory, collection_schema, collection_option ): import sys import time import os # Create unique temp directory temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" # Ensure the path exists collection_path_str = str(collection_path) print(f"DEBUG: Collection path: {collection_path_str}") print(f"DEBUG: Temp directory exists: {temp_dir.exists()}") # Create and open collection first created_coll = zvec.create_and_open( path=collection_path_str, schema=collection_schema, option=collection_option ) assert created_coll is not None, ( f"Failed to create collection, returned None instead of valid Collection object. Path: {collection_path_str}" ) assert created_coll.path == collection_path_str, ( f"Collection path mismatch. Expected: {collection_path_str}, Actual: {created_coll.path}" ) assert created_coll.schema.name == "test_collection", ( f"Collection schema name mismatch. Expected: test_collection, Actual: {created_coll.schema.name}" ) # Insert multiple documents to verify persistence docs = [] for i in range(3): doc = Doc( id=f"{i}", fields={"id": i, "name": f"test_{i}", "weight": float(i * 10)}, vectors={ "dense": [float(j + i) for j in range(128)], "sparse": {j: float(j + i) for j in range(5)}, }, ) docs.append(doc) result = created_coll.insert(docs) assert len(result) == 3, f"Expected 3 insertion results, but got {len(result)}" for i, res in enumerate(result): assert res.ok(), ( f"Insertion result {i} is not OK. Status code: {res.code()}, Message: {res.message()}" ) # Verify documents were inserted using fetch interface fetched_docs_after_insert = created_coll.fetch(["0", "1", "2"]) assert len(fetched_docs_after_insert) == 3, ( f"Expected 3 fetched documents after insertion, but got {len(fetched_docs_after_insert)}" ) assert "0" in fetched_docs_after_insert, ( "Document with ID '0' not found in fetched results after insertion" ) assert "1" in fetched_docs_after_insert, ( "Document with ID '1' not found in fetched results after insertion" ) assert "2" in fetched_docs_after_insert, ( "Document with ID '2' not found in fetched results after insertion" ) # Verify fetched document content after insertion for i in range(3): doc = fetched_docs_after_insert[f"{i}"] assert doc is not None, ( f"Fetched document with ID '{i}' is None after insertion" ) assert doc.id == f"{i}", ( f"Document ID mismatch for document '{i}' after insertion. Expected: {i}, Actual: {doc.id}" ) assert doc.field("id") == i, ( f"Document id field mismatch for document '{i}' after insertion. Expected: {i}, Actual: {doc.field('id')}" ) assert doc.field("name") == f"test_{i}", ( f"Document name field mismatch for document '{i}' after insertion. Expected: test_{i}, Actual: {doc.field('name')}" ) assert doc.field("weight") == float(i * 10), ( f"Document weight field mismatch for document '{i}' after insertion. Expected: {float(i * 10)}, Actual: {doc.field('weight')}" ) # Verify vector access after insertion assert doc.vector("dense") is not None, ( f"Document {i} should have dense vector after insertion" ) assert doc.vector("sparse") is not None, ( f"Document {i} should have sparse vector after insertion" ) # Verify vector types after insertion assert isinstance(doc.vector("dense"), list), ( f"Document {i} dense vector should be dict after insertion, got {type(doc.vector('dense'))}" ) assert isinstance(doc.vector("sparse"), dict), ( f"Document {i} sparse vector should be dict after insertion, got {type(doc.vector('sparse'))}" ) # Verify documents were inserted using stats stats = created_coll.stats assert stats is not None, "Collection stats should not be None" assert stats.doc_count == 3, ( f"Document count mismatch after insertion. Expected: 3, Actual: {stats.doc_count}" ) # Store the collection path before cleanup collection_path = created_coll.path # Clean up the created collection reference del created_coll # Wait and verify the path still exists print(f"DEBUG: Collection path after destroy: {collection_path}") print(f"DEBUG: Path exists after destroy: {os.path.exists(collection_path)}") # Now open the existing collection try: print(f"DEBUG: Path exists before open: {os.path.exists(collection_path)}") # List contents of parent directory for debugging parent_dir = os.path.dirname(collection_path) if os.path.exists(parent_dir): print(f"DEBUG: Parent directory contents: {os.listdir(parent_dir)}") opened_coll = zvec.open(path=collection_path, option=collection_option) assert opened_coll is not None, ( f"Failed to open existing collection at path: {collection_path}. Returned None instead of valid Collection object" ) assert opened_coll.path == collection_path, ( f"Opened collection path mismatch. Expected: {collection_path}, Actual: {opened_coll.path}" ) assert opened_coll.schema.name == "test_collection", ( f"Opened collection schema name mismatch. Expected: test_collection, Actual: {opened_coll.schema.name}" ) # Check reference count of opened collection opened_ref_count = sys.getrefcount(opened_coll) print(f"DEBUG: Reference count of opened collection: {opened_ref_count}") # Verify data persistence # Verify data persistence using fetch interface fetched_docs = opened_coll.fetch(["0", "1", "2"]) assert len(fetched_docs) == 3, ( f"Expected 3 fetched documents after reopening, but got {len(fetched_docs)}" ) assert "0" in fetched_docs, ( "Document with ID '0' not found in fetched results after reopening" ) assert "1" in fetched_docs, ( "Document with ID '1' not found in fetched results after reopening" ) assert "2" in fetched_docs, ( "Document with ID '2' not found in fetched results after reopening" ) # Verify fetched document content after reopening collection for i in range(3): doc = fetched_docs[f"{i}"] assert doc is not None, ( f"Fetched document with ID '{i}' is None after reopening collection" ) assert doc.id == f"{i}", ( f"Document ID mismatch for document '{i}' after reopening. Expected: {i}, Actual: {doc.id}" ) assert doc.field("id") == i, ( f"Document id field mismatch for document '{i}' after reopening. Expected: {i}, Actual: {doc.field('id')}" ) assert doc.field("name") == f"test_{i}", ( f"Document name field mismatch for document '{i}' after reopening. Expected: test_{i}, Actual: {doc.field('name')}" ) assert doc.field("weight") == float(i * 10), ( f"Document weight field mismatch for document '{i}' after reopening. Expected: {float(i * 10)}, Actual: {doc.field('weight')}" ) # Verify vector access after reopening assert doc.vector("dense") is not None, ( f"Document {i} should have dense vector after reopening" ) assert doc.vector("sparse") is not None, ( f"Document {i} should have sparse vector after reopening" ) # Verify vector types after reopening assert isinstance(doc.vector("dense"), list), ( f"Document {i} dense vector should be dict after reopening, got {type(doc.vector('dense'))}" ) assert isinstance(doc.vector("sparse"), dict), ( f"Document {i} sparse vector should be dict after reopening, got {type(doc.vector('sparse'))}" ) # Verify score attribute exists assert hasattr(doc, "score"), ( f"Document {i} should have a score attribute after reopening" ) assert isinstance(doc.score, (int, float)), ( f"Document {i} score should be numeric after reopening, got {type(doc.score)}" ) # For fetch operations, score is typically 0.0 assert doc.score == 0.0, ( f"Document {i} score should be 0.0 for fetch operation after reopening, but got {doc.score}" ) # Test query functionality query_result = opened_coll.query(include_vector=True) assert len(query_result) == 3, ( f"Expected 3 query results, but got {len(query_result)}" ) # Verify query results have proper structure and content with detailed validation returned_doc_ids = set() for doc in query_result: # Verify basic document structure assert doc.id is not None, f"Query result document should have an ID" assert doc.id in ["0", "1", "2"], ( f"Query result document ID should be one of ['0', '1', '2'], but got {doc.id}" ) returned_doc_ids.add(doc.id) # Verify field access assert doc.field("id") is not None, ( f"Document {doc.id} should have id field" ) assert doc.field("name") is not None, ( f"Document {doc.id} should have name field" ) assert doc.field("weight") is not None, ( f"Document {doc.id} should have weight field" ) # Verify field values expected_id = int(doc.id) assert doc.field("id") == expected_id, ( f"Document {doc.id} id field mismatch. Expected: {expected_id}, Actual: {doc.field('id')}" ) assert doc.field("name") == f"test_{expected_id}", ( f"Document {doc.id} name field mismatch. Expected: test_{expected_id}, Actual: {doc.field('name')}" ) assert doc.field("weight") == float(expected_id * 10), ( f"Document {doc.id} weight field mismatch. Expected: {float(expected_id * 10)}, Actual: {doc.field('weight')}" ) # Verify vector access assert doc.vector("dense") is not None, ( f"Document {doc.id} should have dense vector" ) assert doc.vector("sparse") is not None, ( f"Document {doc.id} should have sparse vector" ) # Verify vector types assert isinstance(doc.vector("dense"), list), ( f"Document {doc.id} dense vector should be list, got {type(doc.vector('dense'))}" ) assert isinstance(doc.vector("sparse"), dict), ( f"Document {doc.id} sparse vector should be dict, got {type(doc.vector('sparse'))}" ) # Verify score attribute exists assert hasattr(doc, "score"), ( f"Document {doc.id} should have a score attribute" ) assert isinstance(doc.score, (int, float)), ( f"Document {doc.id} score should be numeric, got {type(doc.score)}" ) # Verify all expected documents are returned expected_doc_ids = {"0", "1", "2"} assert returned_doc_ids == expected_doc_ids, ( f"Query should return all expected documents. Expected: {expected_doc_ids}, Actual: {returned_doc_ids}" ) # === Enhanced validation based on test_collection_dql_operations.py === # Verify vector field names accessibility for all documents for doc in query_result: vector_names = doc.vector_names() expected_vector_names = {"dense", "sparse"} assert set(vector_names) == expected_vector_names, ( f"Document {doc.id} vector names mismatch. Expected: {expected_vector_names}, Actual: {set(vector_names)}" ) # Verify all vector fields can be accessed for vector_name in expected_vector_names: vector_data = doc.vector(vector_name) assert vector_data is not None, ( f"Document {doc.id} should have accessible vector '{vector_name}'" ) if vector_name == "dense": assert isinstance(vector_data, list), ( f"Document {doc.id} vector '{vector_name}' should be list, got {type(vector_data)}" ) else: assert isinstance(vector_data, dict), ( f"Document {doc.id} vector '{vector_name}' should be dict, got {type(vector_data)}" ) # Test query with filter filtered_result = opened_coll.query(filter="id >= 1", include_vector=True) assert len(filtered_result) == 2, ( f"Expected 2 filtered query results (id >= 1), but got {len(filtered_result)}" ) # Verify filtered query results filtered_doc_ids = set() for doc in filtered_result: assert doc.id is not None, ( f"Filtered query result document should have an ID" ) assert doc.id in ["1", "2"], ( f"Filtered query result document ID should be one of ['1', '2'], but got {doc.id}" ) filtered_doc_ids.add(doc.id) # Verify filter condition is satisfied doc_id = int(doc.id) assert doc_id >= 1, ( f"Document {doc.id} should satisfy filter condition id >= 1" ) # Verify document structure assert doc.field("id") is not None, ( f"Document {doc.id} should have id field" ) assert doc.field("name") is not None, ( f"Document {doc.id} should have name field" ) assert doc.field("weight") is not None, ( f"Document {doc.id} should have weight field" ) # Verify field values assert doc.field("id") == doc_id, ( f"Document {doc.id} id field mismatch. Expected: {doc_id}, Actual: {doc.field('id')}" ) assert doc.field("name") == f"test_{doc_id}", ( f"Document {doc.id} name field mismatch. Expected: test_{doc_id}, Actual: {doc.field('name')}" ) assert doc.field("weight") == float(doc_id * 10), ( f"Document {doc.id} weight field mismatch. Expected: {float(doc_id * 10)}, Actual: {doc.field('weight')}" ) # Verify vector access assert doc.vector("dense") is not None, ( f"Document {doc.id} should have dense vector" ) assert doc.vector("sparse") is not None, ( f"Document {doc.id} should have sparse vector" ) # Verify score attribute exists assert hasattr(doc, "score"), ( f"Document {doc.id} should have a score attribute" ) assert isinstance(doc.score, (int, float)), ( f"Document {doc.id} score should be numeric, got {type(doc.score)}" ) # Verify filtered documents expected_filtered_ids = {"1", "2"} assert filtered_doc_ids == expected_filtered_ids, ( f"Filtered query should return expected documents. Expected: {expected_filtered_ids}, Actual: {filtered_doc_ids}" ) # Test vector query functionality for dense vectors query_vector_dense = [0.1] * 128 vector_query_result = opened_coll.query( VectorQuery(field_name="dense", vector=query_vector_dense) ) assert len(vector_query_result) > 0, ( f"Expected at least 1 vector query result, but got {len(vector_query_result)}" ) # Verify vector query results structure for doc in vector_query_result[:3]: # Check first 3 results assert doc.id is not None, ( f"Vector query result document should have an ID" ) assert doc.id in ["0", "1", "2"], ( f"Vector query result document ID should be one of ['0', '1', '2'], but got {doc.id}" ) # Verify document structure assert doc.field("id") is not None, ( f"Document {doc.id} should have id field" ) assert doc.field("name") is not None, ( f"Document {doc.id} should have name field" ) assert doc.field("weight") is not None, ( f"Document {doc.id} should have weight field" ) # Verify vector access assert doc.vector("dense") is not None, ( f"Document {doc.id} should have dense vector" ) assert doc.vector("sparse") is not None, ( f"Document {doc.id} should have sparse vector" ) # Verify score attribute exists and is numeric assert hasattr(doc, "score"), ( f"Document {doc.id} should have a score attribute" ) assert isinstance(doc.score, (int, float)), ( f"Document {doc.id} score should be numeric, got {type(doc.score)}" ) # For dense vector queries, score should typically be non-negative (depending on metric) # Note: This may vary based on the metric type used assert doc.score >= 0 or doc.score < 0, ( f"Document {doc.id} score should be a valid number" ) # Test vector query functionality for sparse vectors query_vector_sparse = {1: 1.0, 2: 2.0, 3: 3.0} sparse_vector_query_result = opened_coll.query( VectorQuery(field_name="sparse", vector=query_vector_sparse) ) assert len(sparse_vector_query_result) > 0, ( f"Expected at least 1 sparse vector query result, but got {len(sparse_vector_query_result)}" ) # Verify sparse vector query results structure for doc in sparse_vector_query_result[:3]: # Check first 3 results assert doc.id is not None, ( f"Sparse vector query result document should have an ID" ) assert doc.id in ["0", "1", "2"], ( f"Sparse vector query result document ID should be one of ['0', '1', '2'], but got {doc.id}" ) # Verify document structure assert doc.field("id") is not None, ( f"Document {doc.id} should have id field" ) assert doc.field("name") is not None, ( f"Document {doc.id} should have name field" ) assert doc.field("weight") is not None, ( f"Document {doc.id} should have weight field" ) # Verify vector access assert doc.vector("dense") is not None, ( f"Document {doc.id} should have dense vector" ) assert doc.vector("sparse") is not None, ( f"Document {doc.id} should have sparse vector" ) # Verify score attribute exists and is numeric assert hasattr(doc, "score"), ( f"Document {doc.id} should have a score attribute" ) assert isinstance(doc.score, (int, float)), ( f"Document {doc.id} score should be numeric, got {type(doc.score)}" ) # Clean up if hasattr(opened_coll, "destroy") and opened_coll is not None: opened_coll.destroy() print("DEBUG: Opened collection destroyed successfully") except Exception as e: logging.error("Exception occurred: [{}]".format(e)) raise e @pytest.mark.parametrize( "read_only,enable_mmap,description", COLLECTION_OPTION_TEST_CASES_VALID ) @pytest.mark.parametrize("createAndopen_enable_mmap", [True, False]) def test_open_with_different_collection_options_valid( self, tmp_path_factory, createAndopen_enable_mmap, read_only, enable_mmap, description, collection_schema, ): # Create collection with initial option temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" initial_option = CollectionOption( read_only=False, enable_mmap=createAndopen_enable_mmap ) # Create and open collection first created_coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=initial_option ) assert created_coll is not None, "Failed to create collection" # Clean up the created collection reference del created_coll # Now open with different options collection_option = CollectionOption( read_only=read_only, enable_mmap=enable_mmap ) try: opened_coll = zvec.open(path=str(collection_path), option=collection_option) assert opened_coll is not None, ( f"Failed to open collection with option: {description}. Returned None instead of valid Collection object. Path: {collection_path}" ) assert opened_coll.path == str(collection_path), ( f"Opened collection path mismatch. Expected: {collection_path}, Actual: {opened_coll.path}" ) assert opened_coll.schema.name == collection_schema.name, ( f"Opened collection schema name mismatch. Expected: {collection_schema.name}, Actual: {opened_coll.schema.name}" ) assert opened_coll.option.read_only == read_only, ( f"Opened collection read_only option mismatch. Expected: {read_only}, Actual: {opened_coll.option.read_only}" ) assert opened_coll.option.enable_mmap == createAndopen_enable_mmap, ( f"Opened collection mmap option mismatch. Expected: {createAndopen_enable_mmap}, Actual: {opened_coll.option.enable_mmap}" ) # Clean up if ( hasattr(opened_coll, "destroy") and opened_coll is not None and read_only == False ): opened_coll.destroy() except Exception as e: logging.error("Exception occurred: [{}]".format(e)) pytest.fail(f"Failed to open collection with different options: {e}") def test_open_with_none_option(self, tmp_path_factory, collection_schema): # Create collection temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" initial_option = CollectionOption(read_only=False, enable_mmap=True) # Create and open collection first created_coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=initial_option ) assert created_coll is not None, ( f"Failed to create collection. Returned None instead of valid Collection object. Path: {collection_path}" ) # Clean up the created collection reference del created_coll # Now open with None option with pytest.raises(Exception) as exc_info: zvec.open(path=str(collection_path), option=None) assert "incompatible function arguments" in str(exc_info.value), ( f"Expected 'incompatible function arguments' error, but got: {exc_info.value}" ) def test_reopen_collection(self, tmp_path_factory): # Prepare schema collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) # Create collection temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" # Create and open collection coll1 = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert coll1 is not None, "Failed to create and open collection" # Insert some data doc = Doc( id="1", fields={"id": 1, "name": "test"}, vectors={"dense": np.random.random(128).tolist()}, ) result = coll1.insert(doc) assert result.ok() # Close the first collection (delete reference) del coll1 # Reopen the collection coll2 = zvec.open(path=str(collection_path), option=collection_option) assert coll2 is not None, "Failed to reopen collection" assert coll2.path == str(collection_path) assert coll2.schema.name == collection_schema.name # Verify data is still there fetched_docs = coll2.fetch(["1"]) assert "1" in fetched_docs fetched_doc = fetched_docs["1"] assert fetched_doc.id == "1" assert fetched_doc.field("name") == "test" # Clean up if hasattr(coll2, "destroy") and coll2 is not None: try: coll2.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") def test_open_concurrent_same_path(self, tmp_path_factory): # First create a collection collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) # Create collection path temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" # First create the collection created_coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert created_coll is not None, "Failed to create collection" # Close the collection so we can test concurrent opening if hasattr(created_coll, "close") and created_coll is not None: created_coll.close() # Shared variables to collect results from threads results = [] errors = [] # Lock for thread-safe operations lock = threading.Lock() # Clean up the created collection reference del created_coll # Function to be executed by each thread def open_collection_thread(thread_id): try: coll = zvec.open(path=str(collection_path), option=collection_option) with lock: results.append((thread_id, coll)) # Close the collection if opened successfully if hasattr(coll, "close") and coll is not None: coll.close() except Exception as e: with lock: errors.append((thread_id, str(e))) # Create 5 threads to call open concurrently threads = [] for i in range(5): thread = threading.Thread(target=open_collection_thread, args=(i,)) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify concurrency safety: only one should succeed, others should fail assert len(results) == 1, ( f"Expected exactly one successful open, but got {len(results)}" ) assert len(errors) == 4, ( f"Expected exactly four failures, but got {len(errors)}" ) # Additional verification: check that the successful open has a valid collection successful_thread_id, successful_collection = results[0] assert successful_collection is not None, ( "Successful open should return a valid collection" ) assert successful_collection.path == str(collection_path), ( "Collection path mismatch" ) # Clean up the successfully opened collection if ( hasattr(successful_collection, "destroy") and successful_collection is not None ): try: successful_collection.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") def test_open_with_corrupted_files(self, tmp_path_factory): # First create a collection collection_schema = zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ) ], ) collection_option = CollectionOption(read_only=False, enable_mmap=True) # Create collection path temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" # First create the collection created_coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option, ) assert created_coll is not None, "Failed to create collection" # Close the collection so we can manipulate its files if hasattr(created_coll, "close") and created_coll is not None: created_coll.close() # Test case 1: Delete some files in the collection directory (simulate partial corruption) import os import shutil import random # Get the collection directory path collection_dir = str(collection_path) # List all files in the collection directory files_in_dir = [] for root, dirs, files in os.walk(collection_dir): for file in files: files_in_dir.append(os.path.join(root, file)) # Randomly delete approximately half of the files to simulate partial corruption if files_in_dir: # Shuffle the list to randomly select files random.shuffle(files_in_dir) files_to_delete = files_in_dir[: len(files_in_dir) // 2] for file_path in files_to_delete: try: os.remove(file_path) except Exception as e: pass # Ignore errors during deletion # Try to open the collection with missing files - should raise an exception with pytest.raises(Exception): zvec.open(path=str(collection_path), option=collection_option) # Test case 2: Delete all files in the collection directory (simulate complete corruption) # Recreate the collection recreated_coll = zvec.create_and_open( path=str(collection_path) + "_all", schema=collection_schema, option=collection_option, ) assert recreated_coll is not None, "Failed to recreate collection" # Close the collection so we can manipulate its files if hasattr(recreated_coll, "close") and recreated_coll is not None: recreated_coll.close() # Delete all files in the collection directory try: shutil.rmtree(collection_dir) os.makedirs(collection_dir) # Recreate empty directory except Exception as e: pass # Ignore errors during deletion # Try to open the collection with missing files - should raise an exception with pytest.raises(Exception): zvec.open(path=str(collection_path), option=collection_option) ================================================ FILE: python/tests/detail/test_collection_recall.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest from zvec.typing import DataType, StatusCode, MetricType, QuantizeType from zvec.model import Collection, Doc, VectorQuery from zvec.model.param import ( CollectionOption, InvertIndexParam, HnswIndexParam, FlatIndexParam, IVFIndexParam, HnswQueryParam, IVFQueryParam, ) from zvec.model.schema import FieldSchema, VectorSchema from zvec.extension import RrfReRanker, WeightedReRanker, QwenReRanker from distance_helper import * from zvec import StatusCode from distance_helper import * from fixture_helper import * from doc_helper import * from params_helper import * import time # ==================== helper ==================== def batchdoc_and_check(collection: Collection, multiple_docs, operator="insert"): if operator == "insert": result = collection.insert(multiple_docs) elif operator == "upsert": result = collection.upsert(multiple_docs) elif operator == "update": result = collection.update(multiple_docs) else: logging.error("operator value is error!") assert len(result) == len(multiple_docs) for item in result: assert item.ok(), ( f"result={result},Insert operation failed with code {item.code()}" ) stats = collection.stats assert stats is not None, "Collection stats should not be None" """assert stats.doc_count == len(multiple_docs), ( f"Document count should be {len(multiple_docs)} after insert, but got {stats.doc_count}" )""" doc_ids = [doc.id for doc in multiple_docs] fetched_docs = collection.fetch(doc_ids) assert len(fetched_docs) == len(multiple_docs), ( f"fetched_docs={fetched_docs},Expected {len(multiple_docs)} fetched documents, but got {len(fetched_docs)}" ) for original_doc in multiple_docs: assert original_doc.id in fetched_docs, ( f"Expected document ID {original_doc.id} in fetched documents" ) fetched_doc = fetched_docs[original_doc.id] assert is_doc_equal(fetched_doc, original_doc, collection.schema) assert hasattr(fetched_doc, "score"), "Document should have a score attribute" assert fetched_doc.score == 0.0, ( "Fetch operation should return default score of 0.0" ) def compute_exact_similarity_scores( vectors_a, vectors_b, metric_type=MetricType.IP, DataType=DataType.VECTOR_FP32, QuantizeType=QuantizeType.UNDEFINED, ): similarities = [] for i, vec_a in enumerate(vectors_a): for j, vec_b in enumerate(vectors_b): similarity = distance_recall(vec_a, vec_b, metric_type, DataType) similarities.append((j, similarity)) # For L2,COSINE metric, smaller distances mean higher similarity, so sort in ascending order if ( metric_type in [MetricType.L2] and DataType in [DataType.VECTOR_FP32, DataType.VECTOR_FP16, DataType.VECTOR_INT8] ) or ( metric_type in [MetricType.COSINE] and DataType in [DataType.VECTOR_FP32, DataType.VECTOR_FP16] ): similarities.sort(key=lambda x: x[1], reverse=False) # Ascending order for L2 else: similarities.sort( key=lambda x: x[1], reverse=True ) # Descending order for others # Special handling for COSINE in FP16 to address precision issues if metric_type == MetricType.COSINE and DataType == DataType.VECTOR_FP16: # Clamp values to valid cosine distance range [0, 2] and handle floating point errors similarities = [(idx, max(0.0, min(2.0, score))) for idx, score in similarities] return similarities def get_ground_truth_for_vector_query( collection, query_vector, field_name, all_docs, query_idx, metric_type, k, use_exact_computation=False, ): if use_exact_computation: all_vectors = [doc.vectors[field_name] for doc in all_docs] for d, f in DEFAULT_VECTOR_FIELD_NAME.items(): if field_name == f: DataType = d break similarities = compute_exact_similarity_scores( [query_vector], all_vectors, metric_type, DataType=DataType, QuantizeType=QuantizeType, ) if metric_type == MetricType.COSINE and DataType == DataType.VECTOR_FP16: # Filter out tiny non-zero values that may be caused by precision errors similarities = [ (idx, max(0.0, min(2.0, score))) for idx, score in similarities ] ground_truth_ids_scores = similarities[:k] print("Get the most similar k document IDs k:,ground_truth_ids_scores") print(k, ground_truth_ids_scores) return ground_truth_ids_scores else: full_result = collection.query( VectorQuery(field_name=field_name, vector=query_vector), topk=min(len(all_docs), 1024), include_vector=True, ) ground_truth_ids_scores = [ (result.id, result.score) for result in full_result[:k] ] if not ground_truth_ids_scores: ground_truth_ids_scores = [(all_docs[query_idx].id, 0)] return ground_truth_ids_scores def get_ground_truth_map(collection, test_docs, query_vectors_map, metric_type, k): ground_truth_map = {} for field_name, query_vectors in query_vectors_map.items(): ground_truth_map[field_name] = {} for i, query_vector in enumerate(query_vectors): # Get the ground truth for this query relevant_doc_ids_scores = get_ground_truth_for_vector_query( collection, query_vector, field_name, test_docs, i, metric_type, k, True ) ground_truth_map[field_name][i] = relevant_doc_ids_scores print("ground_truth_map:\n") print(ground_truth_map) return ground_truth_map def calculate_recall_at_k( collection: Collection, test_docs, query_vectors_map, schema, k=1, expected_doc_ids_scores_map=None, tolerance=0.01, ): recall_stats = {} for field_name, query_vectors in query_vectors_map.items(): recall_stats[field_name] = { "relevant_retrieved_count": 0, "total_relevant_count": 0, "retrieved_count": 0, "recall_at_k": 0.0, } for i, query_vector in enumerate(query_vectors): print("Starting %dth query" % i) query_result_list = collection.query( VectorQuery(field_name=field_name, vector=query_vector), topk=1024, include_vector=True, ) retrieved_count = len(query_result_list) query_result_ids_scores = [] for word in query_result_list: query_result_ids_scores.append((word.id, word.score)) recall_stats[field_name]["retrieved_count"] += retrieved_count print("expected_doc_ids_scores_map:\n") print(expected_doc_ids_scores_map) if i in (expected_doc_ids_scores_map[field_name]): expected_relevant_ids_scores = expected_doc_ids_scores_map[field_name][ i ] print( "field_name,i,expected_relevant_ids_scores, query_result_ids_scores:\n" ) print( field_name, i, "\n", expected_relevant_ids_scores, "\n", len(query_result_ids_scores), query_result_ids_scores, ) # Update total relevant documents count recall_stats[field_name]["total_relevant_count"] += len( expected_relevant_ids_scores ) relevant_found_count = 0 for ids_scores_except in expected_relevant_ids_scores: for ids_scores_result in query_result_ids_scores[:k]: if int(ids_scores_result[0]) == int(ids_scores_except[0]): relevant_found_count += 1 break elif ( int(ids_scores_result[0]) != int(ids_scores_except[0]) and abs(ids_scores_result[1] - ids_scores_except[1]) <= tolerance ): print("IDs are not equal, but the error is small, tolerance") print( ids_scores_result[0], ids_scores_except[0], ids_scores_result[1], ids_scores_except[1], tolerance, ) relevant_found_count += 1 break else: continue recall_stats[field_name]["relevant_retrieved_count"] += relevant_found_count # Calculate Recall@K if recall_stats[field_name]["total_relevant_count"] > 0: recall_stats[field_name]["recall_at_k"] = ( recall_stats[field_name]["relevant_retrieved_count"] / recall_stats[field_name]["total_relevant_count"] ) return recall_stats class TestRecall: @pytest.mark.parametrize( "full_schema_new", [ (True, True, HnswIndexParam()), (False, True, IVFIndexParam()), (False, True, FlatIndexParam()), # ——ok ( True, True, HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, ), ), ( True, True, HnswIndexParam( metric_type=MetricType.COSINE, m=24, ef_construction=150, ), ), ( True, True, HnswIndexParam( metric_type=MetricType.L2, m=32, ef_construction=200, ), ), ( False, True, FlatIndexParam( metric_type=MetricType.IP, ), ), ( True, True, FlatIndexParam( metric_type=MetricType.COSINE, ), ), ( True, True, FlatIndexParam( metric_type=MetricType.L2, ), ), ( True, True, IVFIndexParam( metric_type=MetricType.IP, n_list=100, n_iters=10, use_soar=False, ), ), ( True, True, IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), ), ( True, True, IVFIndexParam( metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, ), ), ], indirect=True, ) @pytest.mark.parametrize("doc_num", [500]) @pytest.mark.parametrize("query_num", [10]) @pytest.mark.parametrize("top_k", [1]) def test_recall_with_single_vector_valid_500( self, full_collection_new: Collection, doc_num, query_num, top_k, full_schema_new, request, ): full_schema_params = request.getfixturevalue("full_schema_new") for vector_para in full_schema_params.vectors: if vector_para.name == "vector_fp32_field": metric_type = vector_para.index_param.metric_type break multiple_docs = [ generate_doc_recall(i, full_collection_new.schema) for i in range(doc_num) ] print("len(multiple_docs):\n") print(len(multiple_docs)) # print(multiple_docs) for i in range(10): if i != 0: pass # print(multiple_docs[i * 1000:1000 * (i + 1)]) batchdoc_and_check( full_collection_new, multiple_docs[i * 1000 : 1000 * (i + 1)], operator="insert", ) stats = full_collection_new.stats assert stats.doc_count == len(multiple_docs) doc_ids = ["0", "1"] fetched_docs = full_collection_new.fetch(doc_ids) print("fetched_docs,multiple_docs") print( fetched_docs[doc_ids[0]].vectors["sparse_vector_fp32_field"], fetched_docs[doc_ids[0]].vectors["sparse_vector_fp16_field"], fetched_docs[doc_ids[1]].vectors["sparse_vector_fp32_field"], fetched_docs[doc_ids[1]].vectors["sparse_vector_fp16_field"], "\n", multiple_docs[0].vectors["sparse_vector_fp32_field"], multiple_docs[0].vectors["sparse_vector_fp32_field"], multiple_docs[1].vectors["sparse_vector_fp32_field"], multiple_docs[1].vectors["sparse_vector_fp16_field"], ) full_collection_new.optimize(option=OptimizeOption()) time.sleep(2) query_vectors_map = {} for field_name in DEFAULT_VECTOR_FIELD_NAME.values(): query_vectors_map[field_name] = [ multiple_docs[i].vectors[field_name] for i in range(query_num) ] # Get ground truth mapping ground_truth_map = get_ground_truth_map( full_collection_new, multiple_docs, query_vectors_map, metric_type, top_k ) # Validate ground truth mapping structure for field_name in DEFAULT_VECTOR_FIELD_NAME.values(): assert field_name in ground_truth_map field_gt = ground_truth_map[field_name] assert len(field_gt) == query_num for query_idx in range(query_num): assert query_idx in field_gt relevant_ids = field_gt[query_idx] assert isinstance(relevant_ids, list) assert len(relevant_ids) <= top_k # Print ground truth statistics print(f"Ground Truth for Top-{top_k} Retrieval:") for field_name, field_gt in ground_truth_map.items(): print(f" {field_name}:") for query_idx, relevant_ids in field_gt.items(): print( f" Query {query_idx}: {len(relevant_ids)} relevant docs - {relevant_ids[:5]}{'...' if len(relevant_ids) > 5 else ''}" ) # Calculate Recall@K using ground truth recall_at_k_stats = calculate_recall_at_k( full_collection_new, multiple_docs, query_vectors_map, full_schema_new, k=top_k, expected_doc_ids_scores_map=ground_truth_map, tolerance=0.01, ) print("ground_truth_map:\n") print(ground_truth_map) print("(recall_at_k_stats:\n") print(recall_at_k_stats) print("metric_type:") print(metric_type) # Print Recall@K statistics print(f"Recall@{top_k} using Ground Truth:") for field_name, stats in recall_at_k_stats.items(): print(f" {field_name}:") print( f" Relevant Retrieved: {stats['relevant_retrieved_count']}/{stats['total_relevant_count']}" ) print(f" Recall@{top_k}: {stats['recall_at_k']:.4f}") for k, v in recall_at_k_stats.items(): assert v["recall_at_k"] == 1.0 @pytest.mark.parametrize( "full_schema_new", [ (True, True, HnswIndexParam()), (False, True, IVFIndexParam()), (False, True, FlatIndexParam()), # ——ok ( True, True, HnswIndexParam( metric_type=MetricType.IP, m=16, ef_construction=100, ), ), ( True, True, HnswIndexParam( metric_type=MetricType.COSINE, m=24, ef_construction=150, ), ), # (True, True, HnswIndexParam(metric_type=MetricType.L2, m=32, ef_construction=200, )), ( False, True, FlatIndexParam( metric_type=MetricType.IP, ), ), ( True, True, FlatIndexParam( metric_type=MetricType.COSINE, ), ), # (True, True, FlatIndexParam(metric_type=MetricType.L2, )), ( True, True, IVFIndexParam( metric_type=MetricType.IP, n_list=100, n_iters=10, use_soar=False, ), ), ( True, True, IVFIndexParam( metric_type=MetricType.L2, n_list=200, n_iters=20, use_soar=True, ), ), # (True, True, IVFIndexParam(metric_type=MetricType.COSINE, n_list=150, n_iters=15, use_soar=False, )), ], indirect=True, ) @pytest.mark.parametrize("doc_num", [2000]) @pytest.mark.parametrize("query_num", [2]) @pytest.mark.parametrize("top_k", [1]) @pytest.mark.skip(reason="known bug") def test_recall_with_single_vector_valid_2000( self, full_collection_new: Collection, doc_num, query_num, top_k, full_schema_new, request, ): full_schema_params = request.getfixturevalue("full_schema_new") for vector_para in full_schema_params.vectors: if vector_para.name == "vector_fp32_field": metric_type = vector_para.index_param.metric_type break multiple_docs = [ generate_doc_recall(i, full_collection_new.schema) for i in range(doc_num) ] print("len(multiple_docs):\n") print(len(multiple_docs)) # print(multiple_docs) for i in range(10): if i != 0: pass # print(multiple_docs[i * 1000:1000 * (i + 1)]) batchdoc_and_check( full_collection_new, multiple_docs[i * 1000 : 1000 * (i + 1)], operator="insert", ) stats = full_collection_new.stats assert stats.doc_count == len(multiple_docs) doc_ids = ["0", "1"] fetched_docs = full_collection_new.fetch(doc_ids) print("fetched_docs,multiple_docs") print( fetched_docs[doc_ids[0]].vectors["sparse_vector_fp32_field"], fetched_docs[doc_ids[0]].vectors["sparse_vector_fp16_field"], fetched_docs[doc_ids[1]].vectors["sparse_vector_fp32_field"], fetched_docs[doc_ids[1]].vectors["sparse_vector_fp16_field"], "\n", multiple_docs[0].vectors["sparse_vector_fp32_field"], multiple_docs[0].vectors["sparse_vector_fp32_field"], multiple_docs[1].vectors["sparse_vector_fp32_field"], multiple_docs[1].vectors["sparse_vector_fp16_field"], ) full_collection_new.optimize(option=OptimizeOption()) time.sleep(2) query_vectors_map = {} for field_name in DEFAULT_VECTOR_FIELD_NAME.values(): query_vectors_map[field_name] = [ multiple_docs[i].vectors[field_name] for i in range(query_num) ] # Get ground truth mapping ground_truth_map = get_ground_truth_map( full_collection_new, multiple_docs, query_vectors_map, metric_type, top_k ) # Validate ground truth mapping structure for field_name in DEFAULT_VECTOR_FIELD_NAME.values(): assert field_name in ground_truth_map field_gt = ground_truth_map[field_name] assert len(field_gt) == query_num for query_idx in range(query_num): assert query_idx in field_gt relevant_ids = field_gt[query_idx] assert isinstance(relevant_ids, list) assert len(relevant_ids) <= top_k # Print ground truth statistics print(f"Ground Truth for Top-{top_k} Retrieval:") for field_name, field_gt in ground_truth_map.items(): print(f" {field_name}:") for query_idx, relevant_ids in field_gt.items(): print( f" Query {query_idx}: {len(relevant_ids)} relevant docs - {relevant_ids[:5]}{'...' if len(relevant_ids) > 5 else ''}" ) # Calculate Recall@K using ground truth recall_at_k_stats = calculate_recall_at_k( full_collection_new, multiple_docs, query_vectors_map, full_schema_new, k=top_k, expected_doc_ids_scores_map=ground_truth_map, tolerance=0.01, ) print("ground_truth_map:\n") print(ground_truth_map) print("(recall_at_k_stats:\n") print(recall_at_k_stats) print("metric_type:") print(metric_type) # Print Recall@K statistics print(f"Recall@{top_k} using Ground Truth:") for field_name, stats in recall_at_k_stats.items(): print(f" {field_name}:") print( f" Relevant Retrieved: {stats['relevant_retrieved_count']}/{stats['total_relevant_count']}" ) print(f" Recall@{top_k}: {stats['recall_at_k']:.4f}") for k, v in recall_at_k_stats.items(): assert v["recall_at_k"] == 1.0 ================================================ FILE: python/tests/detail/test_db_config.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pytest import tempfile import os import sys import subprocess import zvec import zvec from zvec import LogType, LogLevel # Error messages INITIALIZATION_ERROR_MSG = "initialization failed" RUNTIME_ERROR_MSG = "RuntimeError" VALUE_ERROR_MSG = "ValueError" TYPE_ERROR_MSG = "TypeError" # ==================== helper ==================== def run_in_subprocess(func): def wrapper(*args, **kwargs): if os.getenv("RUNNING_IN_SUBPROCESS"): return func(*args, **kwargs) env = os.environ.copy() env["RUNNING_IN_SUBPROCESS"] = "1" env["PYTEST_CURRENT_TEST"] = func.__name__ import inspect filepath = inspect.getfile(func) qualname = func.__qualname__.replace(".", "::") test_id = f"{filepath}::{qualname}" project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) env["PYTHONPATH"] = project_root + ":" + env.get("PYTHONPATH", "") cmd = [sys.executable, "-m", "pytest", "-v", "-s", test_id] result = subprocess.run(cmd, env=env, capture_output=True, text=True) if result.returncode != 0: pytest.fail( f"Subprocess test {func.__name__} failed with code {result.returncode}\n" f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" ) return wrapper # ==================== Fixtures ==================== @pytest.fixture(scope="function") def temp_log_dir(tmp_path_factory): return tmp_path_factory.mktemp("logs") # ==================== Tests ==================== class TestDbConfigInitialization: @run_in_subprocess def test_init_default(self): # default config # log_type: Optional[LogType] = LogType.CONSOLE, # log_level: Optional[LogLevel] = LogLevel.WARN, # log_dir: Optional[str] = "./logs", # log_basename: Optional[str] = "zvec.log", # log_file_size: Optional[int] = 2048, # log_overdue_days: Optional[int] = 7, zvec.init() @run_in_subprocess def test_init_file_logger(self): from pathlib import Path import shutil zvec.init( log_level=LogLevel.DEBUG, log_type=LogType.FILE, ) # assert logdir exist log_dir = Path("./logs") assert log_dir.exists() # validate write log col = zvec.create_and_open( "/tmp/test/1", zvec.CollectionSchema( name="test", vectors=zvec.VectorSchema( dimension=4, data_type=zvec.DataType.VECTOR_FP32, name="image", ), ), ) col.insert(docs=[zvec.Doc(id="1", vectors={"image": [1.0, 2.0, 3.0, 4.0]})]) assert any(log_dir.glob("zvec.log.*")) # clear col.destroy() shutil.rmtree(log_dir, ignore_errors=True) @run_in_subprocess def test_init_with_mixed_config(self): zvec.init( memory_limit_mb=128, log_type=LogType.FILE, query_threads=1, log_level=LogLevel.WARN, ) @run_in_subprocess def test_repeated_initialization(self): # Calling init() repeatedly is allowed: # it succeeds but becomes a no-op after the first successful init() zvec.init() class TestDbConfigMemoryLimitValidation: @run_in_subprocess def test_memory_limit_min_valid(self): # MIN_MEMORY_LIMIT_BYTES is 100M with pytest.raises(RuntimeError): zvec.init(memory_limit_mb=99) @run_in_subprocess def test_memory_limit_invalid_value(self): # memory_limit_mb must >= 0 and must be int and if None, set default value with pytest.raises(ValueError): zvec.init(memory_limit_mb=0) with pytest.raises(ValueError): zvec.init(memory_limit_mb=-1) with pytest.raises(TypeError): zvec.init(memory_limit_mb="512") with pytest.raises(TypeError): zvec.init(memory_limit_mb=512.5) class TestDbConfigThreadValidation: @run_in_subprocess def test_query_threads(self): zvec.init(query_threads=1) @run_in_subprocess def test_query_threads_invalid(self): # query_threads must >= 0 and must be int and if None, set default value with pytest.raises(ValueError): zvec.init(query_threads=0) with pytest.raises(ValueError): zvec.init(query_threads=-1) with pytest.raises(TypeError): zvec.init(query_threads="value") with pytest.raises(TypeError): zvec.init(query_threads=512.5) with pytest.raises(TypeError): zvec.init(query_threads="512") @run_in_subprocess def test_optimize_threads(self): zvec.init(optimize_threads=1) @run_in_subprocess def test_optimize_threads_invalid(self): # optimize_threads must >= 0 and must be int and if None, set default value with pytest.raises(ValueError): zvec.init(optimize_threads=0) with pytest.raises(ValueError): zvec.init(optimize_threads=-1) with pytest.raises(TypeError): zvec.init(optimize_threads="value") with pytest.raises(TypeError): zvec.init(optimize_threads=512.5) with pytest.raises(TypeError): zvec.init(optimize_threads="512") class TestDbConfigRatioValidation: @run_in_subprocess def test_init_invert_to_forward_scan_ratio(self): # must be in [0,1] zvec.init(invert_to_forward_scan_ratio=0.8) @run_in_subprocess def test_init_invert_to_forward_scan_ratio_invalid(self): with pytest.raises(ValueError): zvec.init(invert_to_forward_scan_ratio=1.1) with pytest.raises(ValueError): zvec.init(invert_to_forward_scan_ratio=-0.1) with pytest.raises(TypeError): zvec.init(invert_to_forward_scan_ratio="0.8") @run_in_subprocess def test_init_brute_force_by_keys_ratio(self): zvec.init(brute_force_by_keys_ratio=0.8) @run_in_subprocess def test_init_brute_force_by_keys_ratio_invalid(self): with pytest.raises(ValueError): zvec.init(brute_force_by_keys_ratio=1.1) with pytest.raises(ValueError): zvec.init(brute_force_by_keys_ratio=-0.1) with pytest.raises(TypeError): zvec.init(brute_force_by_keys_ratio="0.8") class TestDbConfigLogValidation: @run_in_subprocess def test_log_type_valid(self): zvec.init(log_type=LogType.CONSOLE) @run_in_subprocess def test_log_type_invalid(self): with pytest.raises(TypeError): zvec.init(log_type="FILE") with pytest.raises(TypeError): zvec.init(log_type="") with pytest.raises(TypeError): zvec.init(log_type="invalid") with pytest.raises(TypeError): zvec.init(log_type=123) @run_in_subprocess def test_log_level_valid(self): zvec.init(log_level=LogLevel.ERROR) @run_in_subprocess def test_log_level_invalid(self): with pytest.raises(TypeError): zvec.init(log_level="WARN") with pytest.raises(TypeError): zvec.init(log_level="") with pytest.raises(TypeError): zvec.init(log_level="invalid") with pytest.raises(TypeError): zvec.init(log_level=123) @run_in_subprocess def test_init_file_logger(self): from pathlib import Path import shutil temp_dir = tempfile.mkdtemp(prefix="log_test_") abs_temp_dir = os.path.abspath(temp_dir) zvec.init( log_level=LogLevel.DEBUG, log_type=LogType.FILE, log_dir=abs_temp_dir, log_basename="test", ) # assert logdir exist log_dir = Path(abs_temp_dir) assert log_dir.exists() # validate write log col = zvec.create_and_open( "/tmp/test/1", zvec.CollectionSchema( name="test", vectors=zvec.VectorSchema( dimension=4, data_type=zvec.DataType.VECTOR_FP32, name="image", ), ), ) col.insert(docs=[zvec.Doc(id="1", vectors={"image": [1.0, 2.0, 3.0, 4.0]})]) assert any(log_dir.glob("test.*")) # clear col.destroy() shutil.rmtree(log_dir, ignore_errors=True) @run_in_subprocess def test_log_file_size_invalid(self): with pytest.raises(TypeError): zvec.init(log_type=LogType.FILE, log_file_size="df") with pytest.raises(ValueError): zvec.init(log_type=LogType.FILE, log_file_size=0) with pytest.raises(ValueError): zvec.init(log_type=LogType.FILE, log_file_size=-1) @run_in_subprocess def test_log_overdue_days_invalid(self): with pytest.raises(TypeError): zvec.init(log_type=LogType.FILE, log_overdue_days="df") with pytest.raises(ValueError): zvec.init(log_type=LogType.FILE, log_overdue_days=0) with pytest.raises(ValueError): zvec.init(log_type=LogType.FILE, log_overdue_days=-1) ================================================ FILE: python/tests/test_collection.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import pytest import zvec from zvec import ( Collection, CollectionOption, DataType, Doc, FieldSchema, HnswIndexParam, InvertIndexParam, LogLevel, LogType, VectorSchema, StatusCode, IndexOption, IndexType, VectorQuery, OptimizeOption, ) # ==================== Common ==================== @pytest.fixture(scope="session") def collection_schema(): return zvec.CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam() ), FieldSchema("weight", DataType.FLOAT, nullable=True), FieldSchema("height", DataType.INT32, nullable=True), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) @pytest.fixture(scope="session") def collection_option(): return CollectionOption(read_only=False, enable_mmap=True) @pytest.fixture def single_doc(): id = 0 return Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": id + 140}, vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) @pytest.fixture def multiple_docs(): return [ Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": 210}, vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) for id in range(1, 101) ] @pytest.fixture(scope="function") def test_collection( tmp_path_factory, collection_schema, collection_option ) -> Collection: """ Function-scoped fixture: creates and opens a collection. Uses tmp_path_factory to ensure shared temp dir per class. """ # Create unique temp directory for this test class temp_dir = tmp_path_factory.mktemp("zvec") collection_path = temp_dir / "test_collection" coll = zvec.create_and_open( path=str(collection_path), schema=collection_schema, option=collection_option ) assert coll is not None, "Failed to create and open collection" assert coll.path == str(collection_path) assert coll.schema.name == collection_schema.name assert list(coll.schema.fields) == list(collection_schema.fields) assert list(coll.schema.vectors) == list(collection_schema.vectors) assert coll.option.read_only == collection_option.read_only assert coll.option.enable_mmap == collection_option.enable_mmap try: yield coll finally: if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") @pytest.fixture def collection_with_single_doc(test_collection: Collection, single_doc) -> Collection: # Setup: insert single doc assert test_collection.stats.doc_count == 0 result = test_collection.insert(single_doc) assert bool(result) assert result.ok() assert test_collection.stats.doc_count == 1 yield test_collection # Teardown: delete single doc test_collection.delete(single_doc.id) assert test_collection.stats.doc_count == 0 @pytest.fixture def collection_with_multiple_docs( test_collection: Collection, multiple_docs ) -> Collection: # Setup: insert multiple docs assert test_collection.stats.doc_count == 0 result = test_collection.insert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() assert test_collection.stats.doc_count == len(multiple_docs) yield test_collection # Teardown: delete multiple docs test_collection.delete([doc.id for doc in multiple_docs]) # ==================== Tests ==================== # ---------------------------- # Config Test Case # ---------------------------- class TestConfig: def test_config(self): zvec.init(log_type=LogType.CONSOLE, log_level=LogLevel.ERROR, log_dir="./log") # ---------------------------- # Collection DDL Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionDDL: def test_collection_stats(self, test_collection: Collection): assert test_collection.stats is not None stats = test_collection.stats assert stats.doc_count == 0 assert len(stats.index_completeness) == 2 assert stats.index_completeness["dense"] == 1 assert stats.index_completeness["sparse"] == 1 # ---------------------------- # Collection Index DDL Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionIndexDDL: def test_create_index(self, test_collection: Collection): # before create field_schema = test_collection.schema.field("weight") assert field_schema is not None assert field_schema.data_type == DataType.FLOAT assert field_schema.name == "weight" index_param = field_schema.index_param assert index_param is None # create test_collection.create_index( field_name="weight", index_param=InvertIndexParam(), option=IndexOption() ) assert test_collection.schema is not None field_schema = test_collection.schema.field("weight") assert field_schema is not None assert field_schema.data_type == DataType.FLOAT assert field_schema.name == "weight" index_param = field_schema.index_param assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is False assert index_param.enable_extended_wildcard is False def test_drop_index(self, test_collection: Collection): # before drop field_schema = test_collection.schema.field("name") assert field_schema is not None assert field_schema.data_type == DataType.STRING assert field_schema.name == "name" index_param = field_schema.index_param assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is False assert index_param.enable_extended_wildcard is False # drop test_collection.drop_index("name") field_schema = test_collection.schema.field("name") assert field_schema is not None assert field_schema.data_type == DataType.STRING assert field_schema.name == "name" # without index index_param = field_schema.index_param assert index_param is None def test_create_index_field_is_not_exist(self, test_collection: Collection): with pytest.raises(Exception) as e: test_collection.create_index( field_name="not_exist", index_param=InvertIndexParam(), ) index_param = field_schema.index_param assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is False assert index_param.enable_extended_wildcard is False def test_drop_index(self, test_collection: Collection): # before drop field_schema = test_collection.schema.field("name") assert field_schema is not None assert field_schema.data_type == DataType.STRING assert field_schema.name == "name" index_param = field_schema.index_param assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is False assert index_param.enable_extended_wildcard is False # drop test_collection.drop_index("name") field_schema = test_collection.schema.field("name") assert field_schema is not None assert field_schema.data_type == DataType.STRING assert field_schema.name == "name" # without index index_param = field_schema.index_param assert index_param is None def test_create_index_field_is_not_exist(self, test_collection: Collection): with pytest.raises(Exception) as e: test_collection.create_index( field_name="not_exist", index_param=InvertIndexParam(), ) # ---------------------------- # Collection Column DDL Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionColumnDDL: def test_create_column(self, test_collection: Collection): # before create column field_schema = test_collection.schema.field("age") assert field_schema is None # create test_collection.add_column(FieldSchema("age", DataType.INT32, nullable=True)) field_schema = test_collection.schema.field("age") assert field_schema is not None assert field_schema.data_type == DataType.INT32 assert field_schema.name == "age" assert field_schema.index_param is None def test_create_column_is_nullable(self, test_collection: Collection): with pytest.raises(ValueError): test_collection.add_column( FieldSchema("age", DataType.INT32, nullable=False) ) def test_drop_column(self, test_collection: Collection): # before drop column field_schema = test_collection.schema.field("id") assert field_schema is not None assert field_schema.data_type == DataType.INT64 assert field_schema.name == "id" index_param = field_schema.index_param assert index_param is not None assert index_param.type == IndexType.INVERT # drop test_collection.drop_column("id") field_schema = test_collection.schema.field("id") assert field_schema is None def test_alert_column_to_rename(self, test_collection: Collection): # before alert column field_schema = test_collection.schema.field("id") assert field_schema is not None assert field_schema.data_type == DataType.INT64 assert field_schema.name == "id" index_param = field_schema.index_param assert index_param is not None assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is True assert index_param.enable_extended_wildcard is False # alert rename test_collection.alter_column("id", "doc_id") # validate old column field_schema = test_collection.schema.field("id") assert field_schema is None # validate rename column field_schema = test_collection.schema.field("doc_id") assert field_schema is not None assert field_schema.data_type == DataType.INT64 assert field_schema.name == "doc_id" assert field_schema.nullable is False index_param = field_schema.index_param assert index_param is not None assert index_param.type == IndexType.INVERT assert index_param.enable_range_optimization is True assert index_param.enable_extended_wildcard is False def test_alert_column_to_modify_schema(self, test_collection: Collection): # before alert column field_schema = test_collection.schema.field("id") assert field_schema is not None assert field_schema.data_type == DataType.INT64 assert field_schema.name == "id" index_param = field_schema.index_param assert index_param.type == IndexType.INVERT test_collection.alter_column( old_name="id", field_schema=FieldSchema("doc_id", DataType.UINT64, nullable=True), ) field_schema = test_collection.schema.field("doc_id") assert field_schema is not None assert field_schema.data_type == DataType.UINT64 assert field_schema.name == "doc_id" def test_column_with_other_dtype(self, test_collection: Collection): # only allow number type test_collection.add_column(FieldSchema("age", DataType.INT32, nullable=True)) with pytest.raises(ValueError): test_collection.add_column(FieldSchema("full_name", DataType.STRING)) with pytest.raises(ValueError): test_collection.drop_column("name") with pytest.raises(ValueError): test_collection.alter_column(old_name="name", new_name="full_name") with pytest.raises(ValueError): test_collection.alter_column( old_name="name", field_schema=FieldSchema("full_name", DataType.STRING) ) # ---------------------------- # Collection Optimize Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionOptimize: def test_collection_optimize(self, test_collection: Collection): test_collection.optimize(option=OptimizeOption()) # ---------------------------- # Collection Fetch Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionFetch: def test_collection_fetch( self, collection_with_single_doc: Collection, single_doc: Doc ): result = collection_with_single_doc.fetch(ids=[single_doc.id]) assert bool(result) assert single_doc.id in result.keys() doc = result[single_doc.id] assert doc is not None assert doc.id == single_doc.id assert set(doc.field_names()) == set(single_doc.field_names()) for field_name in doc.field_names(): if field_name in ["dense", "sparse"]: continue assert doc.field(field_name) == single_doc.field(field_name) def test_collection_fetch_contains_nodata_ids( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): ids = [doc.id for doc in multiple_docs] no_data_key = "x" ids_with_no_data = [no_data_key] + ids result = collection_with_multiple_docs.fetch(ids=ids_with_no_data) assert bool(result) assert len(result) == len(ids) assert no_data_key not in result # ---------------------------- # Collection Insert Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionInsert: def test_collection_insert(self, test_collection, single_doc): result = test_collection.insert(single_doc) assert bool(result) assert result.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == 1 def test_collection_insert_with_nullable_false_field(self, test_collection): # id, name's nullable == False # weight, height's nullable == True doc = Doc( id="0", fields={ "id": 1, "name": "test", }, vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) result = test_collection.insert(doc) assert bool(result) assert result.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == 1 def test_collection_insert_without_nullable_false_field(self, test_collection): # id, name's nullable == False # weight, height's nullable == True # without id, name doc = Doc( id="0", vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) with pytest.raises(ValueError) as e: # ValueError: doc validate failed: field[id] is configured not nullable, # but doc does not contain this field test_collection.insert(doc) assert "field[id] is configured not nullable" in str(e.value) # without name doc = Doc( id="0", fields={ "id": 1, }, vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) with pytest.raises(ValueError) as e: test_collection.insert(doc) assert "field[name] is configured not nullable" in str(e.value) def test_collection_insert_with_nullable_true_field(self, test_collection): # id, name's nullable == False # weight, height's nullable == True doc = Doc( id="0", fields={ "id": 1, "name": "test", }, vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, ) result = test_collection.insert(doc) assert bool(result) assert result.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == 1 result = test_collection.fetch(ids=[doc.id]) assert doc.id in result ret = result[doc.id] assert ret.field("id") == 1 assert ret.field("name") == "test" assert ret.field("weight") is None assert ret.field("height") is None def test_collection_insert_batch(self, test_collection, multiple_docs): result = test_collection.insert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == len(multiple_docs) def test_collection_insert_duplicate( self, test_collection, single_doc, multiple_docs ): test_collection.insert(single_doc) result = test_collection.insert(single_doc) assert bool(result) assert result.code() == StatusCode.ALREADY_EXISTS stats = test_collection.stats assert stats is not None assert stats.doc_count == 1 # ---------------------------- # Collection Update Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionUpdate: def test_empty_collection_update( self, test_collection: Collection, single_doc: Doc ): result = test_collection.update(single_doc) assert bool(result) assert result.code() == StatusCode.NOT_FOUND stats = test_collection.stats assert stats is not None assert stats.doc_count == 0 def test_collection_update_with_nullable_false_field( self, collection_with_single_doc: Collection, single_doc: Doc ): # id, name's nullable == False # weight, height's nullable == True # update doc field id doc = Doc( id=single_doc.id, fields={"id": single_doc.field("id") + 1}, ) result = collection_with_single_doc.update(doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 # fetch result = collection_with_single_doc.fetch(ids=[doc.id]) assert doc.id in result ret = result[doc.id] assert ret.field("id") == doc.field("id") assert ret.field("name") == single_doc.field("name") assert ret.field("weight") == single_doc.field("weight") assert ret.field("height") == single_doc.field("height") def test_collection_update_with_nullable_false_field_is_none( self, collection_with_single_doc: Collection, single_doc: Doc ): # id, name's nullable == False # weight, height's nullable == True # update doc field id doc = Doc( id=single_doc.id, fields={"id": None}, ) with pytest.raises(ValueError) as e: # ValueError: doc validate failed: field[id] is configured not nullable, # but doc does not contain this field collection_with_single_doc.update(doc) doc = Doc( id=single_doc.id, fields={"id": single_doc.field("id") + 1, "weight": None}, ) result = collection_with_single_doc.update(doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 ret = collection_with_single_doc.fetch(ids=[doc.id]) assert doc.id in ret ret = ret[doc.id] assert ret.field("id") == doc.field("id") assert ret.field("name") == single_doc.field("name") assert ret.field("weight") is None assert ret.field("height") == single_doc.field("height") def test_collection_update_without_nullable_false_field( self, collection_with_single_doc: Collection, single_doc: Doc ): # id, name's nullable == False # weight, height's nullable == True # update doc field weight doc = Doc( id=single_doc.id, fields={"weight": single_doc.field("weight") + 1}, ) result = collection_with_single_doc.update(doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 # fetch ret = collection_with_single_doc.fetch(ids=[doc.id]) assert doc.id in ret ret = ret[doc.id] assert ret.field("id") == single_doc.field("id") assert ret.field("name") == single_doc.field("name") assert ret.field("weight") == doc.field("weight") assert ret.field("height") == single_doc.field("height") def test_collection_update_without_nullable_false_field_set_null( self, collection_with_single_doc: Collection, single_doc: Doc ): # id, name's nullable == False # weight, height's nullable == True # update doc field weight is None doc = Doc( id=single_doc.id, fields={"weight": None}, ) result = collection_with_single_doc.update(doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 # fetch ret = collection_with_single_doc.fetch(ids=[doc.id]) assert doc.id in ret ret = ret[doc.id] assert ret.field("id") == single_doc.field("id") assert ret.field("name") == single_doc.field("name") assert ret.field("weight") is None assert ret.field("height") == single_doc.field("height") def test_empty_collection_update_batch( self, test_collection: Collection, multiple_docs ): result = test_collection.update(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.code() == StatusCode.NOT_FOUND stats = test_collection.stats assert stats is not None assert stats.doc_count == 0 def test_collection_update( self, collection_with_single_doc: Collection, single_doc ): result = collection_with_single_doc.update(single_doc) assert bool(result) == 1 assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 def test_collection_update_batch( self, collection_with_multiple_docs: Collection, multiple_docs ): result = collection_with_multiple_docs.update(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = collection_with_multiple_docs.stats assert stats is not None assert stats.doc_count == len(multiple_docs) # ---------------------------- # Collection Upsert Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionUpsert: def test_empty_collection_upsert(self, test_collection: Collection, single_doc): result = test_collection.upsert(single_doc) assert bool(result) assert result.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == 1 def test_empty_collection_upsert_batch( self, test_collection: Collection, multiple_docs ): result = test_collection.upsert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = test_collection.stats assert stats is not None assert stats.doc_count == len(multiple_docs) def test_collection_upsert( self, collection_with_single_doc: Collection, single_doc, multiple_docs ): # doc is existing # upsert => update result = collection_with_single_doc.upsert(single_doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 def test_collection_upsert_batch( self, collection_with_multiple_docs: Collection, multiple_docs ): # doc is existing # upsert => update result = collection_with_multiple_docs.upsert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = collection_with_multiple_docs.stats assert stats is not None assert stats.doc_count == len(multiple_docs) # ---------------------------- # Collection Upsert Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionDelete: def test_empty_collection_delete(self, test_collection: Collection, single_doc): result = test_collection.delete(single_doc.id) assert bool(result) assert result.code() == StatusCode.NOT_FOUND def test_empty_collection_delete_batch( self, test_collection: Collection, multiple_docs ): result = test_collection.delete([doc.id for doc in multiple_docs]) assert len(result) == len(multiple_docs) for item in result: assert item.code() == StatusCode.NOT_FOUND def test_collection_delete( self, collection_with_single_doc: Collection, single_doc ): result = collection_with_single_doc.delete(single_doc.id) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 0 result = collection_with_single_doc.insert(single_doc) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 1 def test_collection_delete_batch( self, collection_with_multiple_docs: Collection, multiple_docs ): result = collection_with_multiple_docs.delete([doc.id for doc in multiple_docs]) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = collection_with_multiple_docs.stats assert stats is not None assert stats.doc_count == 0 def test_collection_delete_by_filter( self, collection_with_single_doc: Collection, single_doc ): collection_with_single_doc.delete_by_filter( filter=f"height={single_doc.field('height')}" ) stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 0 def test_collection_delete_by_filter_invert_field( self, collection_with_single_doc: Collection, single_doc ): collection_with_single_doc.delete_by_filter( filter=f"id={single_doc.field('id')}" ) stats = collection_with_single_doc.stats assert stats is not None assert stats.doc_count == 0 # ---------------------------- # Collection Upsert Test Case # ---------------------------- @pytest.mark.usefixtures("test_collection") class TestCollectionQuery: def test_empty_collection_query(self, test_collection: Collection): result = test_collection.query() assert len(result) == 0 def test_collection_query(self, collection_with_single_doc: Collection, single_doc): result = collection_with_single_doc.query() assert len(result) == 1 doc = result[0] assert doc.id == single_doc.id assert "dense" not in doc.field_names() assert "sparse" not in doc.field_names() field_without_vector = single_doc.field_names() assert set(doc.field_names()) == set(field_without_vector) for name in field_without_vector: assert doc.field(name) == single_doc.field(name) def test_collection_query_with_include_vector( self, collection_with_single_doc: Collection, single_doc ): result = collection_with_single_doc.query(include_vector=True) assert len(result) == 1 doc = result[0] assert doc.vector("dense") is not None assert doc.vector("sparse") is not None def test_collection_query_with_output_fields( self, collection_with_single_doc: Collection, single_doc ): result = collection_with_single_doc.query(output_fields=["id", "name"]) assert len(result) == 1 doc = result[0] assert doc.id == single_doc.id assert len(doc.field_names()) == 2 assert set(doc.field_names()) == {"id", "name"} def test_collection_query_with_topk( self, collection_with_multiple_docs: Collection ): result = collection_with_multiple_docs.query() assert len(result) == 10 result = collection_with_multiple_docs.query(topk=5) assert len(result) == 5 def test_collection_query_with_range_filter_int_field( self, collection_with_multiple_docs: Collection, multiple_docs ): index = 10 idx = multiple_docs[index].id result = collection_with_multiple_docs.query(filter=f"id>{idx}", topk=100) assert len(result) == len(multiple_docs) - index - 1 result = collection_with_multiple_docs.query(filter=f"id>={idx}", topk=100) assert len(result) == len(multiple_docs) - index result = collection_with_multiple_docs.query(filter=f"id<{idx}", topk=100) assert len(result) == index result = collection_with_multiple_docs.query(filter=f"id<={idx}", topk=100) assert len(result) == index + 1 result = collection_with_multiple_docs.query(filter=f"id={idx}", topk=100) assert len(result) == 1 result = collection_with_multiple_docs.query(filter=f"id!={idx}", topk=100) assert len(result) == len(multiple_docs) - 1 left, right = 10, 90 l_id, r_id = multiple_docs[left].id, multiple_docs[right].id result = collection_with_multiple_docs.query( filter=f"id>{l_id} and id<{r_id}", topk=100 ) assert len(result) == right - left - 1 result = collection_with_multiple_docs.query( filter=f"id>={l_id} and id<{r_id}", topk=100 ) assert len(result) == right - left result = collection_with_multiple_docs.query( filter=f"id>={l_id} and id<={r_id}", topk=100 ) assert len(result) == right - left + 1 result = collection_with_multiple_docs.query( filter=f"id<{l_id} or id>{r_id}", topk=100 ) assert len(result) == len(multiple_docs) - (right - left) - 1 result = collection_with_multiple_docs.query( filter=f"id<={l_id} or id>{r_id}", topk=100 ) assert len(result) == len(multiple_docs) - (right - left) result = collection_with_multiple_docs.query( filter=f"id<={l_id} or id>={r_id}", topk=100 ) assert len(result) == len(multiple_docs) - (right - left) + 1 result = collection_with_multiple_docs.query(filter="id in (1)", topk=100) assert len(result) == 1 def test_collection_query_with_vector_and_id( self, collection_with_single_doc: Collection, single_doc: Doc ): with pytest.raises(ValueError): collection_with_single_doc.query( VectorQuery( field_name="dense", id=single_doc.id, vector=single_doc.vector("dense"), ) ) def test_collection_query_with_filter_not_in( self, collection_with_multiple_docs: Collection, multiple_docs ): result = collection_with_multiple_docs.query(filter="id not in (1)", topk=100) assert len(result) == len(multiple_docs) - 1 def test_collection_with_error_query_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): query = VectorQuery( field_name="dense", vector=multiple_docs[0].vector("dense"), param=[1, 2, 3] ) with pytest.raises(TypeError): result = collection_with_multiple_docs.query( filter="id in (1)", topk=100, vectors=query ) def test_collection_query_by_id( self, collection_with_multiple_docs: Collection, multiple_docs ): result = collection_with_multiple_docs.query( VectorQuery(field_name="dense", id=multiple_docs[0].id) ) assert len(result) == 10 def test_collection_query_multi_vector_with_same_field( self, collection_with_multiple_docs: Collection, multiple_docs ): with pytest.raises(ValueError): collection_with_multiple_docs.query( [ VectorQuery( field_name="dense", vector=multiple_docs[0].vector("dense") ), VectorQuery( field_name="dense", vector=multiple_docs[0].vector("dense") ), ] ) @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_dense_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): pass ================================================ FILE: python/tests/test_collection_hnsw_rabitq.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import platform import sys import pytest import math import zvec pytestmark = pytest.mark.skipif( not (sys.platform == "linux" and platform.machine() in ("x86_64", "AMD64")), reason="HNSW RaBitQ only supported on Linux x86_64", ) from zvec import ( Collection, CollectionOption, DataType, Doc, FieldSchema, HnswRabitqIndexParam, HnswRabitqQueryParam, MetricType, VectorSchema, VectorQuery, ) # ==================== Fixtures ==================== @pytest.fixture(scope="session") def hnsw_rabitq_collection_schema(): """Create a collection schema with HNSW RaBitQ index.""" return zvec.CollectionSchema( name="test_hnsw_rabitq_collection", fields=[ FieldSchema("id", DataType.INT64, nullable=False), FieldSchema("name", DataType.STRING, nullable=False), ], vectors=[ VectorSchema( "embedding", DataType.VECTOR_FP32, dimension=128, index_param=HnswRabitqIndexParam( metric_type=MetricType.L2, m=16, ef_construction=200, total_bits=7, num_clusters=64, ), ), ], ) @pytest.fixture(scope="session") def collection_option(): """Create collection options.""" return CollectionOption(read_only=False, enable_mmap=True) @pytest.fixture def single_doc(): """Create a single document for testing.""" return Doc( id="0", fields={"id": 0, "name": "test_doc_0"}, vectors={"embedding": [0.1 + i * 0.01 for i in range(128)]}, ) @pytest.fixture def multiple_docs(): """Create multiple documents for testing.""" return [ Doc( id=f"{i}", fields={"id": i, "name": f"test_doc_{i}"}, vectors={"embedding": [i * 0.1 + j * 0.01 for j in range(128)]}, ) for i in range(1, 101) ] @pytest.fixture(scope="function") def hnsw_rabitq_collection( tmp_path_factory, hnsw_rabitq_collection_schema, collection_option ) -> Collection: """ Function-scoped fixture: creates and opens a collection with HNSW RaBitQ index. """ temp_dir = tmp_path_factory.mktemp("zvec_hnsw_rabitq") collection_path = temp_dir / "test_hnsw_rabitq_collection" coll = zvec.create_and_open( path=str(collection_path), schema=hnsw_rabitq_collection_schema, option=collection_option, ) assert coll is not None, "Failed to create and open HNSW RaBitQ collection" assert coll.path == str(collection_path) assert coll.schema.name == hnsw_rabitq_collection_schema.name try: yield coll finally: if hasattr(coll, "destroy") and coll is not None: try: coll.destroy() except Exception as e: print(f"Warning: failed to destroy collection: {e}") @pytest.fixture def collection_with_single_doc( hnsw_rabitq_collection: Collection, single_doc: Doc ) -> Collection: """Setup: insert single doc into collection.""" assert hnsw_rabitq_collection.stats.doc_count == 0 result = hnsw_rabitq_collection.insert(single_doc) assert bool(result) assert result.ok() assert hnsw_rabitq_collection.stats.doc_count == 1 yield hnsw_rabitq_collection # Teardown: delete single doc hnsw_rabitq_collection.delete(single_doc.id) assert hnsw_rabitq_collection.stats.doc_count == 0 @pytest.fixture def collection_with_multiple_docs( hnsw_rabitq_collection: Collection, multiple_docs: list[Doc] ) -> Collection: """Setup: insert multiple docs into collection.""" assert hnsw_rabitq_collection.stats.doc_count == 0 result = hnsw_rabitq_collection.insert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() assert hnsw_rabitq_collection.stats.doc_count == len(multiple_docs) yield hnsw_rabitq_collection # Teardown: delete multiple docs hnsw_rabitq_collection.delete([doc.id for doc in multiple_docs]) # ==================== Tests ==================== @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionCreation: """Test HNSW RaBitQ collection creation and schema validation.""" def test_collection_creation( self, hnsw_rabitq_collection: Collection, hnsw_rabitq_collection_schema ): """Test that collection is created with correct schema.""" assert hnsw_rabitq_collection is not None assert hnsw_rabitq_collection.schema.name == hnsw_rabitq_collection_schema.name assert len(hnsw_rabitq_collection.schema.fields) == len( hnsw_rabitq_collection_schema.fields ) assert len(hnsw_rabitq_collection.schema.vectors) == len( hnsw_rabitq_collection_schema.vectors ) def test_vector_schema_validation(self, hnsw_rabitq_collection: Collection): """Test that vector schema has correct HNSW RaBitQ configuration.""" vector_schema = hnsw_rabitq_collection.schema.vector("embedding") assert vector_schema is not None assert vector_schema.name == "embedding" assert vector_schema.data_type == DataType.VECTOR_FP32 assert vector_schema.dimension == 128 index_param = vector_schema.index_param assert index_param is not None assert index_param.metric_type == MetricType.L2 assert index_param.m == 16 assert index_param.ef_construction == 200 assert index_param.total_bits == 7 assert index_param.num_clusters == 64 def test_collection_stats(self, hnsw_rabitq_collection: Collection): """Test initial collection statistics.""" stats = hnsw_rabitq_collection.stats assert stats is not None assert stats.doc_count == 0 assert len(stats.index_completeness) == 1 assert stats.index_completeness["embedding"] == 1 @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionInsert: """Test document insertion into HNSW RaBitQ collection.""" def test_insert_single_doc( self, hnsw_rabitq_collection: Collection, single_doc: Doc ): """Test inserting a single document.""" result = hnsw_rabitq_collection.insert(single_doc) assert bool(result) assert result.ok() stats = hnsw_rabitq_collection.stats assert stats is not None assert stats.doc_count == 1 def test_insert_multiple_docs( self, hnsw_rabitq_collection: Collection, multiple_docs: list[Doc] ): """Test inserting multiple documents.""" result = hnsw_rabitq_collection.insert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() stats = hnsw_rabitq_collection.stats assert stats is not None assert stats.doc_count == len(multiple_docs) @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionFetch: """Test document fetching from HNSW RaBitQ collection.""" def test_fetch_single_doc( self, collection_with_single_doc: Collection, single_doc: Doc ): """Test fetching a single document by ID.""" result = collection_with_single_doc.fetch(ids=[single_doc.id]) assert bool(result) assert single_doc.id in result.keys() doc = result[single_doc.id] assert doc is not None assert doc.id == single_doc.id assert doc.field("id") == single_doc.field("id") assert doc.field("name") == single_doc.field("name") def test_fetch_multiple_docs( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test fetching multiple documents by IDs.""" ids = [doc.id for doc in multiple_docs[:10]] result = collection_with_multiple_docs.fetch(ids=ids) assert bool(result) assert len(result) == len(ids) for doc_id in ids: assert doc_id in result doc = result[doc_id] assert doc is not None assert doc.id == doc_id def test_fetch_nonexistent_doc(self, collection_with_single_doc: Collection): """Test fetching a non-existent document.""" result = collection_with_single_doc.fetch(ids=["nonexistent_id"]) assert len(result) == 0 @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionQuery: """Test vector search queries on HNSW RaBitQ collection.""" def test_query_by_vector( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying by vector with HNSW RaBitQ index.""" query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) result = collection_with_multiple_docs.query(vectors=query, topk=10) assert len(result) > 0 assert len(result) <= 10 # First result should be the query document itself (or very close) first_doc = result[0] assert first_doc is not None assert first_doc.id is not None def test_query_by_id( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying by document ID with HNSW RaBitQ index.""" query = VectorQuery( field_name="embedding", id=multiple_docs[0].id, param=HnswRabitqQueryParam(ef=300), ) result = collection_with_multiple_docs.query(vectors=query, topk=10) assert len(result) > 0 assert len(result) <= 10 def test_query_with_different_ef_values( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying with different ef parameter values.""" query_vector = multiple_docs[0].vector("embedding") # Test with ef=100 query_100 = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=100), ) result_100 = collection_with_multiple_docs.query(vectors=query_100, topk=10) assert len(result_100) > 0 # Test with ef=500 query_500 = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=500), ) result_500 = collection_with_multiple_docs.query(vectors=query_500, topk=10) assert len(result_500) > 0 def test_query_with_topk( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying with different topk values.""" query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) # Test topk=5 result_5 = collection_with_multiple_docs.query(vectors=query, topk=5) assert len(result_5) <= 5 # Test topk=20 result_20 = collection_with_multiple_docs.query(vectors=query, topk=20) assert len(result_20) <= 20 def test_query_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying with filter conditions.""" query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) # Query with id filter result = collection_with_multiple_docs.query( vectors=query, topk=10, filter="id < 50" ) assert len(result) > 0 for doc in result: assert doc.field("id") < 50 def test_query_with_output_fields( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying with specific output fields.""" query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) result = collection_with_multiple_docs.query( vectors=query, topk=10, output_fields=["id", "name"] ) assert len(result) > 0 first_doc = result[0] assert "id" in first_doc.field_names() assert "name" in first_doc.field_names() def test_query_with_include_vector( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test querying with vector data included in results.""" query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) result = collection_with_multiple_docs.query( vectors=query, topk=10, include_vector=True ) assert len(result) > 0 first_doc = result[0] assert first_doc.vector("embedding") is not None assert len(first_doc.vector("embedding")) == 128 @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionUpdate: """Test document update in HNSW RaBitQ collection.""" def test_update_doc_fields( self, collection_with_single_doc: Collection, single_doc: Doc ): """Test updating document fields.""" updated_doc = Doc( id=single_doc.id, fields={"id": single_doc.field("id"), "name": "updated_name"}, ) result = collection_with_single_doc.update(updated_doc) assert bool(result) assert result.ok() # Verify update fetched = collection_with_single_doc.fetch(ids=[single_doc.id]) assert single_doc.id in fetched doc = fetched[single_doc.id] assert doc.field("name") == "updated_name" def test_update_doc_vector( self, collection_with_single_doc: Collection, single_doc: Doc ): """Test updating document vector.""" new_vector = [0.5 + i * 0.01 for i in range(128)] updated_doc = Doc( id=single_doc.id, vectors={"embedding": new_vector}, ) result = collection_with_single_doc.update(updated_doc) assert bool(result) assert result.ok() # Verify update fetched = collection_with_single_doc.fetch( ids=[single_doc.id], ) assert single_doc.id in fetched doc = fetched[single_doc.id] assert doc.vector("embedding") is not None embedding = doc.vector("embedding") assert len(embedding) == 128 # Verify vector values are approximately equal (float comparison) for i in range(128): assert math.isclose(embedding[i], new_vector[i], rel_tol=1e-5) @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionDelete: """Test document deletion from HNSW RaBitQ collection.""" def test_delete_single_doc( self, collection_with_single_doc: Collection, single_doc: Doc ): """Test deleting a single document.""" result = collection_with_single_doc.delete(single_doc.id) assert bool(result) assert result.ok() stats = collection_with_single_doc.stats assert stats.doc_count == 0 def test_delete_multiple_docs( self, collection_with_multiple_docs: Collection, multiple_docs: list[Doc] ): """Test deleting multiple documents.""" ids_to_delete = [doc.id for doc in multiple_docs[:10]] result = collection_with_multiple_docs.delete(ids_to_delete) assert len(result) == len(ids_to_delete) for item in result: assert item.ok() stats = collection_with_multiple_docs.stats assert stats.doc_count == len(multiple_docs) - len(ids_to_delete) @pytest.mark.usefixtures("hnsw_rabitq_collection") class TestHnswRabitqCollectionOptimizeAndReopen: """Test collection optimize and reopen functionality.""" def test_optimize_close_reopen_and_query( self, tmp_path_factory, hnsw_rabitq_collection_schema, collection_option, multiple_docs: list[Doc], ): """Test inserting 100 docs, optimize, close, reopen and query.""" # Create collection and insert 100 documents temp_dir = tmp_path_factory.mktemp("zvec_hnsw_rabitq_optimize") collection_path = temp_dir / "test_optimize_collection" coll = zvec.create_and_open( path=str(collection_path), schema=hnsw_rabitq_collection_schema, option=collection_option, ) assert coll is not None assert coll.stats.doc_count == 0 # Insert 100 documents result = coll.insert(multiple_docs) assert len(result) == len(multiple_docs) for item in result: assert item.ok() assert coll.stats.doc_count == len(multiple_docs) # Call optimize from zvec import OptimizeOption coll.optimize(option=OptimizeOption()) # Verify data is still accessible after optimize query_vector = multiple_docs[0].vector("embedding") query = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) result_before_close = coll.query(vectors=query, topk=10) assert len(result_before_close) > 0 # Close collection (destroy will close it) collection_path_str = str(collection_path) del coll # Reopen collection reopened_coll = zvec.open(path=collection_path_str, option=collection_option) assert reopened_coll is not None assert reopened_coll.stats.doc_count == len(multiple_docs) # Execute query on reopened collection query_after_reopen = VectorQuery( field_name="embedding", vector=query_vector, param=HnswRabitqQueryParam(ef=300), ) result_after_reopen = reopened_coll.query(vectors=query_after_reopen, topk=10) assert len(result_after_reopen) > 0 assert len(result_after_reopen) <= 10 # Verify query results are valid first_doc = result_after_reopen[0] assert first_doc is not None assert first_doc.id is not None assert first_doc.field("id") is not None assert first_doc.field("name") is not None # Cleanup reopened_coll.destroy() ================================================ FILE: python/tests/test_convert.py ================================================ from __future__ import annotations import math import pytest from _zvec import _Doc from zvec.model.convert import convert_to_py_doc, convert_to_cpp_doc from zvec import Doc, CollectionSchema, DataType, FieldSchema, VectorSchema # ---------------------------- # Convert Cpp Doc Test Case # ---------------------------- class TestConvertCppDoc: def test_default(self): doc = Doc(id="1") schema = CollectionSchema( name="test_collection", fields=FieldSchema("name", DataType.STRING), ) cpp_doc = convert_to_cpp_doc(doc, collection_schema=schema) assert cpp_doc is not None assert cpp_doc.pk() == doc.id def test_with_field_notin_schema(self): doc = Doc(id="1", fields={"name": "Tom"}) schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("id", DataType.UINT64), FieldSchema("salary", DataType.UINT32), FieldSchema("age", DataType.INT32), FieldSchema("create_at", DataType.INT64), FieldSchema("author", DataType.STRING), FieldSchema("weight", DataType.FLOAT), ], ) with pytest.raises(ValueError): convert_to_cpp_doc(doc, collection_schema=schema) def test_with_scalar_fields(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("id", DataType.UINT64), FieldSchema("salary", DataType.UINT32), FieldSchema("age", DataType.INT32), FieldSchema("create_at", DataType.INT64), FieldSchema("author", DataType.STRING), FieldSchema("weight", DataType.FLOAT), FieldSchema("bmi", DataType.DOUBLE), FieldSchema("is_male", DataType.BOOL), ], ) doc = Doc( id="1", fields={ "id": 1, "salary": 1000, "age": 18, "create_at": 1640995200, "bmi": 80.0 / 200.0, "author": "Tom", "weight": 80.0, "is_male": True, }, ) cpp_doc = convert_to_cpp_doc(doc, collection_schema=schema) assert cpp_doc is not None assert cpp_doc.pk() == doc.id assert cpp_doc.get_any("id", DataType.UINT64) == 1 assert cpp_doc.get_any("salary", DataType.UINT32) == 1000 assert cpp_doc.get_any("age", DataType.INT32) == 18 assert cpp_doc.get_any("create_at", DataType.INT64) == 1640995200 assert cpp_doc.get_any("author", DataType.STRING) == "Tom" assert math.isclose( cpp_doc.get_any("weight", DataType.FLOAT), 80.0, rel_tol=1e-6 ) assert math.isclose( cpp_doc.get_any("bmi", DataType.DOUBLE), 80.0 / 200.0, rel_tol=1e-6 ) assert cpp_doc.get_any("is_male", DataType.BOOL) == True def test_with_array_fields(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("tags", DataType.ARRAY_STRING), FieldSchema("ids", DataType.ARRAY_UINT64), FieldSchema("marks", DataType.ARRAY_UINT32), FieldSchema("x", DataType.ARRAY_INT32), FieldSchema("y", DataType.ARRAY_INT64), FieldSchema("scores", DataType.ARRAY_FLOAT), FieldSchema("ratios", DataType.ARRAY_DOUBLE), FieldSchema("results", DataType.ARRAY_BOOL), ], ) doc = Doc( id="1", fields={ "tags": ["tag1", "tag2", "tag3"], "ids": [111111111111, 222222222222, 333333333333], "marks": [100, 200, 300], "x": [1, 2, 3], "y": [100, 200, 300], "scores": [1.1, 2.2, 3.3], "ratios": [0.1, 0.2, 0.3], "results": [True, False, True], }, ) cpp_doc = convert_to_cpp_doc(doc, collection_schema=schema) assert cpp_doc is not None assert cpp_doc.pk() == doc.id assert cpp_doc.get_any("tags", DataType.ARRAY_STRING) == doc.field("tags") assert cpp_doc.get_any("ids", DataType.ARRAY_UINT64) == doc.field("ids") assert cpp_doc.get_any("marks", DataType.ARRAY_UINT32) == doc.field("marks") assert cpp_doc.get_any("x", DataType.ARRAY_INT32) == doc.field("x") assert cpp_doc.get_any("y", DataType.ARRAY_INT64) == doc.field("y") scores = cpp_doc.get_any("scores", DataType.ARRAY_FLOAT) for i in range(len(doc.field("scores"))): assert math.isclose(scores[i], doc.field("scores")[i], rel_tol=1e-1) ratios = cpp_doc.get_any("ratios", DataType.ARRAY_DOUBLE) for i in range(len(doc.field("ratios"))): assert math.isclose(ratios[i], doc.field("ratios")[i], rel_tol=1e-1) results = cpp_doc.get_any("results", DataType.ARRAY_BOOL) for i in range(len(doc.field("results"))): assert results[i] == doc.field("results")[i] def test_with_dense_vector_fields(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="embedding", data_type=DataType.VECTOR_FP16, dimension=4, ), VectorSchema( name="image", data_type=DataType.VECTOR_FP32, dimension=8, ), VectorSchema( name="text", data_type=DataType.VECTOR_INT8, dimension=32, ), ], ) doc = Doc( id="1", vectors={ "embedding": [1.1] * 4, "image": [2.2] * 8, "text": [4] * 32, }, ) cpp_doc = convert_to_cpp_doc(doc, collection_schema=schema) assert cpp_doc is not None assert cpp_doc.pk() == doc.id embedding_vector = cpp_doc.get_any("embedding", DataType.VECTOR_FP16) assert len(embedding_vector) == 4 for i in range(4): assert math.isclose( embedding_vector[i], doc.vector("embedding")[i], rel_tol=1e-1 ) image_vector = cpp_doc.get_any("image", DataType.VECTOR_FP32) assert len(image_vector) == 8 for i in range(8): assert math.isclose(image_vector[i], doc.vector("image")[i], rel_tol=1e-1) text_vector = cpp_doc.get_any("text", DataType.VECTOR_INT8) assert len(text_vector) == 32 for i in range(32): assert text_vector[i] == doc.vectors["text"][i] def test_with_sparse_vector_fields(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="author", data_type=DataType.SPARSE_VECTOR_FP32, ), VectorSchema( name="content", data_type=DataType.SPARSE_VECTOR_FP16, ), ], ) doc = Doc( id="1", vectors={ "author": {1: 1.1, 2: 2.2, 3: 3.3}, "content": {4: 4.4, 5: 5.5, 6: 6.6}, }, ) cpp_doc = convert_to_cpp_doc(doc, collection_schema=schema) assert cpp_doc is not None assert cpp_doc.pk() == doc.id author_vector = cpp_doc.get_any("author", DataType.SPARSE_VECTOR_FP32) assert isinstance(author_vector, dict) for key, value in doc.vector("author").items(): assert math.isclose(author_vector[key], value, rel_tol=1e-1) content_vector = cpp_doc.get_any("content", DataType.SPARSE_VECTOR_FP16) assert isinstance(content_vector, dict) for key, value in doc.vector("content").items(): assert math.isclose(content_vector[key], value, rel_tol=1e-1) def test_with_scalar_fields_error_datatype(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("id", DataType.UINT64), FieldSchema("salary", DataType.UINT32), FieldSchema("age", DataType.INT32), FieldSchema("create_at", DataType.INT64), FieldSchema("author", DataType.STRING), FieldSchema("weight", DataType.FLOAT), FieldSchema("bmi", DataType.DOUBLE), FieldSchema("is_male", DataType.BOOL), ], ) doc = Doc( id="1", fields={ "id": "1", }, ) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"salary": "1000"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"age": "18"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"create_at": "2021-01-01"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"author": 1}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"weight": "80.5"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"bmi": "25.0"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"is_male": "true"}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) def test_with_array_fields_error_datatype(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("tags", DataType.ARRAY_STRING), FieldSchema("ids", DataType.ARRAY_UINT64), FieldSchema("marks", DataType.ARRAY_UINT32), FieldSchema("x", DataType.ARRAY_INT32), FieldSchema("y", DataType.ARRAY_INT64), FieldSchema("scores", DataType.ARRAY_FLOAT), FieldSchema("ratios", DataType.ARRAY_DOUBLE), FieldSchema("results", DataType.ARRAY_BOOL), ], ) doc = Doc(id="1", fields={"tags": [1, 2, 3]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"ids": ["1", "2", "3"]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"marks": [1.1, 2.2, 3.3]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"x": [1.1, 2.2, 3.3]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"y": [1.1, 2.2, 3.3]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"scores": ["1", "2", "3"]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"ratios": ["1", "2", "3"]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", fields={"results": ["1", "2", "3"]}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) def test_with_vector_fields_error_datatype(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="embedding", data_type=DataType.VECTOR_FP16, dimension=4, ), VectorSchema( name="image", data_type=DataType.VECTOR_FP32, dimension=8, ), VectorSchema( name="text", data_type=DataType.VECTOR_INT8, dimension=32, ), ], ) doc = Doc(id="1", vectors={"image": ["1.1"] * 4}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", vectors={"text": ["1"] * 4}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc(id="1", vectors={"embedding": ["1"] * 4}) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) def test_with_sparse_vector_error_datatype(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="author", data_type=DataType.SPARSE_VECTOR_FP32, ), VectorSchema( name="content", data_type=DataType.SPARSE_VECTOR_FP16, ), ], ) doc = Doc( id="1", vectors={ "author": {"1": 1.1, "2": 2.2, "3": 3.3}, }, ) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc( id="1", vectors={ "content": {"1": 1.1, "2": 2.2, "3": 3.3}, }, ) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) doc = Doc( id="1", vectors={ "author": {1: "1", 2: "2", 3: "3"}, }, ) with pytest.raises(TypeError): convert_to_cpp_doc(doc, collection_schema=schema) # ---------------------------- # Convert Py Doc Test Case # ---------------------------- class TestConvertPyDoc: def test_default(self): doc = _Doc() doc.set_pk("1") doc.set_score(1.0) schema = CollectionSchema( name="test_collection", fields=FieldSchema("name", DataType.STRING), ) py_doc = convert_to_py_doc(doc, schema) assert py_doc.id == "1" assert py_doc.score == 1.0 def test_with_scalar_fields(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("id", DataType.UINT64), FieldSchema("salary", DataType.UINT32), FieldSchema("age", DataType.INT32), FieldSchema("create_at", DataType.INT64), FieldSchema("author", DataType.STRING), FieldSchema("weight", DataType.FLOAT), FieldSchema("bmi", DataType.DOUBLE), FieldSchema("is_male", DataType.BOOL), ], ) doc = _Doc() doc.set_pk("1") doc.set_any("id", schema.field("id")._get_object(), 1111111111111111) doc.set_any("salary", schema.field("salary")._get_object(), 1000) doc.set_any("age", schema.field("age")._get_object(), 18) doc.set_any("create_at", schema.field("create_at")._get_object(), 1640995200) doc.set_any("author", schema.field("author")._get_object(), "Tom") doc.set_any("weight", schema.field("weight")._get_object(), 80.0) doc.set_any("bmi", schema.field("bmi")._get_object(), 80.0 / 200.0) doc.set_any("is_male", schema.field("is_male")._get_object(), True) py_doc = convert_to_py_doc(doc, schema) assert py_doc.id == "1" assert py_doc.field("id") == 1111111111111111 assert py_doc.field("salary") == 1000 assert py_doc.field("age") == 18 assert py_doc.field("create_at") == 1640995200 assert py_doc.field("author") == "Tom" assert py_doc.field("weight") == 80.0 assert py_doc.field("bmi") == 80.0 / 200.0 assert py_doc.field("is_male") == True def test_with_array_fields(self): schema = CollectionSchema( name="test_collection", fields=[ FieldSchema("tags", DataType.ARRAY_STRING), FieldSchema("ids", DataType.ARRAY_UINT64), FieldSchema("marks", DataType.ARRAY_UINT32), FieldSchema("x", DataType.ARRAY_INT32), FieldSchema("y", DataType.ARRAY_INT64), FieldSchema("scores", DataType.ARRAY_FLOAT), FieldSchema("ratios", DataType.ARRAY_DOUBLE), FieldSchema("results", DataType.ARRAY_BOOL), ], ) doc = _Doc() doc.set_pk("1") doc.set_any( "tags", schema.field("tags")._get_object(), ["tag1", "tag2", "tag3"] ) doc.set_any( "ids", schema.field("ids")._get_object(), [111111111111, 222222222222, 3333333333333], ) doc.set_any("marks", schema.field("marks")._get_object(), [1000, 2000, 3000]) doc.set_any("x", schema.field("x")._get_object(), [1, 2, 3]) doc.set_any("y", schema.field("y")._get_object(), [100, 200, 300]) doc.set_any("scores", schema.field("scores")._get_object(), [0.1, 0.2, 0.3]) doc.set_any("ratios", schema.field("ratios")._get_object(), [0.1, 0.2, 0.3]) doc.set_any( "results", schema.field("results")._get_object(), [True, False, True] ) py_doc = convert_to_py_doc(doc, schema) assert py_doc.field("tags") == ["tag1", "tag2", "tag3"] assert py_doc.field("ids") == [111111111111, 222222222222, 3333333333333] assert py_doc.field("marks") == [1000, 2000, 3000] assert py_doc.field("x") == [1, 2, 3] assert py_doc.field("y") == [100, 200, 300] scores = doc.get_any("scores", DataType.ARRAY_FLOAT) for i in range(len(scores)): assert math.isclose(scores[i], py_doc.field("scores")[i], rel_tol=1e-1) ratios = doc.get_any("ratios", DataType.ARRAY_DOUBLE) for i in range(len(ratios)): assert math.isclose(ratios[i], py_doc.field("ratios")[i], rel_tol=1e-1) results = doc.get_any("results", DataType.ARRAY_BOOL) for i in range(len(results)): assert results[i] == py_doc.field("results")[i] def test_with_dense_vector_fields(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="embedding", data_type=DataType.VECTOR_FP16, dimension=4, ), VectorSchema( name="image", data_type=DataType.VECTOR_FP32, dimension=8, ), VectorSchema( name="text", data_type=DataType.VECTOR_INT8, dimension=32, ), ], ) doc = _Doc() doc.set_pk("1") doc.set_any("embedding", schema.vector("embedding")._get_object(), [1.1] * 4) doc.set_any("image", schema.vector("image")._get_object(), [2.2] * 8) doc.set_any("text", schema.vector("text")._get_object(), [4] * 32) py_doc = convert_to_py_doc(doc, schema) assert py_doc.id == "1" embedding_vector = py_doc.vector("embedding") assert len(embedding_vector) == 4 for i in range(4): assert math.isclose( py_doc.vector("embedding")[i], embedding_vector[i], rel_tol=1e-1 ) image_vector = py_doc.vector("image") assert len(image_vector) == 8 for i in range(8): assert math.isclose( py_doc.vector("image")[i], image_vector[i], rel_tol=1e-1 ) text_vector = py_doc.vector("text") assert len(text_vector) == 32 for i in range(32): assert py_doc.vector("text")[i] == text_vector[i] def test_with_sparse_vector_fields(self): schema = CollectionSchema( name="test_collection", vectors=[ VectorSchema( name="author", data_type=DataType.SPARSE_VECTOR_FP32, ), VectorSchema( name="content", data_type=DataType.SPARSE_VECTOR_FP16, ), ], ) doc = _Doc() doc.set_pk("1") doc.set_any( "author", schema.vector("author")._get_object(), {1: 1.1, 2: 2.2, 3: 3.3} ) doc.set_any( "content", schema.vector("content")._get_object(), {4: 4.4, 5: 5.5, 6: 6.6} ) py_doc = convert_to_py_doc(doc, schema) assert py_doc.id == "1" author_vector = py_doc.vector("author") assert isinstance(author_vector, dict) for key, value in doc.get_any("author", DataType.SPARSE_VECTOR_FP32).items(): assert math.isclose(author_vector[key], value, rel_tol=1e-1) content_vector = py_doc.vector("content") assert isinstance(content_vector, dict) for key, value in doc.get_any("content", DataType.SPARSE_VECTOR_FP16).items(): assert math.isclose(content_vector[key], value, rel_tol=1e-1) ================================================ FILE: python/tests/test_doc.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import math import pytest from _zvec import _Doc from zvec import FieldSchema, VectorSchema, Doc, DataType # ---------------------------- # PyDoc Test Case # ---------------------------- class TestPyDoc: def test_default(self): Doc(id="1") def test_with_single_vector(self): doc = Doc(id="1", vectors={"dense": [1, 2, 3]}) assert doc is not None assert doc.id == "1" assert doc.vector("dense") == [1, 2, 3] def test_with_hybrid_vectors(self): doc = Doc( id="1", vectors={"dense": [1, 2, 3], "sparse": {1: 1.0, 2: 2.0, 3: 3.0}} ) assert doc is not None assert doc.id == "1" assert doc.vector("dense") == [1, 2, 3] assert doc.vector("sparse") == {1: 1.0, 2: 2.0, 3: 3.0} def test_with_multi_vectors(self): doc = Doc( id="1", vectors={ "image": [1, 2, 3], "description": [4, 5, 6], "keys": {1: 1.0, 2: 2.0, 3: 3.0}, }, fields={"author": "Tom", "age": 19, "is_male": True, "weight": 60.5}, ) assert doc is not None assert doc.id == "1" assert doc.vector("image") == [1, 2, 3] assert doc.vector("description") == [4, 5, 6] assert doc.vector("keys") == {1: 1.0, 2: 2.0, 3: 3.0} assert doc.field("author") == "Tom" assert doc.field("age") == 19 assert doc.field("is_male") == True assert doc.field("weight") == 60.5 def test_with_numpy_array(self): import numpy as np doc = Doc._from_tuple( ( "1", 0.0, None, { "image": np.array([1, 2, 3]), "description": np.random.random(512), "keys": {1: 1.0, 2: 2.0, 3: 3.0}, }, ) ) assert doc is not None assert doc.id == "1" assert doc.vector("image") == [1, 2, 3] assert doc.vector("keys") == {1: 1.0, 2: 2.0, 3: 3.0} # ---------------------------- # CppDoc Test Case # ---------------------------- class TestCppDoc: def test_default(self): doc = _Doc() assert doc is not None def test_doc_set_pk(self): doc = _Doc() doc.set_pk("1") assert doc.pk() == "1" def test_doc_set_score(self): doc = _Doc() doc.set_score(0.9) assert math.isclose(doc.score(), 0.9, rel_tol=1e-6) def test_doc_get_null_field(self): doc = _Doc() schema = FieldSchema("author", DataType.STRING, nullable=True) doc.set_any("author", schema._get_object(), None) assert doc.has_field("author") assert doc.get_any("author", schema.data_type) is None def test_doc_get_set_has_null_field(self): doc = _Doc() schema = FieldSchema("author", DataType.STRING, nullable=False) with pytest.raises(ValueError): doc.set_any("author", schema._get_object(), None) def test_doc_get_set_has_string_field(self): doc = _Doc() schema = FieldSchema("author", DataType.STRING) doc.set_any("author", schema._get_object(), "Tom") assert doc.has_field("author") assert doc.get_any("author", DataType.STRING) == "Tom" def test_doc_get_set_has_bool_field(self): doc = _Doc() schema = FieldSchema("is_male", DataType.BOOL) doc.set_any("is_male", schema._get_object(), True) assert doc.has_field("is_male") assert doc.get_any("is_male", DataType.BOOL) == True def test_doc_get_set_has_int32_field(self): doc = _Doc() schema = FieldSchema("age", DataType.INT32) doc.set_any("age", schema._get_object(), 19) assert doc.has_field("age") assert doc.get_any("age", DataType.INT32) == 19 def test_doc_get_set_has_int64_field(self): doc = _Doc() schema = FieldSchema("id", DataType.INT64) doc.set_any("id", schema._get_object(), 1111111111111111111) assert doc.has_field("id") assert doc.get_any("id", DataType.INT64) == 1111111111111111111 def test_doc_get_set_has_float_field(self): doc = _Doc() schema = FieldSchema("weight", DataType.FLOAT) doc.set_any("weight", schema._get_object(), 60.5) assert doc.has_field("weight") assert math.isclose(doc.get_any("weight", DataType.FLOAT), 60.5, rel_tol=1e-6) def test_doc_get_set_has_double_field(self): doc = _Doc() schema = FieldSchema("height", DataType.DOUBLE) doc.set_any("height", schema._get_object(), 1.77777777777) assert doc.has_field("height") assert math.isclose( doc.get_any("height", DataType.DOUBLE), 1.7777777777, rel_tol=1e-9 ) def test_doc_get_set_has_uint32_field(self): doc = _Doc() schema = FieldSchema("id", DataType.UINT32) doc.set_any("id", schema._get_object(), 4294967295) assert doc.has_field("id") assert doc.get_any("id", DataType.UINT32) == 4294967295 def test_doc_get_set_has_uint64_field(self): doc = _Doc() schema = FieldSchema("id", DataType.UINT64) doc.set_any("id", schema._get_object(), 18446744073709551615) assert doc.has_field("id") assert doc.get_any("id", DataType.UINT64) == 18446744073709551615 def test_doc_get_set_has_array_string_field(self): doc = _Doc() schema = FieldSchema("tags", DataType.ARRAY_STRING) doc.set_any("tags", schema._get_object(), ["tag1", "tag2", "tag3"]) assert doc.has_field("tags") assert doc.get_any("tags", DataType.ARRAY_STRING) == ["tag1", "tag2", "tag3"] def test_doc_get_set_has_array_int32_field(self): doc = _Doc() schema = FieldSchema("ids", DataType.ARRAY_INT32) doc.set_any("ids", schema._get_object(), [1, 2, 3]) assert doc.has_field("ids") assert doc.get_any("ids", DataType.ARRAY_INT32) == [1, 2, 3] def test_doc_get_set_has_array_int64_field(self): doc = _Doc() schema = FieldSchema("ids", DataType.ARRAY_INT64) doc.set_any("ids", schema._get_object(), [1, 2, 3]) assert doc.has_field("ids") assert doc.get_any("ids", DataType.ARRAY_INT64) == [1, 2, 3] def test_doc_get_set_has_array_float_field(self): doc = _Doc() schema = FieldSchema("weights", DataType.ARRAY_FLOAT) doc.set_any("weights", schema._get_object(), [1.0, 2.0, 3.0]) assert doc.has_field("weights") assert doc.get_any("weights", DataType.ARRAY_FLOAT) == [1.0, 2.0, 3.0] def test_doc_get_set_has_array_double_field(self): doc = _Doc() schema = FieldSchema("heights", DataType.ARRAY_DOUBLE) doc.set_any("heights", schema._get_object(), [1.0, 2.0, 3.0]) assert doc.has_field("heights") assert doc.get_any("heights", DataType.ARRAY_DOUBLE) == [1.0, 2.0, 3.0] def test_doc_get_set_has_array_bool_field(self): doc = _Doc() schema = FieldSchema("bools", DataType.ARRAY_BOOL) doc.set_any("bools", schema._get_object(), [True, False, True]) assert doc.has_field("bools") assert doc.get_any("bools", DataType.ARRAY_BOOL) == [True, False, True] def test_doc_get_set_has_vector_fp16(self): doc = _Doc() schema = VectorSchema("image", DataType.VECTOR_FP16) doc.set_any("image", schema._get_object(), [1.0, 2.0, 3.0]) assert doc.has_field("image") image_vector = doc.get_any("image", DataType.VECTOR_FP16) assert image_vector is not None for i in range(len(image_vector)): assert math.isclose(image_vector[i], [1.0, 2.0, 3.0][i], rel_tol=1e-6) def test_doc_get_set_has_vector_fp32(self): doc = _Doc() schema = VectorSchema("image", DataType.VECTOR_FP32) doc.set_any("image", schema._get_object(), [1.111111, 2.222222, 3.333333]) assert doc.has_field("image") vector = doc.get_any("image", DataType.VECTOR_FP32) assert vector is not None for i in range(len(vector)): assert math.isclose( vector[i], [1.111111, 2.222222, 3.333333][i], rel_tol=1e-6 ) def test_doc_get_set_has_vector_int8(self): doc = _Doc() schema = VectorSchema("image", DataType.VECTOR_INT8) doc.set_any("image", schema._get_object(), [1, 2, 3]) assert doc.has_field("image") assert doc.get_any("image", DataType.VECTOR_INT8) == [1, 2, 3] def test_doc_get_set_has_sparse_vector_fp32(self): doc = _Doc() sparse = {1: 1.111111, 2: 2.222222, 3: 3.333333} schema = VectorSchema("key", DataType.SPARSE_VECTOR_FP32) doc.set_any("key", schema._get_object(), sparse) assert doc.has_field("key") vector = doc.get_any("key", DataType.SPARSE_VECTOR_FP32) assert vector is not None assert isinstance(vector, dict) for key, value in sparse.items(): assert math.isclose(vector[key], value, rel_tol=1e-6) def test_doc_get_set_has_sparse_vector_fp16(self): doc = _Doc() sparse = {1: 1.1, 2: 2.2, 3: 3.3} schema = VectorSchema("key", DataType.SPARSE_VECTOR_FP16) doc.set_any("key", schema._get_object(), sparse) assert doc.has_field("key") vector = doc.get_any("key", DataType.SPARSE_VECTOR_FP16) assert vector is not None assert isinstance(vector, dict) for key, value in sparse.items(): assert math.isclose(vector[key], value, rel_tol=1e-1) ================================================ FILE: python/tests/test_embedding.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os from http import HTTPStatus from unittest.mock import MagicMock, patch, Mock import numpy as np import pytest from zvec.extension import ( BM25EmbeddingFunction, DefaultLocalDenseEmbedding, DefaultLocalSparseEmbedding, OpenAIDenseEmbedding, QwenDenseEmbedding, QwenSparseEmbedding, ) # Environment variable to control integration tests # Set ZVEC_RUN_INTEGRATION_TESTS=1 to run real API/model tests RUN_INTEGRATION_TESTS = os.environ.get("ZVEC_RUN_INTEGRATION_TESTS", "0") == "1" # ---------------------------- # QwenDenseEmbedding Test Case # ---------------------------- class TestQwenDenseEmbedding: def test_init_with_api_key(self): # Test initialization with explicit API key embedding_func = QwenDenseEmbedding(dimension=128, api_key="test_key") assert embedding_func.dimension == 128 assert embedding_func.model == "text-embedding-v4" assert embedding_func._api_key == "test_key" @patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_key"}) def test_init_with_env_api_key(self): # Test initialization with API key from environment embedding_func = QwenDenseEmbedding(dimension=128) assert embedding_func._api_key == "env_key" @patch.dict(os.environ, {"DASHSCOPE_API_KEY": ""}) def test_init_with_empty_env_api_key(self): # Test initialization with empty API key from environment with pytest.raises(ValueError, match="DashScope API key is required"): QwenDenseEmbedding(dimension=128) def test_model_property(self): embedding_func = QwenDenseEmbedding(dimension=128, api_key="test_key") assert embedding_func.model == "text-embedding-v4" embedding_func = QwenDenseEmbedding( dimension=128, model="custom-model", api_key="test_key" ) assert embedding_func.model == "custom-model" @patch("zvec.extension.qwen_function.require_module") def test_embed_with_empty_text(self, mock_require_module): # Test embed method with empty text raises ValueError embedding_func = QwenDenseEmbedding(dimension=128, api_key="test_key") with pytest.raises( ValueError, match="Input text cannot be empty or whitespace only" ): embedding_func.embed("") with pytest.raises(TypeError): embedding_func.embed(None) @patch("zvec.extension.qwen_function.require_module") def test_embed_success(self, mock_require_module): # Test successful embedding mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK mock_response.output = {"embeddings": [{"embedding": [0.1, 0.2, 0.3]}]} mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenDenseEmbedding(dimension=3, api_key="test_key") # Clear cache to avoid interference embedding_func.embed.cache_clear() result = embedding_func.embed("test text") assert result == [0.1, 0.2, 0.3] mock_dashscope.TextEmbedding.call.assert_called_once_with( model="text-embedding-v4", input="test text", dimension=3, output_type="dense", ) @patch("zvec.extension.qwen_function.require_module") def test_embed_http_error(self, mock_require_module): # Test embedding with HTTP error mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.BAD_REQUEST mock_response.message = "Bad Request" mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenDenseEmbedding(dimension=128, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises(ValueError): embedding_func.embed("test text") @patch("zvec.extension.qwen_function.require_module") def test_embed_invalid_response(self, mock_require_module): # Test embedding with invalid response (wrong number of embeddings) mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK mock_response.output = {"embeddings": []} mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenDenseEmbedding(dimension=128, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises(ValueError): embedding_func.embed("test text") @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_embed_success(self): """Integration test with real DashScope API. To run this test, set environment variable: export ZVEC_RUN_INTEGRATION_TESTS=1 export DASHSCOPE_API_KEY=your-api-key """ embedding_func = QwenDenseEmbedding(dimension=128) dense = embedding_func("test text") assert len(dense) == 128 # ---------------------------- # QwenSparseEmbedding Test Case # ---------------------------- class TestQwenSparseEmbedding: """Test suite for QwenSparseEmbedding (Qwen sparse embedding via DashScope API).""" def test_init_with_api_key(self): """Test initialization with explicit API key.""" embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") assert embedding_func._dimension == 1024 assert embedding_func.model == "text-embedding-v4" assert embedding_func._api_key == "test_key" # encoding_type defaults to "query" via extra_params assert embedding_func.extra_params.get("encoding_type", "query") == "query" def test_init_with_custom_encoding_type(self): """Test initialization with custom encoding type.""" embedding_func = QwenSparseEmbedding( dimension=1024, encoding_type="document", api_key="test_key" ) assert embedding_func.extra_params.get("encoding_type") == "document" @patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_key"}) def test_init_with_env_api_key(self): """Test initialization with API key from environment.""" embedding_func = QwenSparseEmbedding(dimension=1024) assert embedding_func._api_key == "env_key" @patch.dict(os.environ, {"DASHSCOPE_API_KEY": ""}) def test_init_without_api_key(self): """Test initialization fails without API key.""" with pytest.raises(ValueError, match="DashScope API key is required"): QwenSparseEmbedding(dimension=1024) def test_model_property(self): """Test model property.""" embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") assert embedding_func.model == "text-embedding-v4" embedding_func = QwenSparseEmbedding( dimension=1024, model="text-embedding-v3", api_key="test_key" ) assert embedding_func.model == "text-embedding-v3" def test_encoding_type_property(self): """Test encoding_type via extra_params.""" query_emb = QwenSparseEmbedding( dimension=1024, encoding_type="query", api_key="test_key" ) assert query_emb.extra_params.get("encoding_type") == "query" doc_emb = QwenSparseEmbedding( dimension=1024, encoding_type="document", api_key="test_key" ) assert doc_emb.extra_params.get("encoding_type") == "document" @patch("zvec.extension.qwen_function.require_module") def test_embed_with_empty_text(self, mock_require_module): """Test embed method with empty text raises ValueError.""" embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") with pytest.raises( ValueError, match="Input text cannot be empty or whitespace only" ): embedding_func.embed("") with pytest.raises( ValueError, match="Input text cannot be empty or whitespace only" ): embedding_func.embed(" ") @patch("zvec.extension.qwen_function.require_module") def test_embed_with_non_string_input(self, mock_require_module): """Test embed method with non-string input raises TypeError.""" embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") with pytest.raises(TypeError, match="Expected 'input' to be str"): embedding_func.embed(123) with pytest.raises(TypeError, match="Expected 'input' to be str"): embedding_func.embed(None) @patch("zvec.extension.qwen_function.require_module") def test_embed_success(self, mock_require_module): """Test successful sparse embedding generation.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK # Sparse embedding returns array of {index, value, token} objects mock_response.output = { "embeddings": [ { "sparse_embedding": [ {"index": 10, "value": 0.5, "token": "机器"}, {"index": 245, "value": 0.8, "token": "学习"}, {"index": 1023, "value": 1.2, "token": "算法"}, ] } ] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") # Clear cache to avoid interference embedding_func.embed.cache_clear() result = embedding_func.embed("test text") # Verify result is a dict assert isinstance(result, dict) # Verify keys are integers assert all(isinstance(k, int) for k in result.keys()) # Verify values are floats assert all(isinstance(v, float) for v in result.values()) # Verify all values are positive assert all(v > 0 for v in result.values()) # Verify sorted by indices keys = list(result.keys()) assert keys == sorted(keys) # Verify specific keys assert keys == [10, 245, 1023] mock_dashscope.TextEmbedding.call.assert_called_once_with( model="text-embedding-v4", input="test text", dimension=1024, output_type="sparse", text_type="query", ) @patch("zvec.extension.qwen_function.require_module") def test_embed_with_document_encoding_type(self, mock_require_module): """Test embedding with document encoding type.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK mock_response.output = { "embeddings": [ { "sparse_embedding": [ {"index": 5, "value": 0.3, "token": "文档"}, {"index": 100, "value": 0.7, "token": "内容"}, {"index": 500, "value": 0.9, "token": "检索"}, ] } ] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding( dimension=1024, encoding_type="document", api_key="test_key" ) embedding_func.embed.cache_clear() result = embedding_func.embed("test document") assert isinstance(result, dict) assert list(result.keys()) == [5, 100, 500] # Verify text_type parameter is "document" call_args = mock_dashscope.TextEmbedding.call.call_args assert call_args[1]["text_type"] == "document" assert call_args[1]["output_type"] == "sparse" @patch("zvec.extension.qwen_function.require_module") def test_embed_output_sorted_by_indices(self, mock_require_module): """Test that output is always sorted by indices in ascending order.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK # Return unsorted indices mock_response.output = { "embeddings": [ { "sparse_embedding": [ {"index": 9999, "value": 1.5, "token": "A"}, {"index": 5, "value": 2.0, "token": "B"}, {"index": 1234, "value": 0.8, "token": "C"}, {"index": 77, "value": 3.2, "token": "D"}, {"index": 500, "value": 1.1, "token": "E"}, ] } ] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() result = embedding_func.embed("test sorting") # Verify keys are sorted result_keys = list(result.keys()) assert result_keys == sorted(result_keys) # Verify expected sorted order assert result_keys == [5, 77, 500, 1234, 9999] @patch("zvec.extension.qwen_function.require_module") def test_embed_filters_zero_values(self, mock_require_module): """Test that zero and negative values are filtered out.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK # Include zero and negative values mock_response.output = { "embeddings": [ { "sparse_embedding": [ {"index": 10, "value": 0.5, "token": "正"}, { "index": 20, "value": 0.0, "token": "零", }, # Should be filtered { "index": 30, "value": -0.3, "token": "负", }, # Should be filtered {"index": 40, "value": 0.8, "token": "正"}, { "index": 50, "value": 0.0, "token": "零", }, # Should be filtered ] } ] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() result = embedding_func.embed("test filtering") # Only positive values should remain assert list(result.keys()) == [10, 40] assert all(v > 0 for v in result.values()) @patch("zvec.extension.qwen_function.require_module") def test_embed_http_error(self, mock_require_module): """Test embedding with HTTP error.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.BAD_REQUEST mock_response.message = "Bad Request" mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises(ValueError, match="DashScope API error"): embedding_func.embed("test text") @patch("zvec.extension.qwen_function.require_module") def test_embed_invalid_response_no_embeddings(self, mock_require_module): """Test embedding with invalid response (no embeddings).""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK mock_response.output = {"embeddings": []} mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises(ValueError, match="Expected exactly 1 embedding"): embedding_func.embed("test text") @patch("zvec.extension.qwen_function.require_module") def test_embed_invalid_response_not_dict(self, mock_require_module): """Test embedding with invalid response (sparse_embedding not list).""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK # sparse_embedding should be list, not dict mock_response.output = { "embeddings": [{"sparse_embedding": {"index": 10, "value": 0.5}}] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises( ValueError, match="'sparse_embedding' field is missing or not a list" ): embedding_func.embed("test text") @patch("zvec.extension.qwen_function.require_module") def test_embed_callable_interface(self, mock_require_module): """Test that embedding function is callable.""" mock_dashscope = MagicMock() mock_response = MagicMock() mock_response.status_code = HTTPStatus.OK mock_response.output = { "embeddings": [ { "sparse_embedding": [ {"index": 100, "value": 1.0, "token": "测试"}, {"index": 200, "value": 0.5, "token": "调用"}, ] } ] } mock_dashscope.TextEmbedding.call.return_value = mock_response mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() # Test calling the function directly result = embedding_func("test text") assert isinstance(result, dict) assert list(result.keys()) == [100, 200] @patch("zvec.extension.qwen_function.require_module") def test_embed_api_connection_error(self, mock_require_module): """Test handling of API connection errors.""" mock_dashscope = MagicMock() mock_dashscope.TextEmbedding.call.side_effect = Exception("Connection timeout") mock_require_module.return_value = mock_dashscope embedding_func = QwenSparseEmbedding(dimension=1024, api_key="test_key") embedding_func.embed.cache_clear() with pytest.raises(RuntimeError, match="Failed to call DashScope API"): embedding_func.embed("test text") @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_embed_success(self): """Integration test with real DashScope API. To run this test, set environment variable: export ZVEC_RUN_INTEGRATION_TESTS=1 export DASHSCOPE_API_KEY=your-api-key """ # Test query embedding query_emb = QwenSparseEmbedding(dimension=1024, encoding_type="query") query_vec = query_emb.embed("machine learning") assert isinstance(query_vec, dict) assert len(query_vec) > 0 assert all(isinstance(k, int) for k in query_vec.keys()) assert all(isinstance(v, float) and v > 0 for v in query_vec.values()) # Verify sorted output keys = list(query_vec.keys()) assert keys == sorted(keys) # Test document embedding doc_emb = QwenSparseEmbedding(dimension=1024, encoding_type="document") doc_vec = doc_emb.embed("Machine learning is a subset of AI") assert isinstance(doc_vec, dict) assert len(doc_vec) > 0 # Verify sorted output doc_keys = list(doc_vec.keys()) assert doc_keys == sorted(doc_keys) # ---------------------------- # OpenAIDenseEmbedding Test Case # ---------------------------- class TestOpenAIDenseEmbedding: def test_init_with_api_key(self): """Test initialization with explicit API key.""" embedding_func = OpenAIDenseEmbedding(api_key="sk-test-key") assert embedding_func.dimension == 1536 # Default for text-embedding-3-small assert embedding_func.model == "text-embedding-3-small" assert embedding_func._api_key == "sk-test-key" @patch.dict(os.environ, {"OPENAI_API_KEY": "sk-env-key"}) def test_init_with_env_api_key(self): """Test initialization with API key from environment.""" embedding_func = OpenAIDenseEmbedding() assert embedding_func._api_key == "sk-env-key" @patch.dict(os.environ, {"OPENAI_API_KEY": ""}) def test_init_without_api_key(self): """Test initialization fails without API key.""" with pytest.raises(ValueError, match="OpenAI API key is required"): OpenAIDenseEmbedding() def test_init_with_custom_dimension(self): """Test initialization with custom dimension.""" embedding_func = OpenAIDenseEmbedding( model="text-embedding-3-large", dimension=1024, api_key="sk-test" ) assert embedding_func.dimension == 1024 assert embedding_func.model == "text-embedding-3-large" def test_init_with_base_url(self): """Test initialization with custom base URL.""" embedding_func = OpenAIDenseEmbedding( api_key="sk-test", base_url="https://custom.openai.com/" ) assert embedding_func._base_url == "https://custom.openai.com/" def test_model_property(self): """Test model property.""" embedding_func = OpenAIDenseEmbedding(api_key="sk-test") assert embedding_func.model == "text-embedding-3-small" embedding_func = OpenAIDenseEmbedding( model="text-embedding-ada-002", api_key="sk-test" ) assert embedding_func.model == "text-embedding-ada-002" def test_extra_params(self): """Test extra_params property.""" # Test without extra params embedding_func = OpenAIDenseEmbedding(api_key="sk-test") assert embedding_func.extra_params == {} # Test with extra params embedding_func = OpenAIDenseEmbedding( api_key="sk-test", encoding_format="float", user="test-user", ) assert embedding_func.extra_params == { "encoding_format": "float", "user": "test-user", } @patch("zvec.extension.openai_function.require_module") def test_embed_with_empty_text(self, mock_require_module): """Test embed method with empty text raises ValueError.""" embedding_func = OpenAIDenseEmbedding(api_key="sk-test") with pytest.raises( ValueError, match="Input text cannot be empty or whitespace only" ): embedding_func.embed("") with pytest.raises( ValueError, match="Input text cannot be empty or whitespace only" ): embedding_func.embed(" ") @patch("zvec.extension.openai_function.require_module") def test_embed_with_non_string_input(self, mock_require_module): """Test embed method with non-string input raises TypeError.""" embedding_func = OpenAIDenseEmbedding(api_key="sk-test") with pytest.raises(TypeError, match="Expected 'input' to be str"): embedding_func.embed(123) with pytest.raises(TypeError, match="Expected 'input' to be str"): embedding_func.embed(None) @patch("zvec.extension.openai_function.require_module") def test_embed_success(self, mock_require_module): """Test successful embedding generation.""" # Mock OpenAI client mock_openai = Mock() mock_client = Mock() mock_response = Mock() # Create mock embedding data fake_embedding = [0.1, 0.2, 0.3] mock_embedding_obj = Mock() mock_embedding_obj.embedding = fake_embedding mock_response.data = [mock_embedding_obj] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding(dimension=3, api_key="sk-test") embedding_func.embed.cache_clear() result = embedding_func.embed("test text") assert result == [0.1, 0.2, 0.3] mock_client.embeddings.create.assert_called_once_with( model="text-embedding-3-small", input="test text", dimensions=3 ) @patch("zvec.extension.openai_function.require_module") def test_embed_with_custom_model(self, mock_require_module): """Test embedding with custom model.""" mock_openai = Mock() mock_client = Mock() mock_response = Mock() fake_embedding = [0.1] * 1536 mock_embedding_obj = Mock() mock_embedding_obj.embedding = fake_embedding mock_response.data = [mock_embedding_obj] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding( model="text-embedding-ada-002", api_key="sk-test" ) embedding_func.embed.cache_clear() result = embedding_func.embed("test text") assert len(result) == 1536 mock_client.embeddings.create.assert_called_once_with( model="text-embedding-ada-002", input="test text" ) @patch("zvec.extension.openai_function.require_module") def test_embed_api_error(self, mock_require_module): """Test handling of API errors.""" mock_openai = Mock() mock_client = Mock() # Simulate API error api_error = Mock() api_error.__class__.__name__ = "APIError" mock_openai.APIError = type("APIError", (Exception,), {}) mock_openai.APIConnectionError = type("APIConnectionError", (Exception,), {}) mock_client.embeddings.create.side_effect = mock_openai.APIError( "Rate limit exceeded" ) mock_openai.OpenAI.return_value = mock_client mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding(api_key="sk-test") embedding_func.embed.cache_clear() with pytest.raises(RuntimeError, match="Failed to call OpenAI API"): embedding_func.embed("test text") @patch("zvec.extension.openai_function.require_module") def test_embed_invalid_response(self, mock_require_module): """Test handling of invalid API response.""" mock_openai = Mock() mock_client = Mock() mock_response = Mock() # Empty response data mock_response.data = [] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_openai.APIError = type("APIError", (Exception,), {}) mock_openai.APIConnectionError = type("APIConnectionError", (Exception,), {}) mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding(api_key="sk-test") embedding_func.embed.cache_clear() with pytest.raises(ValueError, match="no embedding data returned"): embedding_func.embed("test text") @patch("zvec.extension.openai_function.require_module") def test_embed_dimension_mismatch(self, mock_require_module): """Test handling of dimension mismatch.""" mock_openai = Mock() mock_client = Mock() mock_response = Mock() # Return embedding with wrong dimension fake_embedding = [0.1] * 512 mock_embedding_obj = Mock() mock_embedding_obj.embedding = fake_embedding mock_response.data = [mock_embedding_obj] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_openai.APIError = type("APIError", (Exception,), {}) mock_openai.APIConnectionError = type("APIConnectionError", (Exception,), {}) mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding(dimension=1536, api_key="sk-test") embedding_func.embed.cache_clear() with pytest.raises(ValueError, match="Dimension mismatch"): embedding_func.embed("test text") @patch("zvec.extension.openai_function.require_module") def test_embed_callable(self, mock_require_module): """Test that embedding function is callable.""" mock_openai = Mock() mock_client = Mock() mock_response = Mock() fake_embedding = [0.1] * 1536 mock_embedding_obj = Mock() mock_embedding_obj.embedding = fake_embedding mock_response.data = [mock_embedding_obj] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_openai.APIError = type("APIError", (Exception,), {}) mock_openai.APIConnectionError = type("APIConnectionError", (Exception,), {}) mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding(api_key="sk-test") embedding_func.embed.cache_clear() # Test calling the function directly result = embedding_func("test text") assert isinstance(result, list) assert len(result) == 1536 @patch("zvec.extension.openai_function.require_module") def test_embed_with_base_url(self, mock_require_module): """Test embedding with custom base URL.""" mock_openai = Mock() mock_client = Mock() mock_response = Mock() fake_embedding = [0.1] * 1536 mock_embedding_obj = Mock() mock_embedding_obj.embedding = fake_embedding mock_response.data = [mock_embedding_obj] mock_client.embeddings.create.return_value = mock_response mock_openai.OpenAI.return_value = mock_client mock_openai.APIError = type("APIError", (Exception,), {}) mock_openai.APIConnectionError = type("APIConnectionError", (Exception,), {}) mock_require_module.return_value = mock_openai embedding_func = OpenAIDenseEmbedding( api_key="sk-test", base_url="https://custom.openai.com/" ) embedding_func.embed.cache_clear() result = embedding_func.embed("test text") # Verify client was created with custom base URL mock_openai.OpenAI.assert_called_once_with( api_key="sk-test", base_url="https://custom.openai.com/" ) assert len(result) == 1536 @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_embed_success(self): """Integration test with real OpenAI API. To run this test, set environment variable: export ZVEC_RUN_INTEGRATION_TESTS=1 export OPENAI_API_KEY=sk-... """ embedding_func = OpenAIDenseEmbedding( model="text-embedding-v4", dimension=256, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) vector = embedding_func.embed("Hello, world!") assert len(vector) == 256 assert isinstance(vector, list) assert all(isinstance(x, float) for x in vector) # ---------------------------- # DefaultLocalDenseEmbedding Test Case # ---------------------------- class TestDefaultLocalDenseEmbedding: """Test cases for DefaultLocalDenseEmbedding.""" @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_success(self, mock_require_module): """Test successful initialization with mocked model.""" # Mock sentence_transformers module mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_model.device = "cpu" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st # Initialize embedding function emb_func = DefaultLocalDenseEmbedding() # Assertions assert emb_func.dimension == 384 assert emb_func.model_name == "all-MiniLM-L6-v2" assert emb_func.model_source == "huggingface" assert emb_func.device == "cpu" mock_st.SentenceTransformer.assert_called_once_with( "all-MiniLM-L6-v2", device=None, trust_remote_code=True ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_with_custom_device(self, mock_require_module): """Test initialization with custom device.""" mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_model.device = "cuda" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding(device="cuda") assert emb_func.device == "cuda" mock_st.SentenceTransformer.assert_called_once_with( "all-MiniLM-L6-v2", device="cuda", trust_remote_code=True ) @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_with_modelscope(self, mock_require_module): """Test initialization with ModelScope as model source.""" mock_st = Mock() mock_ms = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_model.device = "cpu" mock_st.SentenceTransformer.return_value = mock_model def require_module_side_effect(module_name): if module_name == "sentence_transformers": return mock_st elif module_name == "modelscope": return mock_ms raise ImportError(f"No module named '{module_name}'") mock_require_module.side_effect = require_module_side_effect # Mock snapshot_download at the correct import location with patch( "modelscope.hub.snapshot_download.snapshot_download", return_value="/path/to/cached/model", ): emb_func = DefaultLocalDenseEmbedding(model_source="modelscope") # Assertions assert emb_func.dimension == 384 assert emb_func.model_name == "iic/nlp_gte_sentence-embedding_chinese-small" assert emb_func.model_source == "modelscope" @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_with_invalid_model_source(self, mock_require_module): """Test initialization with invalid model_source raises ValueError.""" mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st with pytest.raises(ValueError, match="Invalid model_source"): DefaultLocalDenseEmbedding(model_source="invalid_source") @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_success(self, mock_require_module): """Test successful embedding generation.""" # Mock embedding output fake_embedding = np.random.rand(384).astype(np.float32) mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 # Configure encode method mock_model.encode = Mock(return_value=fake_embedding) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() result = emb_func.embed("Hello, world!") # Assertions assert isinstance(result, list) assert len(result) == 384 assert all(isinstance(x, float) for x in result) mock_model.encode.assert_called_once_with( "Hello, world!", convert_to_numpy=True, normalize_embeddings=True, batch_size=32, ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_with_normalization(self, mock_require_module): """Test embedding with L2 normalization.""" # Create a normalized vector fake_embedding = np.random.rand(384).astype(np.float32) fake_embedding = fake_embedding / np.linalg.norm(fake_embedding) mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 # Configure encode method mock_model.encode = Mock(return_value=fake_embedding) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding(normalize_embeddings=True) result = emb_func.embed("Test sentence") # Check if vector is normalized (L2 norm should be close to 1.0) result_array = np.array(result) norm = np.linalg.norm(result_array) assert abs(norm - 1.0) < 1e-5 @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_empty_string(self, mock_require_module): """Test embedding with empty string raises ValueError.""" mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() with pytest.raises(ValueError, match="Input text cannot be empty"): emb_func.embed("") with pytest.raises(ValueError, match="Input text cannot be empty"): emb_func.embed(" ") @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_non_string_input(self, mock_require_module): """Test embedding with non-string input raises TypeError.""" mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() with pytest.raises(TypeError, match="Expected 'input' to be str"): emb_func.embed(123) with pytest.raises(TypeError, match="Expected 'input' to be str"): emb_func.embed(None) @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_callable(self, mock_require_module): """Test that embedding function is callable.""" fake_embedding = np.random.rand(384).astype(np.float32) mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 # Configure encode method mock_model.encode = Mock(return_value=fake_embedding) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() # Test calling the function directly result = emb_func("Test text") assert isinstance(result, list) assert len(result) == 384 @patch("zvec.extension.sentence_transformer_function.require_module") def test_semantic_similarity(self, mock_require_module): """Test semantic similarity between similar and different texts.""" # Create mock embeddings for similar and different texts similar_emb_1 = np.array([1.0, 0.0, 0.0] + [0.0] * 381, dtype=np.float32) similar_emb_2 = np.array([0.9, 0.1, 0.0] + [0.0] * 381, dtype=np.float32) different_emb = np.array([0.0, 0.0, 1.0] + [0.0] * 381, dtype=np.float32) # Normalize similar_emb_1 = similar_emb_1 / np.linalg.norm(similar_emb_1) similar_emb_2 = similar_emb_2 / np.linalg.norm(similar_emb_2) different_emb = different_emb / np.linalg.norm(different_emb) mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 # Configure encode method with side_effect for multiple calls mock_model.encode = Mock( side_effect=[similar_emb_1, similar_emb_2, different_emb] ) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() v1 = emb_func.embed("The cat sits on the mat") v2 = emb_func.embed("A feline rests on a rug") v3 = emb_func.embed("Python programming") # Calculate similarities similarity_high = np.dot(v1, v2) similarity_low = np.dot(v1, v3) assert similarity_high > similarity_low @patch("zvec.extension.sentence_transformer_function.require_module") def test_model_loading_error(self, mock_require_module): """Test handling of model loading failure.""" # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() mock_st = Mock() mock_st.SentenceTransformer.side_effect = Exception("Model not found") mock_require_module.return_value = mock_st with pytest.raises( ValueError, match="Failed to load Sentence Transformer model" ): DefaultLocalDenseEmbedding() @patch("zvec.extension.sentence_transformer_function.require_module") def test_modelscope_import_error(self, mock_require_module): """Test handling of ModelScope import error.""" mock_st = Mock() def require_module_side_effect(module_name): if module_name == "sentence_transformers": return mock_st elif module_name == "modelscope": raise ImportError("No module named 'modelscope'") mock_require_module.side_effect = require_module_side_effect with pytest.raises( ImportError, match="ModelScope support requires the 'modelscope' package" ): DefaultLocalDenseEmbedding(model_source="modelscope") @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_dimension_mismatch(self, mock_require_module): """Test handling of dimension mismatch in embedding output.""" # Return embedding with wrong dimension fake_embedding = np.random.rand(256).astype(np.float32) mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 # Configure encode method mock_model.encode = Mock(return_value=fake_embedding) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st emb_func = DefaultLocalDenseEmbedding() with pytest.raises(ValueError, match="Dimension mismatch"): emb_func.embed("Test text") @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_embedding_generation(self): """Integration test with real model (requires sentence-transformers). To run this test, set environment variable: export ZVEC_RUN_INTEGRATION_TESTS=1 Note: First run will download the model (~80MB). """ emb_func = DefaultLocalDenseEmbedding() # Test basic embedding vector = emb_func.embed("Hello, world!") assert len(vector) == 384 assert isinstance(vector, list) assert all(isinstance(x, float) for x in vector) # Test normalization norm = np.linalg.norm(vector) assert abs(norm - 1.0) < 1e-5 # Test semantic similarity v1 = emb_func.embed("The cat sits on the mat") v2 = emb_func.embed("A feline rests on a rug") v3 = emb_func.embed("Python programming language") similarity_high = np.dot(v1, v2) similarity_low = np.dot(v1, v3) assert similarity_high > similarity_low @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_model_properties(self, mock_require_module): """Test model_name and model_source properties.""" mock_st = Mock() mock_model = Mock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_model.device = "cpu" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st # Test Hugging Face emb_func_hf = DefaultLocalDenseEmbedding(model_source="huggingface") assert emb_func_hf.model_name == "all-MiniLM-L6-v2" assert emb_func_hf.model_source == "huggingface" # Test ModelScope with patch( "modelscope.hub.snapshot_download.snapshot_download", return_value="/path/to/model", ): mock_ms = Mock() mock_require_module.side_effect = ( lambda m: mock_st if m == "sentence_transformers" else mock_ms ) emb_func_ms = DefaultLocalDenseEmbedding(model_source="modelscope") assert ( emb_func_ms.model_name == "iic/nlp_gte_sentence-embedding_chinese-small" ) assert emb_func_ms.model_source == "modelscope" # ----------------------------------- # DefaultLocalSparseEmbedding Test Case # ----------------------------------- class TestDefaultLocalSparseEmbedding: """Test suite for DefaultLocalSparseEmbedding (SPLADE sparse embedding). Note: DefaultLocalSparseEmbedding uses naver/splade-cocondenser-ensembledistil instead of naver/splade-v3 because: - splade-v3 is a gated model requiring Hugging Face authentication - cocondenser-ensembledistil is publicly accessible - Performance difference is minimal (~2%) - Avoids "Access to model is restricted" errors This allows all users to run tests without authentication setup. """ @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_success(self, mock_require_module): """Test successful initialization. Verifies that DefaultLocalSparseEmbedding initializes with the publicly accessible naver/splade-cocondenser-ensembledistil model instead of the gated naver/splade-v3 model. """ mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() assert sparse_emb.model_name == "naver/splade-cocondenser-ensembledistil" assert sparse_emb.model_source == "huggingface" assert sparse_emb.device == "cpu" mock_st.SentenceTransformer.assert_called_once_with( "naver/splade-cocondenser-ensembledistil", device=None, trust_remote_code=True, ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_init_with_custom_device(self, mock_require_module): """Test initialization with custom device.""" mock_st = Mock() mock_model = Mock() mock_model.device = "cuda" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding(device="cuda") assert sparse_emb.device == "cuda" mock_st.SentenceTransformer.assert_called_once_with( "naver/splade-cocondenser-ensembledistil", device="cuda", trust_remote_code=True, ) @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_success(self, mock_require_module): """Test successful sparse embedding generation with official API.""" import numpy as np # Clear model cache to ensure fresh mock from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() # Create a mock sparse matrix that simulates scipy.sparse behavior # The code will call: sparse_matrix[0].toarray().flatten() mock_sparse_matrix = Mock() # Create a dense array representation with vocab_size=30522 vocab_size = 30522 dense_array = np.zeros(vocab_size) # Set specific non-zero values at indices [10, 245, 1023, 5678] dense_array[10] = 0.5 dense_array[245] = 0.8 dense_array[1023] = 1.2 dense_array[5678] = 0.3 # Mock the method chain: sparse_matrix[0].toarray().flatten() mock_row = Mock() mock_dense = Mock() mock_row.toarray.return_value = mock_dense mock_dense.flatten.return_value = dense_array mock_sparse_matrix.__getitem__ = Mock(return_value=mock_row) # Also mock hasattr check for 'toarray' mock_sparse_matrix.toarray = Mock() mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" # Configure mock methods to return sparse matrix # Must set return_value BEFORE hasattr() check in the code mock_model.encode_query = Mock(return_value=mock_sparse_matrix) mock_model.encode_document = Mock(return_value=mock_sparse_matrix) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() result = sparse_emb.embed("machine learning") # Verify result is a dictionary assert isinstance(result, dict) # Verify keys are integers and values are floats assert all(isinstance(k, int) for k in result.keys()) assert all(isinstance(v, float) for v in result.values()) # Verify all values are positive assert all(v > 0 for v in result.values()) # Sparse vectors should have specific dimensions assert len(result) == 4 # Verify output is sorted by indices (keys) keys = list(result.keys()) assert keys == sorted(keys), ( "Sparse vector keys must be sorted in ascending order" ) # Verify expected keys assert keys == [10, 245, 1023, 5678] # Verify encode_query was called with a list mock_model.encode_query.assert_called_once() call_args = mock_model.encode_query.call_args[0][0] assert isinstance(call_args, list) assert call_args == ["machine learning"] @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_empty_input(self, mock_require_module): """Test embedding with empty input.""" mock_st = Mock() mock_model = Mock() mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() with pytest.raises(ValueError, match="Input text cannot be empty"): sparse_emb.embed("") with pytest.raises(ValueError, match="Input text cannot be empty"): sparse_emb.embed(" ") @patch("zvec.extension.sentence_transformer_function.require_module") def test_embed_non_string_input(self, mock_require_module): """Test embedding with non-string input.""" mock_st = Mock() mock_model = Mock() mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() with pytest.raises(TypeError, match="Expected 'input' to be str"): sparse_emb.embed(123) with pytest.raises(TypeError, match="Expected 'input' to be str"): sparse_emb.embed(["text"]) @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_callable_interface(self, mock_require_module): """Test that DefaultSparseEmbedding is callable.""" import numpy as np # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() # Create a mock sparse matrix mock_sparse_matrix = Mock() # Create a dense array representation with vocab_size=30522 vocab_size = 30522 dense_array = np.zeros(vocab_size) # Set specific non-zero values at indices [100, 200, 300] dense_array[100] = 1.0 dense_array[200] = 0.5 dense_array[300] = 0.8 # Mock the method chain: sparse_matrix[0].toarray().flatten() mock_row = Mock() mock_dense = Mock() mock_row.toarray.return_value = mock_dense mock_dense.flatten.return_value = dense_array mock_sparse_matrix.__getitem__ = Mock(return_value=mock_row) # Also mock hasattr check for 'toarray' mock_sparse_matrix.toarray = Mock() mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" # Configure mock methods mock_model.encode_query = Mock(return_value=mock_sparse_matrix) mock_model.encode_document = Mock(return_value=mock_sparse_matrix) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() # Test callable interface result = sparse_emb("test input") assert isinstance(result, dict) assert all(isinstance(k, int) for k in result.keys()) # Verify sorted output keys = list(result.keys()) assert keys == sorted(keys), "Callable interface must also return sorted keys" assert keys == [100, 200, 300] @patch("zvec.extension.sentence_transformer_function.require_module") def test_model_loading_failure(self, mock_require_module): """Test handling of model loading failure.""" # Clear model cache to ensure the test actually tries to load the model from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() mock_st = Mock() mock_st.SentenceTransformer.side_effect = Exception("Model not found") mock_require_module.return_value = mock_st with pytest.raises( ValueError, match="Failed to load Sentence Transformer model" ): DefaultLocalSparseEmbedding() @patch("zvec.extension.sentence_transformer_function.require_module") def test_inference_failure(self, mock_require_module): """Test handling of inference failure.""" # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" # Configure mock methods to raise RuntimeError mock_model.encode_query = Mock(side_effect=RuntimeError("CUDA out of memory")) mock_model.encode_document = Mock( side_effect=RuntimeError("CUDA out of memory") ) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() with pytest.raises(RuntimeError, match="Failed to generate sparse embedding"): sparse_emb.embed("test input") @patch("zvec.extension.sentence_transformer_function.require_module") def test_sparse_vector_properties(self, mock_require_module): """Test properties of sparse vectors (sparsity, non-zero values, sorted order).""" import numpy as np # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() # Create a mock sparse matrix that simulates scipy.sparse behavior # The code will call: sparse_matrix[0].toarray().flatten() mock_sparse_matrix = Mock() # Create a dense array representation with vocab_size=30522 vocab_size = 30522 dense_array = np.zeros(vocab_size) # Set specific non-zero values at indices [50, 100, 200, 400, 500] dense_array[50] = 3.0 dense_array[100] = 2.0 dense_array[200] = 1.5 dense_array[400] = 2.5 dense_array[500] = 1.8 # Mock the method chain: sparse_matrix[0].toarray().flatten() mock_row = Mock() mock_dense = Mock() mock_row.toarray.return_value = mock_dense mock_dense.flatten.return_value = dense_array mock_sparse_matrix.__getitem__ = Mock(return_value=mock_row) # Also mock hasattr check for 'toarray' mock_sparse_matrix.toarray = Mock() mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" # Configure mock methods mock_model.encode_query = Mock(return_value=mock_sparse_matrix) mock_model.encode_document = Mock(return_value=mock_sparse_matrix) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() result = sparse_emb.embed("test") # Verify sparsity: result should have much fewer dimensions than vocab_size assert len(result) < vocab_size # All values should be positive assert all(v > 0 for v in result.values()) # Verify keys are sorted in ascending order keys = list(result.keys()) assert keys == sorted(keys), "Sparse vector keys must be sorted" # Verify the specific non-zero indices are present and sorted # Expected order: [50, 100, 200, 400, 500] (sorted) expected_keys = [50, 100, 200, 400, 500] assert keys == expected_keys, f"Expected {expected_keys}, got {keys}" # First key should be smallest if len(result) > 0: first_key = next(iter(result.keys())) assert first_key == min(result.keys()), "First key must be the smallest" @patch("zvec.extension.sentence_transformer_function.require_module") def test_output_sorted_by_indices(self, mock_require_module): """Test that output dictionary is always sorted by indices (keys) in ascending order.""" import numpy as np # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() # Create sparse output with deliberately out-of-order indices # Non-sequential indices: 9999, 5, 1234, 77, 500 mock_sparse_matrix = Mock() # Create a dense array representation with vocab_size=30522 vocab_size = 30522 dense_array = np.zeros(vocab_size) # Set specific non-zero values at out-of-order indices dense_array[9999] = 1.5 dense_array[5] = 2.0 dense_array[1234] = 0.8 dense_array[77] = 3.2 dense_array[500] = 1.1 # Mock the method chain: sparse_matrix[0].toarray().flatten() mock_row = Mock() mock_dense = Mock() mock_row.toarray.return_value = mock_dense mock_dense.flatten.return_value = dense_array mock_sparse_matrix.__getitem__ = Mock(return_value=mock_row) # Also mock hasattr check for 'toarray' mock_sparse_matrix.toarray = Mock() mock_st = Mock() mock_model = Mock() mock_model.device = "cpu" # Configure mock methods mock_model.encode_query = Mock(return_value=mock_sparse_matrix) mock_model.encode_document = Mock(return_value=mock_sparse_matrix) mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding() result = sparse_emb.embed("test sorting") # Extract keys from result result_keys = list(result.keys()) # Verify keys are sorted assert result_keys == sorted(result_keys), ( f"Keys must be sorted in ascending order. " f"Got: {result_keys}, Expected: {sorted(result_keys)}" ) # Verify expected keys are present and in correct order # Expected sorted order: [5, 77, 500, 1234, 9999] expected_sorted_keys = [5, 77, 500, 1234, 9999] assert result_keys == expected_sorted_keys, ( f"All expected keys should be present in sorted order. " f"Expected: {expected_sorted_keys}, Got: {result_keys}" ) # Verify first and last keys assert result_keys[0] == 5, "First key must be minimum" assert result_keys[-1] == 9999, "Last key must be maximum" # Verify iteration order matches sorted order for i, (key, value) in enumerate(result.items()): if i > 0: prev_key = list(result.keys())[i - 1] assert key > prev_key, ( f"Key at position {i} must be greater than previous key" ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_device_property(self, mock_require_module): """Test device property returns correct device.""" mock_st = Mock() mock_model = Mock() mock_model.device = "cuda" mock_st.SentenceTransformer.return_value = mock_model mock_require_module.return_value = mock_st sparse_emb = DefaultLocalSparseEmbedding(device="cuda") assert sparse_emb.device == "cuda" @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test: requires ZVEC_RUN_INTEGRATION_TESTS=1 and model download", ) @patch("zvec.extension.sentence_transformer_function.require_module") def test_modelscope_source(self, mock_require_module): """Test initialization with ModelScope source.""" mock_st = Mock() mock_ms = Mock() mock_model = Mock() mock_model.device = "cpu" mock_st.SentenceTransformer.return_value = mock_model # Mock ModelScope snapshot_download with patch( "modelscope.hub.snapshot_download.snapshot_download", return_value="/cache/splade-cocondenser", ): mock_require_module.side_effect = ( lambda m: mock_st if m == "sentence_transformers" else mock_ms ) sparse_emb = DefaultLocalSparseEmbedding(model_source="modelscope") assert sparse_emb.model_name == "naver/splade-cocondenser-ensembledistil" assert sparse_emb.model_source == "modelscope" @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test: requires ZVEC_RUN_INTEGRATION_TESTS=1 and model download", ) def test_integration_real_model(self): """Integration test with real SPLADE model (requires model download). This test uses naver/splade-cocondenser-ensembledistil instead of naver/splade-v3 because splade-v3 requires Hugging Face authentication. The cocondenser-ensembledistil model is publicly accessible and provides comparable performance. To run this test: export ZVEC_RUN_INTEGRATION_TESTS=1 pytest tests/test_embedding.py::TestDefaultSparseEmbedding::test_integration_real_model -v Note: First run will download ~100MB model from Hugging Face. Alternative models: If you have access to splade-v3, you can create a custom embedding class following the example in DefaultSparseEmbedding docstring. """ # Clear model cache to ensure fresh load from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() sparse_emb = DefaultLocalSparseEmbedding() # Test with real input text = "machine learning and artificial intelligence" result = sparse_emb.embed(text) # Verify result structure assert isinstance(result, dict) assert len(result) > 0 assert all(isinstance(k, int) and k >= 0 for k in result.keys()) assert all(isinstance(v, float) and v > 0 for v in result.values()) # SPLADE typically produces 100-300 non-zero dimensions assert 50 < len(result) < 500 # Verify keys are sorted in ascending order keys = list(result.keys()) assert keys == sorted(keys), "Real model output must be sorted by indices" # Test callable interface result2 = sparse_emb(text) assert result == result2 @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test: requires ZVEC_RUN_INTEGRATION_TESTS=1", ) def test_integration_multiple_inputs(self): """Integration test with multiple different inputs.""" # Clear model cache from zvec.extension.sentence_transformer_embedding_function import ( DefaultLocalSparseEmbedding, ) DefaultLocalSparseEmbedding.clear_cache() sparse_emb = DefaultLocalSparseEmbedding() texts = [ "Hello, world!", "Machine learning is fascinating", "Python programming language", ] results = [sparse_emb.embed(text) for text in texts] # All results should be different assert len(results) == 3 assert all(isinstance(r, dict) for r in results) # Different inputs should produce different sparse vectors assert results[0] != results[1] assert results[1] != results[2] # All results must be sorted by indices for i, result in enumerate(results): keys = list(result.keys()) assert keys == sorted(keys), f"Result {i} must have sorted keys" # ---------------------------- # BM25EmbeddingFunction Test Case # ---------------------------- class TestBM25EmbeddingFunction: """Test suite for BM25EmbeddingFunction (BM25-based sparse embedding using DashText SDK).""" def test_init_with_built_in_encoder(self): """Test successful initialization with built-in encoder (no corpus).""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext # Test with default language (Chinese) bm25 = BM25EmbeddingFunction() assert bm25.corpus_size == 0 assert bm25.encoding_type == "query" assert bm25.language == "zh" mock_dashtext.SparseVectorEncoder.default.assert_called_once_with(name="zh") def test_init_with_custom_encoder(self): """Test successful initialization with custom encoder (with corpus).""" corpus = [ "a cat is a feline and likes to purr", "a dog is the human's best friend", "a bird is a beautiful animal that can fly", ] with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction(corpus=corpus, b=0.75, k1=1.2) assert bm25.corpus_size == 3 assert bm25.encoding_type == "query" mock_dashtext.SparseVectorEncoder.assert_called_once_with(b=0.75, k1=1.2) mock_encoder.train.assert_called_once_with(corpus) def test_init_with_empty_corpus(self): """Test initialization with empty corpus raises ValueError.""" with pytest.raises(ValueError, match="Corpus must be a non-empty list"): BM25EmbeddingFunction(corpus=[]) def test_init_with_invalid_corpus(self): """Test initialization with invalid corpus elements.""" with pytest.raises(ValueError, match="All corpus documents must be strings"): BM25EmbeddingFunction(corpus=["text", 123, "another"]) with pytest.raises(ValueError, match="All corpus documents must be strings"): BM25EmbeddingFunction(corpus=[None, "text"]) def test_init_with_language_parameter(self): """Test initialization with different language settings.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext # Test English language bm25_en = BM25EmbeddingFunction(language="en") assert bm25_en.language == "en" mock_dashtext.SparseVectorEncoder.default.assert_called_with(name="en") def test_init_with_encoding_type(self): """Test initialization with different encoding types.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext # Test document encoding type bm25_doc = BM25EmbeddingFunction(encoding_type="document") assert bm25_doc.encoding_type == "document" def test_init_with_missing_dashtext_library(self): """Test initialization fails when dashtext library is not installed.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_require.side_effect = ImportError("dashtext package is required") with pytest.raises(ImportError, match="dashtext package is required"): BM25EmbeddingFunction() def test_embed_with_query_encoding(self): """Test successful sparse embedding generation with query encoding.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() # Mock encode_queries to return sparse vector mock_encoder.encode_queries.return_value = { 5: 0.89, 12: 1.45, 23: 0.67, 45: 1.12, } mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction(encoding_type="query") # Clear LRU cache to ensure fresh call bm25.embed.cache_clear() result = bm25.embed("cat purr loud") # Verify result structure assert isinstance(result, dict) assert all(isinstance(k, int) for k in result.keys()) assert all(isinstance(v, float) for v in result.values()) # Verify all values are positive assert all(v > 0 for v in result.values()) # Verify output is sorted by indices keys = list(result.keys()) assert keys == sorted(keys), "Output must be sorted by indices" # Verify expected keys from mock response assert result == {5: 0.89, 12: 1.45, 23: 0.67, 45: 1.12} # Verify encode_queries was called mock_encoder.encode_queries.assert_called_once_with("cat purr loud") def test_embed_with_document_encoding(self): """Test successful sparse embedding generation with document encoding.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() # Mock encode_documents to return sparse vector mock_encoder.encode_documents.return_value = {10: 1.5, 20: 2.3} mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction(encoding_type="document") bm25.embed.cache_clear() result = bm25.embed("document text") assert result == {10: 1.5, 20: 2.3} mock_encoder.encode_documents.assert_called_once_with("document text") def test_embed_with_empty_input(self): """Test embedding with empty input raises ValueError.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction() with pytest.raises(ValueError, match="Input text cannot be empty"): bm25.embed("") with pytest.raises(ValueError, match="Input text cannot be empty"): bm25.embed(" ") def test_embed_with_non_string_input(self): """Test embedding with non-string input raises TypeError.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction() # Test with hashable non-string types - should get our custom error message with pytest.raises(TypeError, match="Expected 'input' to be str"): bm25.embed(123) with pytest.raises(TypeError, match="Expected 'input' to be str"): bm25.embed(None) # Test with unhashable type (list) # Note: lru_cache raises TypeError("unhashable type: 'list'") before our type check # This is still a valid type error, just caught at a different layer with pytest.raises(TypeError, match="unhashable type"): bm25.embed(["text"]) def test_embed_callable_interface(self): """Test that BM25EmbeddingFunction is callable.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_encoder.encode_queries.return_value = {10: 1.5} mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction() bm25.embed.cache_clear() # Test callable interface result = bm25("test query") assert isinstance(result, dict) assert 10 in result def test_embed_output_sorted_by_indices(self): """Test that output is always sorted by indices in ascending order.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() # Mock encode_queries with unsorted indices mock_encoder.encode_queries.return_value = { 9999: 1.5, 5: 2.0, 1234: 0.8, 77: 3.2, 500: 1.1, } mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction() bm25.embed.cache_clear() result = bm25.embed("test query") # Verify keys are sorted result_keys = list(result.keys()) assert result_keys == sorted(result_keys), ( f"Keys must be sorted. Got: {result_keys}, Expected: {sorted(result_keys)}" ) # Verify expected sorted order: [5, 77, 500, 1234, 9999] expected_keys = [5, 77, 500, 1234, 9999] assert result_keys == expected_keys def test_embed_filters_zero_values(self): """Test that zero and negative values are filtered out.""" with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() # Mock encode_queries with zero and negative values mock_encoder.encode_queries.return_value = { 0: 1.5, # Positive - should be included 1: 0.0, # Zero - should be filtered 2: -0.5, # Negative - should be filtered } mock_dashtext.SparseVectorEncoder.default.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction() bm25.embed.cache_clear() result = bm25.embed("test") # Only positive token should be in result assert 0 in result assert 1 not in result # Zero value filtered assert 2 not in result # Negative value filtered assert all(v > 0 for v in result.values()) def test_properties(self): """Test property accessors.""" corpus = ["doc1", "doc2", "doc3"] with patch( "zvec.extension.bm25_embedding_function.require_module" ) as mock_require: mock_dashtext = Mock() mock_encoder = Mock() mock_dashtext.SparseVectorEncoder.return_value = mock_encoder mock_require.return_value = mock_dashtext bm25 = BM25EmbeddingFunction( corpus=corpus, encoding_type="document", language="en", b=0.8, k1=1.5, custom_param="test", ) assert bm25.corpus_size == 3 assert bm25.encoding_type == "document" assert bm25.language == "en" assert bm25.extra_params == {"custom_param": "test"} @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_dashtext_bm25_embedding(self): """Integration test with real DashText library. To run this test: export ZVEC_RUN_INTEGRATION_TESTS=1 pip install dashtext Note: This test requires the dashtext package to be installed. """ # Test built-in encoder (Chinese) bm25_zh = BM25EmbeddingFunction(language="zh", encoding_type="query") query_zh = "什么是向量检索服务" result_zh = bm25_zh.embed(query_zh) assert isinstance(result_zh, dict) assert len(result_zh) > 0 assert all(isinstance(k, int) for k in result_zh.keys()) assert all(isinstance(v, float) and v > 0 for v in result_zh.values()) # Verify sorted output keys = list(result_zh.keys()) assert keys == sorted(keys), "Real DashText BM25 output must be sorted" # Test custom corpus corpus = [ "The cat sits on the mat", "The dog plays in the garden", "Birds fly in the sky", "Fish swim in the water", ] bm25_custom = BM25EmbeddingFunction(corpus=corpus, encoding_type="query") query_en = "cat on mat" result_en = bm25_custom.embed(query_en) assert isinstance(result_en, dict) assert len(result_en) > 0 assert all(isinstance(k, int) for k in result_en.keys()) assert all(isinstance(v, float) and v > 0 for v in result_en.values()) # Test callable interface result2 = bm25_custom(query_en) assert result_en == result2 # Verify properties assert bm25_custom.corpus_size == 4 ================================================ FILE: python/tests/test_params.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import sys import time import numpy as np import pytest from zvec import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, HnswIndexParam, IndexOption, InvertIndexParam, IVFIndexParam, OptimizeOption, HnswQueryParam, IVFQueryParam, VectorQuery, IndexType, MetricType, QuantizeType, DataType, VectorSchema, ) from _zvec.param import _VectorQuery # ---------------------------- # Invert Index Param Test Case # ---------------------------- class TestInvertIndexParam: def test_default(self): param = InvertIndexParam() assert param.enable_range_optimization is False assert param.enable_extended_wildcard is False assert param.type == IndexType.INVERT def test_custom(self): param = InvertIndexParam( enable_range_optimization=True, enable_extended_wildcard=True ) assert param.enable_range_optimization is True assert param.enable_extended_wildcard is True def test_readonly(self): param = InvertIndexParam() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): param.enable_range_optimization = False param.enable_extended_wildcard = False # ---------------------------- # Hnsw Index Param Test Case # ---------------------------- class TestHnswIndexParam: def test_default(self): param = HnswIndexParam() assert param.metric_type == MetricType.IP assert param.m == 50 assert param.ef_construction == 500 assert param.quantize_type == QuantizeType.UNDEFINED assert param.type == IndexType.HNSW def test_custom(self): param = HnswIndexParam( metric_type=MetricType.L2, m=10, ef_construction=1000, quantize_type=QuantizeType.FP16, ) assert param.metric_type == MetricType.L2 assert param.m == 10 assert param.ef_construction == 1000 assert param.quantize_type == QuantizeType.FP16 @pytest.mark.parametrize( "attr", ["metric_type", "m", "ef_construction", "quantize_type"] ) def test_readonly_attributes(self, attr): param = HnswIndexParam() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # Flat Index Param Test Case # ---------------------------- class TestFlatIndexParam: def test_default(self): param = FlatIndexParam() assert param.type == IndexType.FLAT assert param.quantize_type == QuantizeType.UNDEFINED assert param.metric_type == MetricType.IP def test_custom(self): param = FlatIndexParam( metric_type=MetricType.L2, quantize_type=QuantizeType.INT8 ) assert param.metric_type == MetricType.L2 assert param.quantize_type == QuantizeType.INT8 @pytest.mark.parametrize("attr", ["metric_type", "quantize_type"]) def test_readonly_attributes(self, attr): param = FlatIndexParam() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # Ivf Index Param Test Case # ---------------------------- class TestIVFIndexParam: def test_default(self): param = IVFIndexParam() assert param.metric_type == MetricType.IP assert param.n_list == 0 assert param.quantize_type == QuantizeType.UNDEFINED assert param.type == IndexType.IVF def test_custom(self): param = IVFIndexParam( metric_type=MetricType.L2, n_list=1000, quantize_type=QuantizeType.FP16 ) assert param.metric_type == MetricType.L2 assert param.n_list == 1000 assert param.quantize_type == QuantizeType.FP16 assert param.type == IndexType.IVF @pytest.mark.parametrize("attr", ["metric_type", "n_list", "quantize_type"]) def test_readonly_attributes(self, attr): param = IVFIndexParam() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # CollectionOption Test Case # ---------------------------- class TestCollectionOption: def test_default(self): option = CollectionOption() assert option is not None assert option.read_only == False assert option.enable_mmap == True def test_custom(self): option = CollectionOption(read_only=True, enable_mmap=False) assert option.read_only == True assert option.enable_mmap == False option = CollectionOption(read_only=False, enable_mmap=True) assert option.read_only == False assert option.enable_mmap == True @pytest.mark.parametrize("attr", ["read_only", "enable_mmap"]) def test_readonly_attributes(self, attr): param = CollectionOption() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # IndexOption Test Case # ---------------------------- class TestIndexOption: def test_default(self): option = IndexOption() assert option is not None assert option.concurrency == 0 def test_custom(self): option = IndexOption(concurrency=10) assert option.concurrency == 10 @pytest.mark.parametrize("attr", ["concurrency"]) def test_readonly_attributes(self, attr): param = IndexOption() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # AddColumnOption Test Case # ---------------------------- class TestAddColumnOption: def test_default(self): option = AddColumnOption() assert option is not None assert option.concurrency == 0 def test_custom(self): option = AddColumnOption(concurrency=10) assert option.concurrency == 10 @pytest.mark.parametrize("attr", ["concurrency"]) def test_readonly_attributes(self, attr): param = AddColumnOption() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # AlterColumnOption Test Case # ---------------------------- class TestAlterColumnOption: def test_default(self): option = AlterColumnOption() assert option is not None assert option.concurrency == 0 def test_custom(self): option = AlterColumnOption(concurrency=10) assert option.concurrency == 10 @pytest.mark.parametrize("attr", ["concurrency"]) def test_readonly_attributes(self, attr): param = AlterColumnOption() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # OptimizeOption Test Case # ---------------------------- class TestOptimizeOption: def test_default(self): option = OptimizeOption() assert option is not None assert option.concurrency == 0 def test_custom(self): option = OptimizeOption(concurrency=10) assert option.concurrency == 10 @pytest.mark.parametrize("attr", ["concurrency"]) def test_readonly_attributes(self, attr): param = OptimizeOption() import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): setattr(param, attr, getattr(param, attr)) # ---------------------------- # HnswQueryParam Test Case # ---------------------------- class TestHnswQueryParam: def test_default(self): param = HnswQueryParam() assert param is not None assert param.ef == 300 assert param.is_using_refiner == False assert param.radius == 0 assert param.is_linear == False def test_custom(self): param = HnswQueryParam(ef=10, is_using_refiner=True, radius=30, is_linear=True) assert param.ef == 10 assert param.is_using_refiner == True assert param.radius == 30 assert param.is_linear == True def test_readonly_attributes(self): param = HnswQueryParam() if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): param.ef = 10 param.is_using_refiner = True param.radius = 30 param.is_linear = True # # ---------------------------- # # IVFQueryParam Test Case # # ---------------------------- # class TestIVFQueryParam: # def test_default(self): # param = IVFQueryParam() # assert param is not None # assert param.nprobe == 10 # assert param.is_using_refiner == False # assert param.radius == 0 # assert param.is_linear == False # assert param.scale_factor == 10 # # def test_custom(self): # param = IVFQueryParam( # nprobe=20, # is_using_refiner=True, # radius=30, # is_linear=True, # scale_factor=40 # ) # assert param.nprobe == 20 # assert param.is_using_refiner == True # assert param.radius == 30 # assert param.is_linear == True # assert param.scale_factor == 40 class TestVectorQuery: def test_init_with_valid_id(self): vq = VectorQuery(field_name="embedding", id="doc123") assert vq.field_name == "embedding" assert vq.id == "doc123" assert vq.vector is None assert vq.param is None def test_init_with_valid_vector(self): vec = [0.1, 0.2, 0.3] param = HnswQueryParam(ef=300) vq = VectorQuery(field_name="embedding", vector=vec, param=param) assert vq.field_name == "embedding" assert vq.vector == vec assert vq.param == param def test_init_both_id_and_vector_raises_error(self): with pytest.raises(ValueError): VectorQuery(field_name="embedding", id="doc123", vector=[0.1])._validate() def test_init_without_field_name_raises_error(self): with pytest.raises(ValueError): VectorQuery(field_name=None)._validate() def test_has_id_returns_true_when_id_set(self): vq = VectorQuery(field_name="embedding", id="doc123") assert vq.has_id() def test_has_id_returns_false_when_no_id(self): vq = VectorQuery(field_name="embedding", vector=[0.1]) assert not vq.has_id() def test_has_vector_returns_true_with_non_empty_vector(self): vq = VectorQuery(field_name="embedding", vector=[0.1]) assert vq.has_vector() def test_validate_fails_on_both_id_and_vector(self): vq = VectorQuery(field_name="test", id="doc123", vector=[0.1]) with pytest.raises(ValueError): vq._validate() ================================================ FILE: python/tests/test_query_executor.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Dict, Union from unittest.mock import MagicMock import numpy as np import math from _zvec.param import _VectorQuery import pytest from zvec.executor.query_executor import ( MultiVectorQueryExecutor, NoVectorQueryExecutor, QueryContext, QueryExecutor, QueryExecutorFactory, SingleVectorQueryExecutor, VectorQuery, ) from zvec import RrfReRanker, HnswQueryParam, CollectionSchema, VectorSchema, DataType # ---------------------------- # Mock Vector Schema # ---------------------------- class MockVectorSchema(VectorSchema): def __init__(self, name="test_vector"): self._name = name @property def name(self): return self._name def _get_object(self): return MagicMock() # ---------------------------- # Mock Collection Schema # ---------------------------- class MockCollectionSchema(CollectionSchema): def __init__(self, vectors=Union[VectorSchema, Dict[str, VectorSchema]]): self._vectors = ( [vectors] if not isinstance(vectors, Dict) else list(vectors.values()) ) @property def vectors(self): return self._vectors # ---------------------------- # VectorQuery Test Case # ---------------------------- class TestVectorQuery: def test_init(self): query = VectorQuery(field_name="test_field") assert query.field_name == "test_field" assert query.id is None assert query.vector is None assert query.param is None param = HnswQueryParam() query = VectorQuery( field_name="test_field", id="test_id", vector=[1, 2, 3], param=param ) assert query.field_name == "test_field" assert query.id == "test_id" assert query.vector == [1, 2, 3] assert query.param == param def test_has_id(self): query = VectorQuery(field_name="test_field") assert not query.has_id() query = VectorQuery(field_name="test_field", id="test_id") assert query.has_id() def test_has_vector(self): query = VectorQuery(field_name="test_field") assert not query.has_vector() query = VectorQuery(field_name="test_field", vector=[]) assert not query.has_vector() query = VectorQuery(field_name="test_field", vector=[1, 2, 3]) assert query.has_vector() def test_validate_dense_fp16_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.VECTOR_FP16) vec = np.array([1.1, 2.1, 3.1], dtype=np.float16) v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) assert np.array_equal(vec, ret) def test_validate_dense_fp32_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.VECTOR_FP32) vec = np.array([1.1, 2.1, 3.1], dtype=np.float32) v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) assert np.array_equal(vec, ret) def test_validate_dense_fp64_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.VECTOR_FP64) vec = np.array([1.1, 2.1, 3.1], dtype=np.float64) v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) assert np.array_equal(vec, ret) def test_validate_dense_int8_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.VECTOR_INT8) vec = np.array([1, 2, 3], dtype=np.int8) v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) assert np.array_equal(vec, ret) def test_validate_sparse_fp32_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.SPARSE_VECTOR_FP32) vec = {1: 1.1, 2: 2.2, 3: 3.3} v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) for k in vec.keys(): assert math.isclose(vec[k], ret[k], abs_tol=1e-6) def test_validate_sparse_fp16_convert(self): v = _VectorQuery() schema = VectorSchema(name="test", data_type=DataType.SPARSE_VECTOR_FP16) vec = {1: 1.1, 2: 2.2, 3: 3.3} v.set_vector(schema._get_object(), vec) ret = v.get_vector(schema._get_object()) for k in vec.keys(): assert math.isclose(np.float16(vec[k]), ret[k], abs_tol=1e-6) class TestQueryContext: def test_init(self): ctx = QueryContext(topk=10) assert ctx.topk == 10 assert ctx.queries == [] assert ctx.filter is None assert ctx.reranker is None assert ctx.output_fields is None assert ctx.include_vector is False assert ctx.core_vectors == [] def test_properties(self): queries = [VectorQuery(field_name="test")] reranker = RrfReRanker() output_fields = ["field1", "field2"] ctx = QueryContext( topk=5, filter="test_filter", include_vector=True, queries=queries, output_fields=output_fields, reranker=reranker, ) assert ctx.topk == 5 assert ctx.queries == queries assert ctx.filter == "test_filter" assert ctx.reranker == reranker assert ctx.output_fields == output_fields assert ctx.include_vector is True def test_core_vectors_setter(self): ctx = QueryContext(topk=10) core_vectors = [MagicMock()] ctx.core_vectors = core_vectors assert ctx.core_vectors == core_vectors class TestNoVectorQueryExecutor: def test_init(self): schema = MockCollectionSchema() executor = NoVectorQueryExecutor(schema) assert isinstance(executor, QueryExecutor) def test_do_validate_with_queries(self): schema = MockCollectionSchema() executor = NoVectorQueryExecutor(schema) ctx = QueryContext(topk=10, queries=[VectorQuery(field_name="test")]) with pytest.raises( ValueError, match="Collection does not support query with vector or id" ): executor._do_validate(ctx) def test_do_validate_without_queries(self): schema = MockCollectionSchema() executor = NoVectorQueryExecutor(schema) ctx = QueryContext(topk=10) executor._do_validate(ctx) def test_do_build(self): schema = MockCollectionSchema() executor = NoVectorQueryExecutor(schema) ctx = QueryContext(topk=5, filter="test_filter") result = executor._do_build(ctx, MagicMock()) assert len(result) == 1 assert result[0].topk == 5 assert result[0].filter == "test_filter" class TestSingleVectorQueryExecutor: def test_init(self): schema = MockCollectionSchema() executor = SingleVectorQueryExecutor(schema) assert isinstance(executor, NoVectorQueryExecutor) def test_do_validate_multiple_queries(self): schema = MockCollectionSchema() executor = SingleVectorQueryExecutor(schema) queries = [VectorQuery(field_name="test1"), VectorQuery(field_name="test2")] ctx = QueryContext(topk=10, queries=queries) with pytest.raises( ValueError, match="Collection has only one vector field, cannot query with multiple vectors", ): executor._do_validate(ctx) def test_do_build_without_queries(self): schema = MockCollectionSchema() executor = SingleVectorQueryExecutor(schema) ctx = QueryContext(topk=5) result = executor._do_build(ctx, MagicMock()) assert len(result) == 1 assert result[0].topk == 5 class TestMultiVectorQueryExecutor: def test_init(self): schema = MockCollectionSchema() executor = MultiVectorQueryExecutor(schema) assert isinstance(executor, SingleVectorQueryExecutor) def test_do_validate_multiple_queries_without_reranker(self): schema = MockCollectionSchema() executor = MultiVectorQueryExecutor(schema) queries = [VectorQuery(field_name="test1"), VectorQuery(field_name="test2")] ctx = QueryContext(topk=10, queries=queries) with pytest.raises( ValueError, match="Reranker is required for multi-vector query" ): executor._do_validate(ctx) def test_do_validate_multiple_queries_with_reranker(self): schema = MockCollectionSchema() executor = MultiVectorQueryExecutor(schema) queries = [VectorQuery(field_name="test1"), VectorQuery(field_name="test2")] reranker = RrfReRanker() ctx = QueryContext(topk=10, queries=queries, reranker=reranker) executor._do_validate(ctx) class TestQueryExecutorFactory: def test_create_no_vectors(self): schema = MockCollectionSchema() executor = QueryExecutorFactory.create(schema) assert isinstance(executor, NoVectorQueryExecutor) def test_create_single_vector(self): schema = MockCollectionSchema(vectors=MockVectorSchema()) executor = QueryExecutorFactory.create(schema) assert isinstance(executor, SingleVectorQueryExecutor) def test_create_multiple_vectors(self): schema = MockCollectionSchema( vectors={"test1": MockVectorSchema(), "test2": MockVectorSchema()} ) executor = QueryExecutorFactory.create(schema) assert isinstance(executor, MultiVectorQueryExecutor) ================================================ FILE: python/tests/test_reranker.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from unittest.mock import patch, MagicMock import pytest import math import os from zvec import Doc, MetricType from zvec.extension.multi_vector_reranker import ( RrfReRanker, WeightedReRanker, ) from zvec.extension.sentence_transformer_rerank_function import ( DefaultLocalReRanker, ) from zvec.extension.qwen_rerank_function import QwenReRanker # Set ZVEC_RUN_INTEGRATION_TESTS=1 to run real API tests RUN_INTEGRATION_TESTS = os.environ.get("ZVEC_RUN_INTEGRATION_TESTS", "0") == "1" # ---------------------------- # RrfRanker Test Case # ---------------------------- class TestRrfReRanker: def test_init(self): reranker = RrfReRanker(topn=5, rerank_field="content", rank_constant=100) assert reranker.topn == 5 assert reranker.rerank_field == "content" assert reranker.rank_constant == 100 def test_rrf_score(self): reranker = RrfReRanker(rank_constant=60) # 根据公式 1.0 / (k + rank + 1),其中k=60 assert reranker._rrf_score(0) == 1.0 / (60 + 0 + 1) assert reranker._rrf_score(1) == 1.0 / (60 + 1 + 1) assert reranker._rrf_score(10) == 1.0 / (60 + 10 + 1) def test_rerank(self): reranker = RrfReRanker(topn=3) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.7) doc3 = Doc(id="3", score=0.9) doc4 = Doc(id="4", score=0.6) query_results = {"vector1": [doc1, doc2, doc3], "vector2": [doc3, doc1, doc4]} results = reranker.rerank(query_results) assert len(results) <= reranker.topn for doc in results: assert hasattr(doc, "score") scores = [doc.score for doc in results] assert scores == sorted(scores, reverse=True) # ---------------------------- # WeightedRanker Test Case # ---------------------------- class TestWeightedReRanker: def test_init(self): weights = {"vector1": 0.7, "vector2": 0.3} reranker = WeightedReRanker( topn=5, rerank_field="content", metric=MetricType.L2, weights=weights, ) assert reranker.topn == 5 assert reranker.rerank_field == "content" assert reranker.metric == MetricType.L2 assert reranker.weights == weights def test_normalize_score(self): reranker = WeightedReRanker() score = reranker._normalize_score(1.0, MetricType.L2) expected = 1.0 - 2 * math.atan(1.0) / math.pi assert score == expected score = reranker._normalize_score(1.0, MetricType.IP) expected = 0.5 + math.atan(1.0) / math.pi assert score == expected score = reranker._normalize_score(1.0, MetricType.COSINE) expected = 1.0 - 1.0 / 2.0 assert score == expected with pytest.raises(ValueError, match="Unsupported metric type"): reranker._normalize_score(1.0, "unsupported_metric") def test_rerank(self): weights = {"vector1": 0.7, "vector2": 0.3} reranker = WeightedReRanker(topn=3, weights=weights, metric=MetricType.L2) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.7) doc3 = Doc(id="3", score=0.9) query_results = {"vector1": [doc1, doc2], "vector2": [doc2, doc3]} results = reranker.rerank(query_results) assert len(results) <= reranker.topn for doc in results: assert hasattr(doc, "score") scores = [doc.score for doc in results] assert scores == sorted(scores, reverse=True) # ---------------------------- # QwenReRanker Test Case # ---------------------------- class TestQwenReRanker: def test_init_without_query(self): with pytest.raises(ValueError, match="Query is required for QwenReRanker"): QwenReRanker(api_key="test_key") def test_init_without_api_key(self): with patch.dict(os.environ, {}, clear=True): with pytest.raises(ValueError, match="DashScope API key is required"): QwenReRanker(query="test") @patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key"}) def test_init_with_env_api_key(self): reranker = QwenReRanker(query="test", rerank_field="content") assert reranker.query == "test" assert reranker._api_key == "test_key" assert reranker.rerank_field == "content" def test_init_with_explicit_api_key(self): reranker = QwenReRanker( query="test", api_key="explicit_key", rerank_field="content" ) assert reranker.query == "test" assert reranker._api_key == "explicit_key" def test_model_property(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) assert reranker.model == "gte-rerank-v2" reranker = QwenReRanker( query="test", model="custom-model", api_key="test_key", rerank_field="content", ) assert reranker.model == "custom-model" def test_query_property(self): reranker = QwenReRanker( query="test query", api_key="test_key", rerank_field="content" ) assert reranker.query == "test query" def test_topn_property(self): reranker = QwenReRanker( query="test", topn=5, api_key="test_key", rerank_field="content" ) assert reranker.topn == 5 def test_rerank_field_property(self): reranker = QwenReRanker(query="test", api_key="test_key", rerank_field="title") assert reranker.rerank_field == "title" def test_rerank_empty_results(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) results = reranker.rerank({}) assert results == [] def test_rerank_no_valid_documents(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) # Document without the rerank_field query_results = {"vector1": [Doc(id="1")]} with pytest.raises(ValueError, match="No documents to rerank"): reranker.rerank(query_results) def test_rerank_skip_empty_content(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) query_results = { "vector1": [ Doc(id="1", fields={"content": ""}), Doc(id="2", fields={"content": " "}), ] } with pytest.raises(ValueError, match="No documents to rerank"): reranker.rerank(query_results) @patch("zvec.extension.qwen_function.require_module") def test_rerank_success(self, mock_require_module): # Mock dashscope module mock_dashscope = MagicMock() mock_require_module.return_value = mock_dashscope # Mock API response mock_response = MagicMock() mock_response.status_code = 200 mock_response.output = { "results": [ {"index": 0, "relevance_score": 0.95}, {"index": 1, "relevance_score": 0.85}, ] } mock_dashscope.TextReRank.call.return_value = mock_response reranker = QwenReRanker( query="test query", topn=2, api_key="test_key", rerank_field="content" ) query_results = { "vector1": [ Doc(id="1", fields={"content": "Document 1"}), Doc(id="2", fields={"content": "Document 2"}), ] } results = reranker.rerank(query_results) assert len(results) == 2 assert results[0].id == "1" assert results[0].score == 0.95 assert results[1].id == "2" assert results[1].score == 0.85 # Verify API call mock_dashscope.TextReRank.call.assert_called_once_with( model="gte-rerank-v2", query="test query", documents=["Document 1", "Document 2"], top_n=2, return_documents=False, ) @patch("zvec.extension.qwen_function.require_module") def test_rerank_deduplicate_documents(self, mock_require_module): # Mock dashscope module mock_dashscope = MagicMock() mock_require_module.return_value = mock_dashscope # Mock API response mock_response = MagicMock() mock_response.status_code = 200 mock_response.output = { "results": [ {"index": 0, "relevance_score": 0.9}, ] } mock_dashscope.TextReRank.call.return_value = mock_response reranker = QwenReRanker( query="test", topn=5, api_key="test_key", rerank_field="content" ) # Same document in multiple vector results doc1 = Doc(id="1", fields={"content": "Document 1"}) query_results = {"vector1": [doc1], "vector2": [doc1]} results = reranker.rerank(query_results) # Should only call API with document once call_args = mock_dashscope.TextReRank.call.call_args assert len(call_args[1]["documents"]) == 1 @patch("zvec.extension.qwen_function.require_module") def test_rerank_api_error(self, mock_require_module): # Mock dashscope module mock_dashscope = MagicMock() mock_require_module.return_value = mock_dashscope # Mock API error response mock_response = MagicMock() mock_response.status_code = 400 mock_response.message = "Invalid request" mock_response.code = "InvalidParameter" mock_dashscope.TextReRank.call.return_value = mock_response reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(ValueError, match="DashScope API error"): reranker.rerank(query_results) @patch("zvec.extension.qwen_function.require_module") def test_rerank_runtime_error(self, mock_require_module): # Mock dashscope module that raises exception mock_dashscope = MagicMock() mock_require_module.return_value = mock_dashscope mock_dashscope.TextReRank.call.side_effect = Exception("Network error") reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(RuntimeError, match="Failed to call DashScope API"): reranker.rerank(query_results) @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_qwen_rerank(self): """Integration test with real DashScope TextReRank API. To run this test, set environment variables: export ZVEC_RUN_INTEGRATION_TESTS=1 export DASHSCOPE_API_KEY=your-api-key """ # Create reranker with real API reranker = QwenReRanker( query="What is machine learning?", topn=3, rerank_field="content", model="gte-rerank-v2", ) # Prepare test documents query_results = { "vector1": [ Doc( id="1", score=0.8, fields={ "content": "Machine learning is a subset of artificial intelligence that focuses on building systems that can learn from data." }, ), Doc( id="2", score=0.7, fields={ "content": "The weather is nice today with clear skies and sunshine." }, ), Doc( id="3", score=0.75, fields={ "content": "Deep learning is a specialized branch of machine learning using neural networks with multiple layers." }, ), ], "vector2": [ Doc( id="4", score=0.6, fields={ "content": "Python is a popular programming language for data science and machine learning applications." }, ), Doc( id="5", score=0.65, fields={ "content": "A recipe for chocolate cake includes flour, sugar, eggs, and cocoa powder." }, ), ], } # Call real API results = reranker.rerank(query_results) # Verify results assert len(results) <= 3, "Should return at most topn documents" assert len(results) > 0, "Should return at least one document" # All results should have valid scores for doc in results: assert hasattr(doc, "score"), "Each document should have a score" assert isinstance(doc.score, (int, float)), "Score should be numeric" assert doc.score > 0, "Score should be positive" # Verify scores are in descending order scores = [doc.score for doc in results] assert scores == sorted(scores, reverse=True), ( "Results should be sorted by score in descending order" ) # Verify relevant documents are ranked higher # Document 1 and 3 are about machine learning, should rank higher than weather/recipe docs result_ids = [doc.id for doc in results] # At least one of the ML-related documents should be in top results ml_related_docs = {"1", "3", "4"} assert any(doc_id in ml_related_docs for doc_id in result_ids[:2]), ( "ML-related documents should rank higher" ) # Print results for manual verification (useful during development) print("\nReranking results:") for i, doc in enumerate(results, 1): print(f"{i}. ID={doc.id}, Score={doc.score:.4f}") if doc.fields: content = doc.field("content") if content: print(f" Content: {content[:80]}...") # ---------------------------- # DefaultLocalReRanker Test Case # ---------------------------- class TestDefaultLocalReRanker: """Test cases for DefaultLocalReRanker.""" def test_init_without_query(self): """Test initialization fails without query.""" with pytest.raises( ValueError, match="Query is required for DefaultLocalReRanker" ): DefaultLocalReRanker(rerank_field="content") def test_init_with_empty_query(self): """Test initialization fails with empty query.""" with pytest.raises( ValueError, match="Query is required for DefaultLocalReRanker" ): DefaultLocalReRanker(query="", rerank_field="content") @patch("zvec.extension.sentence_transformer_rerank_function.require_module") def test_init_success(self, mock_require_module): """Test successful initialization with mocked model.""" # Mock sentence_transformers module mock_st = MagicMock() mock_model = MagicMock() mock_model.predict = MagicMock() # Cross-encoder has predict method mock_model.device = "cpu" mock_st.CrossEncoder.return_value = mock_model mock_require_module.return_value = mock_st reranker = DefaultLocalReRanker( query="test query", topn=5, rerank_field="content", model_name="cross-encoder/ms-marco-MiniLM-L6-v2", ) assert reranker.query == "test query" assert reranker.topn == 5 assert reranker.rerank_field == "content" assert reranker.model_name == "cross-encoder/ms-marco-MiniLM-L6-v2" assert reranker.model_source == "huggingface" assert reranker.batch_size == 32 @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) @patch("zvec.extension.sentence_transformer_rerank_function.require_module") def test_init_with_custom_params(self, mock_require_module): """Test initialization with custom parameters.""" mock_st = MagicMock() mock_model = MagicMock() mock_model.predict = MagicMock() mock_model.device = "cuda" mock_st.CrossEncoder.return_value = mock_model mock_require_module.return_value = mock_st reranker = DefaultLocalReRanker( query="custom query", topn=10, rerank_field="title", model_name="cross-encoder/ms-marco-MiniLM-L12-v2", model_source="modelscope", device="cuda", batch_size=64, ) assert reranker.query == "custom query" assert reranker.topn == 10 assert reranker.rerank_field == "title" assert reranker.model_name == "cross-encoder/ms-marco-MiniLM-L12-v2" assert reranker.model_source == "modelscope" assert reranker.batch_size == 64 @patch("zvec.extension.sentence_transformer_rerank_function.require_module") def test_init_invalid_model(self, mock_require_module): """Test initialization fails with non-cross-encoder model.""" # Mock a model without predict method (not a cross-encoder) mock_st = MagicMock() mock_model = MagicMock(spec=[]) # No predict method mock_st.CrossEncoder.return_value = mock_model mock_require_module.return_value = mock_st with pytest.raises(ValueError, match="does not appear to be a cross-encoder"): DefaultLocalReRanker(query="test", rerank_field="content") def test_query_property(self): """Test query property.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test query", rerank_field="content") assert reranker.query == "test query" def test_topn_property(self): """Test topn property.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", topn=15, rerank_field="content" ) assert reranker.topn == 15 def test_rerank_field_property(self): """Test rerank_field property.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="title") assert reranker.rerank_field == "title" def test_batch_size_property(self): """Test batch_size property.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", rerank_field="content", batch_size=128 ) assert reranker.batch_size == 128 def test_rerank_empty_results(self): """Test rerank with empty query_results.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") results = reranker.rerank({}) assert results == [] def test_rerank_no_valid_documents(self): """Test rerank with documents missing rerank_field.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") # Document without the rerank_field query_results = {"vector1": [Doc(id="1")]} with pytest.raises(ValueError, match="No documents to rerank"): reranker.rerank(query_results) def test_rerank_skip_empty_content(self): """Test rerank skips documents with empty content.""" mock_model = MagicMock() mock_model.predict = MagicMock() mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") query_results = { "vector1": [ Doc(id="1", fields={"content": ""}), Doc(id="2", fields={"content": " "}), ] } with pytest.raises(ValueError, match="No documents to rerank"): reranker.rerank(query_results) def test_rerank_success(self): """Test successful rerank with mocked model.""" # Mock standard cross-encoder model mock_model = MagicMock() # Mock predict method to return scores import numpy as np mock_scores = np.array([0.95, 0.85, 0.75]) mock_model.predict.return_value = mock_scores mock_model.device = "cpu" # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test query", topn=3, rerank_field="content" ) query_results = { "vector1": [ Doc(id="1", score=0.8, fields={"content": "Document 1"}), Doc(id="2", score=0.7, fields={"content": "Document 2"}), Doc(id="3", score=0.6, fields={"content": "Document 3"}), ] } results = reranker.rerank(query_results) # Verify results assert len(results) == 3 assert results[0].id == "1" assert results[0].score == 0.95 assert results[1].id == "2" assert results[1].score == 0.85 assert results[2].id == "3" assert results[2].score == 0.75 # Verify model.predict was called correctly assert mock_model.predict.called call_args = mock_model.predict.call_args pairs = call_args[0][0] assert len(pairs) == 3 assert pairs[0] == ["test query", "Document 1"] assert pairs[1] == ["test query", "Document 2"] assert pairs[2] == ["test query", "Document 3"] assert call_args[1]["batch_size"] == 32 assert call_args[1]["show_progress_bar"] is False def test_rerank_with_topn_limit(self): """Test rerank respects topn limit.""" mock_model = MagicMock() import numpy as np mock_scores = np.array([0.9, 0.8, 0.7, 0.6, 0.5]) mock_model.predict.return_value = mock_scores # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", topn=2, rerank_field="content" ) query_results = { "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), Doc(id="3", fields={"content": "Doc 3"}), Doc(id="4", fields={"content": "Doc 4"}), Doc(id="5", fields={"content": "Doc 5"}), ] } results = reranker.rerank(query_results) # Should only return top 2 assert len(results) == 2 assert results[0].id == "1" assert results[0].score == 0.9 assert results[1].id == "2" assert results[1].score == 0.8 def test_rerank_deduplicate_documents(self): """Test rerank deduplicates documents across multiple vectors.""" mock_model = MagicMock() import numpy as np mock_scores = np.array([0.95, 0.85]) mock_model.predict.return_value = mock_scores # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", topn=5, rerank_field="content" ) # Same document in multiple vector results doc1 = Doc(id="1", fields={"content": "Document 1"}) doc2 = Doc(id="2", fields={"content": "Document 2"}) query_results = { "vector1": [doc1, doc2], "vector2": [doc1], # doc1 appears in both } results = reranker.rerank(query_results) # Should only process each document once assert len(results) == 2 assert mock_model.predict.call_count == 1 call_args = mock_model.predict.call_args pairs = call_args[0][0] assert len(pairs) == 2 # Only 2 unique documents def test_rerank_sorting(self): """Test rerank sorts documents by score in descending order.""" mock_model = MagicMock() import numpy as np # Return scores in non-sorted order mock_scores = np.array([0.6, 0.9, 0.7]) mock_model.predict.return_value = mock_scores # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", topn=3, rerank_field="content" ) query_results = { "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), Doc(id="3", fields={"content": "Doc 3"}), ] } results = reranker.rerank(query_results) # Should be sorted by score (descending) assert len(results) == 3 assert results[0].id == "2" # score 0.9 assert results[0].score == 0.9 assert results[1].id == "3" # score 0.7 assert results[1].score == 0.7 assert results[2].id == "1" # score 0.6 assert results[2].score == 0.6 def test_rerank_model_error(self): """Test rerank handles model prediction errors.""" mock_model = MagicMock() # Mock predict to raise exception mock_model.predict.side_effect = Exception("Model inference error") # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(RuntimeError, match="Failed to compute rerank scores"): reranker.rerank(query_results) def test_rerank_with_custom_batch_size(self): """Test rerank uses custom batch_size.""" mock_model = MagicMock() import numpy as np mock_scores = np.array([0.9, 0.8]) mock_model.predict.return_value = mock_scores # Mock sentence_transformers module mock_st = MagicMock() mock_st.CrossEncoder.return_value = mock_model with patch( "zvec.extension.sentence_transformer_rerank_function.require_module", return_value=mock_st, ): reranker = DefaultLocalReRanker( query="test", rerank_field="content", batch_size=64 ) query_results = { "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), ] } reranker.rerank(query_results) # Verify batch_size is passed to predict call_args = mock_model.predict.call_args assert call_args[1]["batch_size"] == 64 @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, reason="Integration test skipped. Set ZVEC_RUN_INTEGRATION_TESTS=1 to run.", ) def test_real_sentence_transformer_rerank(self): """Integration test with real SentenceTransformer cross-encoder model. To run this test, set environment variable: export ZVEC_RUN_INTEGRATION_TESTS=1 Note: This test requires sentence-transformers package and will download the MS MARCO MiniLM model (~80MB) on first run. """ # Create reranker with real model (using default lightweight model) reranker = DefaultLocalReRanker( query="What is machine learning?", topn=3, rerank_field="content", ) # Prepare test documents query_results = { "vector1": [ Doc( id="1", score=0.8, fields={ "content": "Machine learning is a subset of artificial intelligence that focuses on building systems that can learn from data." }, ), Doc( id="2", score=0.7, fields={ "content": "The weather is nice today with clear skies and sunshine." }, ), Doc( id="3", score=0.75, fields={ "content": "Deep learning is a specialized branch of machine learning using neural networks with multiple layers." }, ), ], "vector2": [ Doc( id="4", score=0.6, fields={ "content": "Python is a popular programming language for data science and machine learning applications." }, ), Doc( id="5", score=0.65, fields={ "content": "A recipe for chocolate cake includes flour, sugar, eggs, and cocoa powder." }, ), ], } # Call real model results = reranker.rerank(query_results) # Verify results assert len(results) <= 3, "Should return at most topn documents" assert len(results) > 0, "Should return at least one document" # All results should have valid scores for doc in results: assert hasattr(doc, "score"), "Each document should have a score" assert isinstance(doc.score, (int, float)), "Score should be numeric" # Verify scores are in descending order scores = [doc.score for doc in results] assert scores == sorted(scores, reverse=True), ( "Results should be sorted by score in descending order" ) # Verify relevant documents are ranked higher # Documents 1, 3, and 4 are about machine learning, should rank higher result_ids = [doc.id for doc in results] # At least one of the ML-related documents should be in top results ml_related_docs = {"1", "3", "4"} assert any(doc_id in ml_related_docs for doc_id in result_ids[:2]), ( "ML-related documents should rank higher" ) # Print results for manual verification (useful during development) print("\nSentenceTransformer Reranking results:") for i, doc in enumerate(results, 1): print(f"{i}. ID={doc.id}, Score={doc.score:.4f}") if doc.fields: content = doc.field("content") if content: print(f" Content: {content[:80]}...") ================================================ FILE: python/tests/test_schema.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import pytest from zvec import ( CollectionSchema, CollectionStats, FieldSchema, VectorSchema, HnswIndexParam, InvertIndexParam, DataType, IndexType, MetricType, ) # ---------------------------- # FieldSchema Test Case # ---------------------------- class TestFieldSchema: def test_default(self): field = FieldSchema("field", data_type=DataType.FLOAT) assert field.name == "field" assert field.data_type == DataType.FLOAT assert field.nullable is False assert field.index_param is None def test_custom(self): field_1 = FieldSchema( name="float", data_type=DataType.FLOAT, nullable=True, index_param=InvertIndexParam(), ) assert field_1.name == "float" assert field_1.data_type == DataType.FLOAT assert field_1.nullable is True assert field_1.index_param.enable_range_optimization is False field_2 = FieldSchema( name="str", data_type=DataType.STRING, nullable=True, index_param=InvertIndexParam(enable_range_optimization=True), ) assert field_2.name == "str" assert field_2.data_type == DataType.STRING assert field_2.nullable is True assert field_2.index_param.enable_range_optimization is True def test_readonly(self): field = FieldSchema( name="float", data_type=DataType.FLOAT, nullable=True, index_param=InvertIndexParam(), ) import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): field.index_param = InvertIndexParam(enable_range_optimization=True) # ---------------------------- # VectorSchema Test Case # ---------------------------- class TestVectorSchema: def test_default(self): field = VectorSchema("vector", data_type=DataType.VECTOR_FP32, dimension=128) assert field.name == "vector" assert field.data_type == DataType.VECTOR_FP32 assert field.dimension == 128 assert field.index_param is not None assert field.index_param.type == IndexType.FLAT assert field.index_param.metric_type == MetricType.IP def test_custom(self): field = VectorSchema( name="vector", data_type=DataType.VECTOR_INT8, dimension=512, index_param=HnswIndexParam( metric_type=MetricType.COSINE, m=15, ef_construction=300 ), ) assert field.name == "vector" assert field.data_type == DataType.VECTOR_INT8 assert field.index_param.metric_type == MetricType.COSINE assert field.index_param.m == 15 assert field.index_param.ef_construction == 300 def test_readonly(self): field = VectorSchema( name="vector", dimension=128, data_type=DataType.VECTOR_INT8, ) import sys if sys.version_info >= (3, 11): match_pattern = r"(can't set attribute|has no setter|readonly attribute)" else: match_pattern = r"can't set attribute" with pytest.raises(AttributeError, match=match_pattern): field.dimension = 4 # ---------------------------- # CollectionSchema Test Case # ---------------------------- class TestCollectionSchema: def test_collection_schema_with_single_field(self): collection_schema = CollectionSchema( name="test_collection", fields=FieldSchema( name="id", data_type=DataType.INT64, index_param=InvertIndexParam(), nullable=False, ), vectors=VectorSchema( name="vector", data_type=DataType.VECTOR_INT8, dimension=128, index_param=HnswIndexParam(), ), ) assert collection_schema is not None assert collection_schema.name == "test_collection" assert len(collection_schema.fields) == 1 assert len(collection_schema.vectors) == 1 field = collection_schema.field("id") assert field is not None assert field.name == "id" assert field.data_type == DataType.INT64 assert not field.nullable assert field.index_param.type == IndexType.INVERT assert not field.index_param.enable_range_optimization vector = collection_schema.vector("vector") assert vector is not None assert vector.name == "vector" assert vector.data_type == DataType.VECTOR_INT8 assert vector.dimension == 128 assert vector.index_param.type == IndexType.HNSW assert vector.index_param.m == 50 assert vector.index_param.ef_construction == 500 assert vector.index_param.metric_type == MetricType.IP def test_collection_schema_with_multi_fields(self): collection_schema = CollectionSchema( name="test_collection", fields=[ FieldSchema( "id", DataType.INT64, nullable=False, index_param=InvertIndexParam(enable_range_optimization=True), ), FieldSchema( "name", DataType.STRING, nullable=False, index_param=InvertIndexParam(), ), FieldSchema( "weight", DataType.INT32, nullable=True, ), ], vectors=[ VectorSchema( "dense", DataType.VECTOR_FP32, dimension=128, index_param=HnswIndexParam(), ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), ], ) assert collection_schema is not None assert collection_schema.name == "test_collection" assert len(collection_schema.fields) == 3 assert len(collection_schema.vectors) == 2 field_id = collection_schema.field("id") assert field_id is not None assert field_id.name == "id" assert field_id.data_type == DataType.INT64 assert not field_id.nullable assert field_id.index_param.type == IndexType.INVERT dense = collection_schema.vector("dense") assert dense is not None assert dense.name == "dense" assert dense.data_type == DataType.VECTOR_FP32 assert dense.dimension == 128 assert dense.index_param.type == IndexType.HNSW sparse = collection_schema.vector("sparse") assert sparse is not None assert sparse.name == "sparse" assert sparse.data_type == DataType.SPARSE_VECTOR_FP32 assert sparse.dimension == 0 assert sparse.index_param.type == IndexType.HNSW assert str(collection_schema) is not None # ---------------------------- # CollectionStats Test Case # ---------------------------- class TestCollectionStats: """ The constructor of CollectionStats is not provided. It can only be obtained through collection.stats() """ def test_collection_stats(self): stats = CollectionStats() assert stats is not None ================================================ FILE: python/tests/test_typing.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import pytest from zvec import ( DataType, IndexType, MetricType, QuantizeType, Status, StatusCode, ) # ---------------------------- # Enum Test Case # ---------------------------- @pytest.mark.parametrize( "member, name", [ (DataType.FLOAT, "FLOAT"), (IndexType.HNSW, "HNSW"), (MetricType.COSINE, "COSINE"), (QuantizeType.INT8, "INT8"), (StatusCode.OK, "OK"), ], ) def test_enum_names(member, name): assert member.name == name @pytest.mark.parametrize( "member, value", [ (DataType.FLOAT, 8), (IndexType.HNSW, 1), (MetricType.COSINE, 3), (QuantizeType.INT8, 2), (StatusCode.OK, 0), ], ) def test_enum_values(member, value): assert member.value == value @pytest.mark.parametrize("member", ["L2", "IP", "COSINE"]) def test_metric_type_has_member(member): assert member in MetricType.__members__ @pytest.mark.parametrize( "member", [ "STRING", "BOOL", "INT32", "INT64", "FLOAT", "DOUBLE", "UINT32", "UINT64", "VECTOR_FP16", "VECTOR_FP32", "VECTOR_FP64", "VECTOR_INT8", "SPARSE_VECTOR_FP32", "SPARSE_VECTOR_FP16", "ARRAY_STRING", "ARRAY_INT32", "ARRAY_INT64", "ARRAY_FLOAT", "ARRAY_DOUBLE", "ARRAY_BOOL", "ARRAY_UINT32", "ARRAY_UINT64", ], ) def test_data_type_has_member(member): assert member in DataType.__members__ @pytest.mark.parametrize("member", ["HNSW", "IVF", "FLAT", "INVERT"]) def test_index_type_has_member(member): assert member in IndexType.__members__ @pytest.mark.parametrize("member", ["FP16", "INT8", "INT4", "UNDEFINED"]) def test_quantize_type_has_member(member): assert member in QuantizeType.__members__ @pytest.mark.parametrize( "member", [ "OK", "UNKNOWN", "NOT_FOUND", "ALREADY_EXISTS", "INVALID_ARGUMENT", "PERMISSION_DENIED", "FAILED_PRECONDITION", "RESOURCE_EXHAUSTED", "UNAVAILABLE", "INTERNAL_ERROR", "NOT_SUPPORTED", ], ) def test_status_code_has_member(member): assert member in StatusCode.__members__ # ---------------------------- # Status Test Case # ---------------------------- class TestStatus: def test_status_code(self): status = Status(StatusCode.OK) assert status.code() == StatusCode.OK def test_status_message(self): status = Status(StatusCode.OK, "OK") assert status.message() == "OK" status = Status(StatusCode.NOT_FOUND, "Not Found") assert status.message() == "Not Found" def test_status_ok(self): status = Status(StatusCode.OK) assert status.ok() ================================================ FILE: python/tests/test_util.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from unittest.mock import MagicMock, patch import pytest from zvec import require_module # ---------------------------- # require_module func Test Case # ---------------------------- def test_require_module_success(): module = require_module("os") assert module is not None assert hasattr(module, "path") def test_require_module_with_submodule_success(): module = require_module("os.path") assert module is not None assert hasattr(module, "join") def test_require_module_import_error(): with pytest.raises(ImportError) as exc_info: require_module("nonexistent_module") exception_msg = str(exc_info.value) assert "Required package 'nonexistent_module' is not installed." in exception_msg def test_require_module_with_mitigation_import_error(): with pytest.raises(ImportError) as exc_info: require_module("nonexistent_module.submodule", mitigation="custom_package") exception_msg = str(exc_info.value) assert "Required package 'custom_package' is not installed." in exception_msg assert ( "Module 'nonexistent_module.submodule' is part of 'nonexistent_module'" in exception_msg ) assert "please pip install 'custom_package'." in exception_msg def test_require_module_submodule_import_error(): with pytest.raises(ImportError) as exc_info: require_module("os.nonexistent_submodule") exception_msg = str(exc_info.value) assert ( "Required package 'os.nonexistent_submodule' is not installed." in exception_msg ) assert "Module 'os.nonexistent_submodule' is part of 'os'" in exception_msg assert "please pip install 'os'." in exception_msg @patch("importlib.import_module") def test_require_module_wraps_original_exception(mock_import_module): original_exception = ImportError("Original error") mock_import_module.side_effect = original_exception with pytest.raises(ImportError) as exc_info: require_module("some_module") assert exc_info.value.__cause__ is original_exception @patch("importlib.import_module") def test_require_module_calls_importlib(mock_import_module): mock_module = MagicMock() mock_import_module.return_value = mock_module result = require_module("test_module") mock_import_module.assert_called_once_with("test_module") assert result is mock_module ================================================ FILE: python/zvec/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import sys from typing import TYPE_CHECKING if TYPE_CHECKING: from importlib.metadata import PackageNotFoundError # ============================== # Public API — grouped by category # ============================== from . import model as model # —— Extensions —— from .extension import ( BM25EmbeddingFunction, DefaultLocalDenseEmbedding, DefaultLocalReRanker, DefaultLocalSparseEmbedding, DenseEmbeddingFunction, OpenAIDenseEmbedding, OpenAIFunctionBase, QwenDenseEmbedding, QwenFunctionBase, QwenReRanker, QwenSparseEmbedding, ReRanker, RrfReRanker, SentenceTransformerFunctionBase, SparseEmbeddingFunction, WeightedReRanker, ) # —— Typing —— from .model import param as param from .model import schema as schema # —— Core data structures —— from .model.collection import Collection from .model.doc import Doc # —— Query & index parameters —— from .model.param import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, HnswIndexParam, HnswQueryParam, HnswRabitqIndexParam, HnswRabitqQueryParam, IndexOption, InvertIndexParam, IVFIndexParam, IVFQueryParam, OptimizeOption, ) from .model.param.vector_query import VectorQuery # —— Schema & field definitions —— from .model.schema import CollectionSchema, CollectionStats, FieldSchema, VectorSchema # —— tools —— from .tool import require_module from .typing import ( DataType, IndexType, MetricType, QuantizeType, Status, StatusCode, ) from .typing.enum import LogLevel, LogType # —— lifecycle —— from .zvec import create_and_open, init, open # ============================== # Public interface declaration # ============================== __all__ = [ # Zvec functions "create_and_open", "init", "open", # Core classes "Collection", "Doc", # Schema "CollectionSchema", "FieldSchema", "VectorSchema", "CollectionStats", # Parameters "VectorQuery", "InvertIndexParam", "HnswIndexParam", "HnswRabitqIndexParam", "HnswRabitqQueryParam", "FlatIndexParam", "IVFIndexParam", "CollectionOption", "IndexOption", "OptimizeOption", "AddColumnOption", "AlterColumnOption", "HnswQueryParam", "IVFQueryParam", # Extensions "DenseEmbeddingFunction", "SparseEmbeddingFunction", "QwenFunctionBase", "OpenAIFunctionBase", "SentenceTransformerFunctionBase", "ReRanker", "DefaultLocalDenseEmbedding", "DefaultLocalSparseEmbedding", "BM25EmbeddingFunction", "OpenAIDenseEmbedding", "QwenDenseEmbedding", "QwenSparseEmbedding", "RrfReRanker", "WeightedReRanker", "DefaultLocalReRanker", "QwenReRanker", # Typing "DataType", "MetricType", "QuantizeType", "IndexType", "LogLevel", "LogType", "Status", "StatusCode", # Tools "require_module", ] # ============================== # Version handling # ============================== __version__: str try: from importlib.metadata import version except ImportError: from importlib_metadata import version # Python < 3.8 try: __version__ = version("zvec") except Exception: __version__ = "unknown" ================================================ FILE: python/zvec/__init__.pyi ================================================ """ Zvec core module """ from __future__ import annotations import collections from . import typing from .extension import ReRanker, RrfReRanker, WeightedReRanker from .extension.embedding import DenseEmbeddingFunction from .model import param, schema from .model.collection import Collection from .model.doc import Doc from .model.param import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, HnswIndexParam, HnswQueryParam, IndexOption, InvertIndexParam, IVFIndexParam, IVFQueryParam, OptimizeOption, ) from .model.param.vector_query import VectorQuery from .model.schema import CollectionSchema, CollectionStats, FieldSchema, VectorSchema from .tool import require_module from .typing import ( DataType, IndexType, MetricType, QuantizeType, Status, StatusCode, ) from .typing.enum import LogLevel, LogType from .zvec import create_and_open, init, open __all__: list = [ "AddColumnOption", "AlterColumnOption", "Collection", "CollectionOption", "CollectionSchema", "CollectionStats", "DataType", "DenseEmbeddingFunction", "DenseEmbeddingFunction", "Doc", "FieldSchema", "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", "IVFIndexParam", "IVFQueryParam", "IndexOption", "IndexType", "InvertIndexParam", "LogLevel", "LogType", "MetricType", "OptimizeOption", "QuantizeType", "ReRanker", "ReRanker", "RrfReRanker", "Status", "StatusCode", "VectorQuery", "VectorSchema", "WeightedReRanker", "create_and_open", "init", "open", "require_module", ] class _Collection: @staticmethod def CreateAndOpen( arg0: str, arg1: schema._CollectionSchema, arg2: param.CollectionOption ) -> _Collection: ... @staticmethod def Open(arg0: str, arg1: param.CollectionOption) -> _Collection: ... def AddColumn( self, arg0: schema._FieldSchema, arg1: str, arg2: param.AddColumnOption, ) -> None: ... def AlterColumn( self, arg0: str, arg1: str, arg2: schema._FieldSchema, arg3: param.AlterColumnOption, ) -> None: ... def CreateIndex( self, arg0: str, arg1: param.IndexParam, arg2: param.IndexOption ) -> None: ... def Delete(self, arg0: collections.abc.Sequence[str]) -> list[typing.Status]: ... def DeleteByFilter(self, arg0: str) -> None: ... def Destroy(self) -> None: ... def DropColumn(self, arg0: str) -> None: ... def DropIndex(self, arg0: str) -> None: ... def Fetch(self, arg0: collections.abc.Sequence[str]) -> dict[str, _Doc]: ... def Flush(self) -> None: ... def GroupByQuery(self, arg0: ...) -> list[...]: ... def Insert(self, arg0: collections.abc.Sequence[_Doc]) -> list[typing.Status]: ... def Optimize(self, arg0: param.OptimizeOption) -> None: ... def Options(self) -> param.CollectionOption: ... def Path(self) -> str: ... def Query(self, arg0: param._VectorQuery) -> list[_Doc]: ... def Schema(self) -> schema._CollectionSchema: ... def Stats(self) -> schema.CollectionStats: ... def Update(self, arg0: collections.abc.Sequence[_Doc]) -> list[typing.Status]: ... def Upsert(self, arg0: collections.abc.Sequence[_Doc]) -> list[typing.Status]: ... def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... class _Doc: def __getstate__(self) -> bytes: ... def __init__(self) -> None: ... def __setstate__(self, arg0: bytes) -> None: ... def field_names(self) -> list[str]: ... def get_any(self, arg0: str, arg1: typing.DataType) -> typing.Any: ... def has_field(self, arg0: str) -> bool: ... def pk(self) -> str: ... def score(self) -> float: ... def set_any(self, arg0: str, arg1: typing.DataType, arg2: typing.Any) -> bool: ... def set_pk(self, arg0: str) -> None: ... def set_score(self, arg0: typing.SupportsFloat) -> None: ... class _DocOp: """ Members: INSERT UPDATE DELETE UPSERT """ DELETE: typing.ClassVar[_DocOp] # value = <_DocOp.DELETE: 3> INSERT: typing.ClassVar[_DocOp] # value = <_DocOp.INSERT: 0> UPDATE: typing.ClassVar[_DocOp] # value = <_DocOp.UPDATE: 2> UPSERT: typing.ClassVar[_DocOp] # value = <_DocOp.UPSERT: 1> __members__: typing.ClassVar[ dict[str, _DocOp] ] # value = {'INSERT': <_DocOp.INSERT: 0>, 'UPDATE': <_DocOp.UPDATE: 2>, 'DELETE': <_DocOp.DELETE: 3>, 'UPSERT': <_DocOp.UPSERT: 1>} def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... ================================================ FILE: python/zvec/common/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from .constants import DenseVectorType, SparseVectorType, VectorType __all__ = ["DenseVectorType", "SparseVectorType", "VectorType"] ================================================ FILE: python/zvec/common/constants.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Optional, TypeVar, Union import numpy as np # VectorType: DenseVectorType | SparseVectorType DenseVectorType = Union[list[float], list[int], np.ndarray] SparseVectorType = dict[int, float] VectorType = Optional[Union[DenseVectorType, SparseVectorType]] # Embeddable: Text | Image | Audio TEXT = str IMAGE = Union[str, bytes, np.ndarray] # file path, raw bytes, or numpy array AUDIO = Union[str, bytes, np.ndarray] # file path, raw bytes, or numpy array Embeddable = Optional[Union[TEXT, IMAGE, AUDIO]] # Multimodal Embeddable MD = TypeVar("MD", bound=Embeddable, contravariant=True) ================================================ FILE: python/zvec/executor/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from .query_executor import ( QueryContext, QueryExecutor, QueryExecutorFactory, ) __all__ = [ "QueryContext", "QueryExecutor", "QueryExecutorFactory", ] ================================================ FILE: python/zvec/executor/query_executor.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional, Union, final import numpy as np from _zvec import _Collection from _zvec.param import _VectorQuery from ..extension import ReRanker, RrfReRanker, WeightedReRanker from ..model.convert import convert_to_py_doc from ..model.doc import Doc from ..model.param.vector_query import VectorQuery from ..model.schema import CollectionSchema from ..typing import DataType __all__ = [ "QueryContext", "QueryExecutor", "QueryExecutorFactory", ] DTYPE_MAP = { DataType.VECTOR_FP16.value: np.float16, DataType.VECTOR_FP32.value: np.float32, DataType.VECTOR_FP64.value: np.float64, DataType.VECTOR_INT8.value: np.int8, } def convert_to_numpy(vec: Union[list, np.ndarray], dtype: np.dtype) -> np.ndarray: if isinstance(vec, np.ndarray): if vec.dtype == dtype and vec.ndim == 1: return vec return np.asarray(vec, dtype=dtype).flatten() try: arr = np.asarray(vec, dtype=dtype) if arr.ndim != 1: arr = arr.flatten() return arr except (ValueError, TypeError) as e: raise TypeError( f"Cannot convert input to 1D numpy array with dtype={dtype}: {type(vec)}" ) from e class QueryContext: def __init__( self, topk: int, filter: Optional[str] = None, include_vector: bool = False, queries: Optional[list[VectorQuery]] = None, output_fields: Optional[list[str]] = None, reranker: Optional[ReRanker] = None, ): # query param self._filter = filter self._queries = queries or [] self._topk = topk self._include_vector = include_vector self._output_fields = output_fields # reranker self._reranker = reranker # core vectors self._core_vectors = [] @property def topk(self): return self._topk @property def queries(self): return self._queries @property def filter(self): return self._filter @property def reranker(self): return self._reranker @property def output_fields(self): return self._output_fields @property def include_vector(self): return self._include_vector @property def core_vectors(self): return self._core_vectors @core_vectors.setter def core_vectors(self, core_vectors: list[_VectorQuery]): self._core_vectors = core_vectors class QueryExecutor(ABC): def __init__(self, schema: CollectionSchema): self._schema = schema self._concurrency = max(1, int(os.getenv("ZVEC_QUERY_CONCURRENCY", "1"))) @abstractmethod def _do_validate(self, ctx: QueryContext) -> None: pass @abstractmethod def _do_build( self, ctx: QueryContext, collection: _Collection ) -> list[_VectorQuery]: pass def _do_build_query_wo_vector(self, ctx: QueryContext) -> _VectorQuery: core_vector = _VectorQuery() core_vector.topk = ctx.topk core_vector.include_vector = ctx.include_vector if ctx.filter: core_vector.filter = ctx.filter if ctx.output_fields: core_vector.output_fields = ctx.output_fields return core_vector def _do_build_query_with_vector( self, ctx: QueryContext, query: VectorQuery, collection: _Collection ) -> _VectorQuery: core_vector = self._do_build_query_wo_vector(ctx) core_vector.field_name = query.field_name if query.param: core_vector.query_params = query.param vector_schema = ( self._schema.vector(query.field_name) if query else self._schema.vectors[0] ) if vector_schema is None: raise ValueError("No vector field found") # set output_fields core_vector.output_fields = ctx.output_fields # set vector if query.has_vector(): vec_data = query.vector else: fetched = collection.Fetch([query.id]) doc = next(iter(fetched.values())) if not doc: return core_vector vec_data = doc.get_any(vector_schema.name, vector_schema.data_type) target_dtype = DTYPE_MAP.get(vector_schema.data_type.value) core_vector.set_vector( vector_schema._get_object(), convert_to_numpy(vec_data, target_dtype) if target_dtype else vec_data, ) return core_vector def _do_execute( self, vectors: list[_VectorQuery], collection: _Collection ) -> dict[str, list[Doc]]: query_cnt = len(vectors) if query_cnt == 0: raise ValueError("No query to execute") if len(vectors) == 1 or self._concurrency == 1: results = {} for query in vectors: docs = collection.Query(query) results[query.field_name] = [ convert_to_py_doc(doc, self._schema) for doc in docs ] return results results = {} with ThreadPoolExecutor(max_workers=self._concurrency) as executor: future_to_query = { executor.submit(collection.Query, query): query.field_name for query in vectors } for future in as_completed(future_to_query): field_name = future_to_query[future] try: docs = future.result() results[field_name] = [ convert_to_py_doc(doc, self._schema) for doc in docs ] except Exception as e: raise e return results def _do_merge_rerank_results( self, ctx: QueryContext, docs_map: dict[str, list[Doc]] ) -> list[Doc]: query_result_cnt = len(docs_map) if docs_map else 0 if query_result_cnt == 0: raise ValueError("Query results is none and dost not to rerank") if query_result_cnt == 1: if not ctx.reranker or isinstance( ctx.reranker, (RrfReRanker, WeightedReRanker) ): return next(iter(docs_map.values())) return ctx.reranker.rerank(docs_map) return ctx.reranker.rerank(docs_map) @final def execute(self, ctx: QueryContext, collection: _Collection) -> list[Doc]: # 1. validate query self._do_validate(ctx) # 2. build query vector query_vectors = self._do_build(ctx, collection) if not query_vectors: raise ValueError("No query to execute") # 3. execute query docs = self._do_execute(query_vectors, collection) # 4. merge and rerank result return self._do_merge_rerank_results(ctx, docs) class NoVectorQueryExecutor(QueryExecutor): def __init__(self, schema: CollectionSchema): super().__init__(schema) def _do_validate(self, ctx: QueryContext) -> None: if len(ctx.queries) > 0: raise ValueError("Collection does not support query with vector or id") def _do_build( self, ctx: QueryContext, _collection: _Collection ) -> list[_VectorQuery]: return [self._do_build_query_wo_vector(ctx)] class SingleVectorQueryExecutor(NoVectorQueryExecutor): def __init__(self, schema: CollectionSchema) -> None: super().__init__(schema) def _do_validate(self, ctx: QueryContext) -> None: if len(ctx.queries) > 1: raise ValueError( "Collection has only one vector field, cannot query with multiple vectors" ) for query in ctx.queries: query._validate() def _do_build( self, ctx: QueryContext, collection: _Collection ) -> list[_VectorQuery]: if len(ctx.queries) == 0: return [self._do_build_query_wo_vector(ctx)] vectors = [] for query in ctx.queries: vectors.append(self._do_build_query_with_vector(ctx, query, collection)) return vectors class MultiVectorQueryExecutor(SingleVectorQueryExecutor): def __init__(self, schema: CollectionSchema) -> None: super().__init__(schema) def _do_validate(self, ctx: QueryContext) -> None: if len(ctx.queries) > 1 and ctx.reranker is None: raise ValueError("Reranker is required for multi-vector query") seen_fields = set() for query in ctx.queries: query._validate() field = query.field_name if field in seen_fields: raise ValueError(f"Query field name '{field}' appears more than once") seen_fields.add(field) def _do_execute( self, vectors: list[_VectorQuery], collection: _Collection ) -> dict[str, list[Doc]]: return super()._do_execute(vectors, collection) class QueryExecutorFactory: @staticmethod def create(schema: CollectionSchema) -> QueryExecutor: vectors = schema.vectors if len(vectors) == 0: return NoVectorQueryExecutor(schema) if len(vectors) == 1: return SingleVectorQueryExecutor(schema) return MultiVectorQueryExecutor(schema) ================================================ FILE: python/zvec/extension/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from .bm25_embedding_function import BM25EmbeddingFunction from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction from .http_embedding_function import HTTPDenseEmbedding from .jina_embedding_function import JinaDenseEmbedding from .jina_function import JinaFunctionBase from .multi_vector_reranker import RrfReRanker, WeightedReRanker from .openai_embedding_function import OpenAIDenseEmbedding from .openai_function import OpenAIFunctionBase from .qwen_embedding_function import QwenDenseEmbedding, QwenSparseEmbedding from .qwen_function import QwenFunctionBase from .qwen_rerank_function import QwenReRanker from .rerank_function import RerankFunction as ReRanker from .sentence_transformer_embedding_function import ( DefaultLocalDenseEmbedding, DefaultLocalSparseEmbedding, ) from .sentence_transformer_function import SentenceTransformerFunctionBase from .sentence_transformer_rerank_function import DefaultLocalReRanker __all__ = [ "BM25EmbeddingFunction", "DefaultLocalDenseEmbedding", "DefaultLocalReRanker", "DefaultLocalSparseEmbedding", "DenseEmbeddingFunction", "HTTPDenseEmbedding", "JinaDenseEmbedding", "JinaFunctionBase", "OpenAIDenseEmbedding", "OpenAIFunctionBase", "QwenDenseEmbedding", "QwenFunctionBase", "QwenReRanker", "QwenSparseEmbedding", "ReRanker", "RrfReRanker", "SentenceTransformerFunctionBase", "SparseEmbeddingFunction", "WeightedReRanker", ] ================================================ FILE: python/zvec/extension/bm25_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import lru_cache from typing import Literal, Optional from ..common.constants import TEXT, SparseVectorType from ..tool import require_module from .embedding_function import SparseEmbeddingFunction class BM25EmbeddingFunction(SparseEmbeddingFunction[TEXT]): """BM25-based sparse embedding function using DashText SDK. This class provides text-to-sparse-vector embedding capabilities using the DashText library with BM25 algorithm. BM25 (Best Matching 25) is a probabilistic retrieval function used for lexical search and document ranking based on term frequency and inverse document frequency. BM25 generates sparse vectors where each dimension corresponds to a term in the vocabulary, and the value represents the BM25 score for that term. It's particularly effective for: - Lexical search and keyword matching - Document ranking and information retrieval - Combining with dense embeddings for hybrid search - Traditional IR tasks where exact term matching is important This implementation uses DashText's SparseVectorEncoder, which provides efficient BM25 computation for Chinese and English text using either a built-in encoder or custom corpus training. Args: corpus (Optional[list[str]], optional): List of documents to train the BM25 encoder. If provided, creates a custom encoder trained on this corpus for better domain-specific accuracy. If ``None``, uses the built-in encoder. Defaults to ``None``. encoding_type (Literal["query", "document"], optional): Encoding mode for text processing. Use ``"query"`` for search queries (default) and ``"document"`` for document indexing. This distinction optimizes the BM25 scoring for asymmetric retrieval tasks. Defaults to ``"query"``. language (Literal["zh", "en"], optional): Language for built-in encoder. Only used when corpus is None. ``"zh"`` for Chinese (trained on Chinese Wikipedia), ``"en"`` for English. Defaults to ``"zh"``. b (float, optional): Document length normalization parameter for BM25. Range [0, 1]. 0 means no normalization, 1 means full normalization. Only used with custom corpus. Defaults to ``0.75``. k1 (float, optional): Term frequency saturation parameter for BM25. Higher values give more weight to term frequency. Only used with custom corpus. Defaults to ``1.2``. **kwargs: Additional parameters for DashText encoder customization. Attributes: corpus_size (int): Number of documents in the training corpus (0 if using built-in encoder). encoding_type (str): The encoding type being used ("query" or "document"). language (str): The language of the built-in encoder ("zh" or "en"). Raises: ValueError: If corpus is provided but empty or contains non-string elements. TypeError: If input to ``embed()`` is not a string. RuntimeError: If DashText encoder initialization or training fails. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``dashtext`` package: ``pip install dashtext`` - Two encoder options available: 1. **Built-in encoder** (no corpus needed): Pre-trained models for Chinese (zh) and English (en), good generalization, works out-of-the-box 2. **Custom encoder** (corpus required): Better accuracy for domain-specific terminology, requires training on your full corpus with BM25 parameters - Encoding types: * ``encoding_type="query"``: Optimized for search queries (shorter text) * ``encoding_type="document"``: Optimized for document indexing (longer text) - BM25 parameters (b, k1) only apply to custom encoder training - Output is sorted by indices (vocabulary term IDs) for consistency - Results are cached (LRU cache, maxsize=10) to reduce computation - No API key or network connectivity required (local computation) Examples: >>> # Option 1: Using built-in encoder for Chinese (no corpus needed) >>> from zvec.extension import BM25EmbeddingFunction >>> >>> # For query encoding (Chinese) >>> bm25_query_zh = BM25EmbeddingFunction(language="zh", encoding_type="query") >>> query_vec = bm25_query_zh.embed("什么是机器学习") >>> isinstance(query_vec, dict) True >>> # query_vec: {1169440797: 0.29, 2045788977: 0.70, ...} >>> # For document encoding (Chinese) >>> bm25_doc_zh = BM25EmbeddingFunction(language="zh", encoding_type="document") >>> doc_vec = bm25_doc_zh.embed("机器学习是人工智能的一个重要分支...") >>> isinstance(doc_vec, dict) True >>> # Using built-in encoder for English >>> bm25_query_en = BM25EmbeddingFunction(language="en", encoding_type="query") >>> query_vec_en = bm25_query_en.embed("what is vector search service") >>> isinstance(query_vec_en, dict) True >>> # Option 2: Using custom corpus for domain-specific accuracy >>> corpus = [ ... "机器学习是人工智能的一个重要分支", ... "深度学习使用多层神经网络进行特征提取", ... "自然语言处理技术用于理解和生成人类语言" ... ] >>> bm25_custom = BM25EmbeddingFunction( ... corpus=corpus, ... encoding_type="query", ... b=0.75, ... k1=1.2 ... ) >>> custom_vec = bm25_custom.embed("机器学习算法") >>> isinstance(custom_vec, dict) True >>> # Hybrid search: combining with dense embeddings >>> from zvec.extension import DefaultLocalDenseEmbedding >>> dense_emb = DefaultLocalDenseEmbedding() >>> bm25_emb = BM25EmbeddingFunction(language="zh", encoding_type="query") >>> >>> query = "machine learning algorithms" >>> dense_vec = dense_emb.embed(query) # Semantic similarity >>> sparse_vec = bm25_emb.embed(query) # Lexical matching >>> # Combine scores for hybrid retrieval >>> # Callable interface >>> sparse_vec = bm25_query_zh("information retrieval") >>> isinstance(sparse_vec, dict) True >>> # Error handling >>> try: ... bm25_query_zh.embed("") # Empty query ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only See Also: - ``SparseEmbeddingFunction``: Base class for sparse embeddings - ``DefaultLocalSparseEmbedding``: SPLADE-based sparse embedding - ``QwenSparseEmbedding``: API-based sparse embedding using Qwen - ``DefaultLocalDenseEmbedding``: Dense embedding for semantic search References: - DashText Documentation: https://help.aliyun.com/zh/document_detail/2546039.html - DashText PyPI: https://pypi.org/project/dashtext/ - BM25 Algorithm: Robertson & Zaragoza (2009) """ def __init__( self, corpus: Optional[list[str]] = None, encoding_type: Literal["query", "document"] = "query", language: Literal["zh", "en"] = "zh", b: float = 0.75, k1: float = 1.2, **kwargs, ): """Initialize the BM25 embedding function. Args: corpus (Optional[list[str]]): Optional corpus for training custom encoder. If None, uses built-in encoder. Defaults to None. encoding_type (Literal["query", "document"]): Text encoding mode. Use "query" for search queries, "document" for indexing. Defaults to "query". language (Literal["zh", "en"]): Language for built-in encoder. "zh" for Chinese, "en" for English. Defaults to "zh". b (float): Document length normalization for BM25 [0, 1]. Only used with custom corpus. Defaults to 0.75. k1 (float): Term frequency saturation for BM25. Only used with custom corpus. Defaults to 1.2. **kwargs: Additional DashText encoder parameters. Raises: ValueError: If corpus is provided but empty or invalid. ImportError: If dashtext package is not installed. RuntimeError: If encoder initialization or training fails. """ # Validate corpus if provided if corpus is not None: if not corpus or not isinstance(corpus, list): raise ValueError("Corpus must be a non-empty list of strings") if not all(isinstance(doc, str) for doc in corpus): raise ValueError("All corpus documents must be strings") # Import dashtext self._dashtext = require_module("dashtext") self._corpus = corpus self._encoding_type = encoding_type self._language = language self._b = b self._k1 = k1 self._extra_params = kwargs # Initialize the BM25 encoder self._build_encoder() def _build_encoder(self): """Build the BM25 sparse vector encoder. Creates either a built-in encoder (pre-trained) or a custom encoder trained on the provided corpus. Raises: RuntimeError: If encoder initialization or training fails. ImportError: If dashtext package is not installed. """ try: if self._corpus is None: # Use built-in encoder (pre-trained on Wikipedia) # language: 'zh' for Chinese, 'en' for English self._encoder = self._dashtext.SparseVectorEncoder.default( name=self._language ) else: # Create custom encoder with BM25 parameters self._encoder = self._dashtext.SparseVectorEncoder( b=self._b, k1=self._k1, **self._extra_params ) # Train encoder with the corpus self._encoder.train(self._corpus) except ImportError as e: raise ImportError( "dashtext package is required for BM25EmbeddingFunction. " "Install it with: pip install dashtext" ) from e except Exception as e: if isinstance(e, (ValueError, RuntimeError)): raise raise RuntimeError(f"Failed to build BM25 encoder: {e!s}") from e @property def corpus_size(self) -> int: """int: Number of documents in the training corpus (0 if using built-in encoder).""" return len(self._corpus) if self._corpus is not None else 0 @property def encoding_type(self) -> str: """str: The encoding type being used ("query" or "document").""" return self._encoding_type @property def language(self) -> str: """str: The language of the built-in encoder ("zh" or "en").""" return self._language @property def extra_params(self) -> dict: """dict: Extra parameters for DashText encoder customization.""" return self._extra_params def __call__(self, input: TEXT) -> SparseVectorType: """Make the embedding function callable. Args: input (TEXT): Input text to embed. Returns: SparseVectorType: Sparse vector as dictionary. """ return self.embed(input) @lru_cache(maxsize=10) def embed(self, input: TEXT) -> SparseVectorType: """Generate BM25 sparse embedding for the input text. This method computes BM25 scores for the input text using DashText's SparseVectorEncoder. The encoding behavior depends on the encoding_type: - ``encoding_type="query"``: Uses ``encode_queries()`` for search queries - ``encoding_type="document"``: Uses ``encode_documents()`` for documents The result is a sparse vector where keys are term indices in the vocabulary and values are BM25 scores. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Returns: SparseVectorType: A dictionary mapping vocabulary term index to BM25 score. Only non-zero scores are included. The dictionary is sorted by indices (keys) in ascending order for consistent output. Example: ``{1169440797: 0.29, 2045788977: 0.70, ...}`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty or whitespace-only. RuntimeError: If BM25 encoding fails. Examples: >>> bm25 = BM25EmbeddingFunction(language="zh", encoding_type="query") >>> sparse_vec = bm25.embed("query text") >>> isinstance(sparse_vec, dict) True >>> all(isinstance(k, int) and isinstance(v, float) for k, v in sparse_vec.items()) True >>> # Verify sorted output >>> keys = list(sparse_vec.keys()) >>> keys == sorted(keys) True >>> # Error: empty input >>> bm25.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> bm25.embed(123) TypeError: Expected 'input' to be str, got int Note: - BM25 scores are relative to the vocabulary statistics - Output dictionary is always sorted by indices for consistency - Terms not in the vocabulary will have zero scores (not included) - This method is cached (maxsize=10) for performance - DashText automatically handles Chinese/English text segmentation """ if not isinstance(input, str): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") try: # Encode based on encoding_type if self._encoding_type == "query": sparse_vector = self._encoder.encode_queries(input) else: # encoding_type == "document" sparse_vector = self._encoder.encode_documents(input) # DashText returns dict with int/long keys and float values # Convert to standard format: {int: float} sparse_dict: dict[int, float] = {} for key, value in sparse_vector.items(): try: idx = int(key) val = float(value) if val > 0: sparse_dict[idx] = val except (ValueError, TypeError): # Skip invalid entries continue # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) except Exception as e: if isinstance(e, (TypeError, ValueError)): raise raise RuntimeError(f"Failed to generate BM25 embedding: {e!s}") from e ================================================ FILE: python/zvec/extension/embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from abc import abstractmethod from typing import Protocol, runtime_checkable from ..common.constants import MD, DenseVectorType, SparseVectorType @runtime_checkable class DenseEmbeddingFunction(Protocol[MD]): """Protocol for dense vector embedding functions. Dense embedding functions map multimodal input (text, image, or audio) to fixed-length real-valued vectors. This is a Protocol class that defines the interface - implementations should provide their own initialization and properties. Type Parameters: MD: The type of input data (bound to Embeddable: TEXT, IMAGE, or AUDIO). Note: - This is a Protocol class - it only defines the ``embed()`` interface. - Implementations are free to define their own ``__init__``, properties, and additional methods as needed. - The ``embed()`` method is the only required interface. Examples: >>> # Custom text embedding implementation >>> class MyTextEmbedding: ... def __init__(self, dimension: int, model_name: str): ... self.dimension = dimension ... self.model = load_model(model_name) ... ... def embed(self, input: str) -> list[float]: ... return self.model.encode(input).tolist() >>> # Custom image embedding implementation >>> class MyImageEmbedding: ... def __init__(self, dimension: int = 512): ... self.dimension = dimension ... self.model = load_image_model() ... ... def embed(self, input: Union[str, bytes, np.ndarray]) -> list[float]: ... if isinstance(input, str): ... image = load_image_from_path(input) ... else: ... image = input ... return self.model.extract_features(image).tolist() >>> # Using built-in implementations >>> from zvec.extension import QwenDenseEmbedding >>> text_emb = QwenDenseEmbedding(dimension=768, api_key="sk-xxx") >>> vector = text_emb.embed("Hello world") """ @abstractmethod def embed(self, input: MD) -> DenseVectorType: """Generate a dense embedding vector for the input data. Args: input (MD): Multimodal input data to embed. Can be: - TEXT (str): Text string - IMAGE (str | bytes | np.ndarray): Image file path, raw bytes, or array - AUDIO (str | bytes | np.ndarray): Audio file path, raw bytes, or array Returns: DenseVectorType: A dense vector representing the embedding. Can be list[float], list[int], or np.ndarray. Length should match the implementation's dimension. """ ... @runtime_checkable class SparseEmbeddingFunction(Protocol[MD]): """Abstract base class for sparse vector embedding functions. Sparse embedding functions map multimodal input (text, image, or audio) to a dictionary of {index: weight}, where only non-zero dimensions are stored. You can inherit this class to create custom sparse embedding functions. Type Parameters: MD: The type of input data (bound to Embeddable: TEXT, IMAGE, or AUDIO). Note: Subclasses must implement the ``embed()`` method. Examples: >>> # Using built-in text sparse embedding (e.g., BM25, TF-IDF) >>> sparse_emb = SomeSparseEmbedding() >>> vector = sparse_emb.embed("Hello world") >>> # Returns: {0: 0.5, 42: 1.2, 100: 0.8} >>> # Custom BM25 sparse embedding function >>> class MyBM25Embedding(SparseEmbeddingFunction): ... def __init__(self, vocab_size: int = 10000): ... self.vocab_size = vocab_size ... self.tokenizer = MyTokenizer() ... ... def embed(self, input: str) -> dict[int, float]: ... tokens = self.tokenizer.tokenize(input) ... sparse_vector = {} ... for token_id, weight in self._calculate_bm25(tokens): ... if weight > 0: ... sparse_vector[token_id] = weight ... return sparse_vector ... ... def _calculate_bm25(self, tokens): ... # BM25 calculation logic ... pass >>> # Custom sparse image feature extractor >>> class MySparseImageEmbedding(SparseEmbeddingFunction): ... def embed(self, input: Union[str, bytes, np.ndarray]) -> dict[int, float]: ... image = self._load_image(input) ... features = self._extract_sparse_features(image) ... return {idx: val for idx, val in enumerate(features) if val != 0} """ @abstractmethod def embed(self, input: MD) -> SparseVectorType: """Generate a sparse embedding for the input data. Args: input (MD): Multimodal input data to embed. Can be: - TEXT (str): Text string - IMAGE (str | bytes | np.ndarray): Image file path, raw bytes, or array - AUDIO (str | bytes | np.ndarray): Audio file path, raw bytes, or array Returns: SparseVectorType: Mapping from dimension index to non-zero weight. Only dimensions with non-zero values are included. """ ... ================================================ FILE: python/zvec/extension/http_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import json import os import urllib.request from functools import lru_cache from typing import Optional from ..common.constants import TEXT, DenseVectorType from .embedding_function import DenseEmbeddingFunction class HTTPDenseEmbedding(DenseEmbeddingFunction[TEXT]): """Dense text embedding function using any OpenAI-compatible HTTP endpoint. This class calls any server that implements the ``/v1/embeddings`` API (LM Studio, Ollama, vLLM, LocalAI, etc.) using only the Python standard library — no extra dependencies are required. The embedding dimension is detected automatically from the first server response. Args: base_url (str, optional): Base URL of the embedding server. Defaults to ``"http://localhost:1234"`` (LM Studio). Common values: - ``"http://localhost:1234"`` — LM Studio - ``"http://localhost:11434"`` — Ollama model (str, optional): Model identifier as expected by the server. Defaults to ``"text-embedding-nomic-embed-text-v1.5@f16"``. api_key (Optional[str], optional): Bearer token for authenticated endpoints. Falls back to the ``OPENAI_API_KEY`` environment variable. Leave as ``None`` for local servers that do not require authentication. timeout (int, optional): HTTP request timeout in seconds. Defaults to 30. Attributes: dimension (int): Embedding vector dimensionality (auto-detected). Raises: TypeError: If ``embed()`` receives a non-string input. ValueError: If input is empty/whitespace-only or the server returns an unexpected response format. RuntimeError: If the HTTP request fails or the server is unreachable. Examples: >>> from zvec.extension import HTTPDenseEmbedding >>> >>> # LM Studio (default) >>> emb = HTTPDenseEmbedding() >>> vector = emb.embed("Hello, world!") >>> len(vector) 768 >>> >>> # Ollama >>> emb = HTTPDenseEmbedding( ... base_url="http://localhost:11434", ... model="nomic-embed-text", ... ) >>> vector = emb.embed("Semantic search with local models") See Also: - ``DenseEmbeddingFunction``: Protocol for dense embeddings. - ``OpenAIDenseEmbedding``: Cloud embedding via the OpenAI API. """ ENDPOINT = "/v1/embeddings" def __init__( self, base_url: str = "http://localhost:1234", model: str = "text-embedding-nomic-embed-text-v1.5@f16", api_key: Optional[str] = None, timeout: int = 30, ) -> None: self._base_url = base_url.rstrip("/") self._model = model self._api_key = api_key or os.environ.get("OPENAI_API_KEY", "") self._timeout = timeout self._dimension: Optional[int] = None @property def dimension(self) -> int: """int: Embedding vector dimensionality (auto-detected on first call).""" if self._dimension is None: self._dimension = len(self.embed("dimension probe")) return self._dimension def __call__(self, input: TEXT) -> DenseVectorType: """Make the embedding function callable.""" return self.embed(input) @lru_cache(maxsize=256) def embed(self, input: TEXT) -> DenseVectorType: """Generate a dense embedding vector for the input text. Results are cached (LRU, up to 256 entries) so repeated strings do not trigger extra HTTP requests. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Returns: DenseVectorType: A list of floats representing the embedding. Raises: TypeError: If *input* is not a string. ValueError: If *input* is empty/whitespace-only or the server returns an unexpected response format. RuntimeError: If the HTTP request fails. """ if not isinstance(input, TEXT): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") url = self._base_url + self.ENDPOINT payload = json.dumps({"model": self._model, "input": input}).encode() headers: dict[str, str] = {"Content-Type": "application/json"} if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" req = urllib.request.Request(url, data=payload, headers=headers, method="POST") try: with urllib.request.urlopen(req, timeout=self._timeout) as resp: body = json.loads(resp.read()) except urllib.error.HTTPError as exc: raise RuntimeError( f"Embedding server returned HTTP {exc.code}: {exc.read().decode()}" ) from exc except OSError as exc: raise RuntimeError( f"Could not reach embedding server at {url}: {exc}" ) from exc try: vector: list[float] = body["data"][0]["embedding"] except (KeyError, IndexError) as exc: raise ValueError( f"Unexpected response format from embedding server: {body}" ) from exc return vector ================================================ FILE: python/zvec/extension/jina_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import lru_cache from typing import Optional from ..common.constants import TEXT, DenseVectorType from .embedding_function import DenseEmbeddingFunction from .jina_function import JinaFunctionBase class JinaDenseEmbedding(JinaFunctionBase, DenseEmbeddingFunction[TEXT]): """Dense text embedding function using Jina AI API. This class provides text-to-vector embedding capabilities using Jina AI's embedding models. It inherits from ``DenseEmbeddingFunction`` and implements dense text embedding via the Jina Embeddings API (OpenAI-compatible). Jina Embeddings v5 models support task-specific embedding through the ``task`` parameter, which optimizes the embedding for different use cases such as retrieval, text matching, or classification. They also support Matryoshka Representation Learning, allowing flexible output dimensions. Args: model (str, optional): Jina embedding model identifier. Defaults to ``"jina-embeddings-v5-text-nano"``. Available models: - ``"jina-embeddings-v5-text-nano"``: 768 dims, 239M params, 8K context - ``"jina-embeddings-v5-text-small"``: 1024 dims, 677M params, 32K context dimension (Optional[int], optional): Desired output embedding dimension. If ``None``, uses model's default dimension. Supports Matryoshka dimensions: 32, 64, 128, 256, 512, 768 (nano) / 1024 (small). Defaults to ``None``. api_key (Optional[str], optional): Jina API authentication key. If ``None``, reads from ``JINA_API_KEY`` environment variable. Obtain your key from: https://jina.ai/api-dashboard task (Optional[str], optional): Task type to optimize embeddings for. Defaults to ``None``. Valid values: - ``"retrieval.query"``: For search queries - ``"retrieval.passage"``: For documents/passages to be searched - ``"text-matching"``: For symmetric text similarity - ``"classification"``: For text classification - ``"separation"``: For clustering/separation tasks Attributes: dimension (int): The embedding vector dimension. data_type (DataType): Always ``DataType.VECTOR_FP32`` for this implementation. model (str): The Jina model name being used. task (Optional[str]): The task type for embedding optimization. Raises: ValueError: If API key is not provided and not found in environment, if task is not a valid task type, or if API returns an error response. TypeError: If input to ``embed()`` is not a string. RuntimeError: If network error or Jina service error occurs. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``openai`` package: ``pip install openai`` - Jina API is OpenAI-compatible, so it uses the ``openai`` Python client - Embedding results are cached (LRU cache, maxsize=10) to reduce API calls - For retrieval tasks, use ``"retrieval.query"`` for queries and ``"retrieval.passage"`` for documents - API usage requires a Jina API key from https://jina.ai/api-dashboard Examples: >>> # Basic usage with default model >>> from zvec.extension import JinaDenseEmbedding >>> import os >>> os.environ["JINA_API_KEY"] = "jina_..." >>> >>> emb_func = JinaDenseEmbedding() >>> vector = emb_func.embed("Hello, world!") >>> len(vector) 768 >>> # Retrieval use case: embed queries and documents differently >>> query_emb = JinaDenseEmbedding(task="retrieval.query") >>> doc_emb = JinaDenseEmbedding(task="retrieval.passage") >>> >>> query_vector = query_emb.embed("What is machine learning?") >>> doc_vector = doc_emb.embed("Machine learning is a subset of AI...") >>> # Using larger model with custom dimension (Matryoshka) >>> emb_func = JinaDenseEmbedding( ... model="jina-embeddings-v5-text-small", ... dimension=256, ... api_key="jina_...", ... task="text-matching", ... ) >>> vector = emb_func.embed("Semantic similarity comparison") >>> len(vector) 256 >>> # Using with zvec collection >>> import zvec >>> emb_func = JinaDenseEmbedding(task="retrieval.passage") >>> schema = zvec.CollectionSchema( ... name="docs", ... vectors=zvec.VectorSchema( ... "embedding", zvec.DataType.VECTOR_FP32, emb_func.dimension ... ), ... ) >>> collection = zvec.create_and_open(path="./my_docs", schema=schema) See Also: - ``DenseEmbeddingFunction``: Base class for dense embeddings - ``OpenAIDenseEmbedding``: Alternative using OpenAI API - ``QwenDenseEmbedding``: Alternative using Qwen/DashScope API - ``DefaultLocalDenseEmbedding``: Local model without API calls """ def __init__( self, model: str = "jina-embeddings-v5-text-nano", dimension: Optional[int] = None, api_key: Optional[str] = None, task: Optional[str] = None, **kwargs, ): """Initialize the Jina dense embedding function. Args: model (str): Jina model name. Defaults to "jina-embeddings-v5-text-nano". dimension (Optional[int]): Target embedding dimension or None for default. api_key (Optional[str]): API key or None to use environment variable. task (Optional[str]): Task type for embedding optimization or None. **kwargs: Additional parameters for API calls. Raises: ValueError: If API key is not provided and not in environment, or if task is not a valid task type. """ # Initialize base class for API connection JinaFunctionBase.__init__(self, model=model, api_key=api_key, task=task) # Store dimension configuration self._custom_dimension = dimension # Determine actual dimension if dimension is None: self._dimension = self._MODEL_DIMENSIONS.get(model, 768) else: self._dimension = dimension # Store extra attributes self._extra_params = kwargs @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" return self._dimension @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: TEXT) -> DenseVectorType: """Make the embedding function callable.""" return self.embed(input) @lru_cache(maxsize=10) def embed(self, input: TEXT) -> DenseVectorType: """Generate dense embedding vector for the input text. This method calls the Jina Embeddings API to convert input text into a dense vector representation. Results are cached to improve performance for repeated inputs. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Maximum length depends on model: 8192 tokens for v5-nano, 32768 tokens for v5-small. Returns: DenseVectorType: A list of floats representing the embedding vector. Length equals ``self.dimension``. Example: ``[0.123, -0.456, 0.789, ...]`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty/whitespace-only, or if the API returns an error or malformed response. RuntimeError: If network connectivity issues or Jina service errors occur. Examples: >>> emb = JinaDenseEmbedding(task="retrieval.query") >>> vector = emb.embed("What is deep learning?") >>> len(vector) 768 >>> isinstance(vector[0], float) True >>> # Error: empty input >>> emb.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> emb.embed(123) TypeError: Expected 'input' to be str, got int Note: - This method is cached (maxsize=10). Identical inputs return cached results. - The cache is based on exact string match (case-sensitive). - Task type affects embedding optimization but not caching behavior. """ if not isinstance(input, TEXT): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") # Call API embedding_vector = self._call_text_embedding_api( input=input, dimension=self._custom_dimension, ) # Verify dimension if len(embedding_vector) != self.dimension: raise ValueError( f"Dimension mismatch: expected {self.dimension}, " f"got {len(embedding_vector)}" ) return embedding_vector ================================================ FILE: python/zvec/extension/jina_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os from typing import ClassVar, Optional from ..common.constants import TEXT from ..tool import require_module class JinaFunctionBase: """Base class for Jina AI functions. This base class provides common functionality for calling Jina AI APIs and handling responses. It supports embeddings (dense) operations via the OpenAI-compatible Jina Embeddings API. This class is not meant to be used directly. Use concrete implementations: - ``JinaDenseEmbedding`` for dense embeddings Args: model (str): Jina embedding model identifier. api_key (Optional[str]): Jina API authentication key. task (Optional[str]): Task type for the embedding model. Note: - This is an internal base class for code reuse across Jina features - Subclasses should inherit from appropriate Protocol - Provides unified API connection and response handling - Jina API is OpenAI-compatible, using the ``openai`` Python client """ _BASE_URL: ClassVar[str] = "https://api.jina.ai/v1" # Model default dimensions _MODEL_DIMENSIONS: ClassVar[dict[str, int]] = { "jina-embeddings-v5-text-nano": 768, "jina-embeddings-v5-text-small": 1024, } # Model max tokens _MODEL_MAX_TOKENS: ClassVar[dict[str, int]] = { "jina-embeddings-v5-text-nano": 8192, "jina-embeddings-v5-text-small": 32768, } # Valid task types _VALID_TASKS: ClassVar[tuple[str, ...]] = ( "retrieval.query", "retrieval.passage", "text-matching", "classification", "separation", ) def __init__( self, model: str, api_key: Optional[str] = None, task: Optional[str] = None, ): """Initialize the base Jina functionality. Args: model (str): Jina model name. api_key (Optional[str]): API key or None to use environment variable. task (Optional[str]): Task type for the embedding model. Valid values: "retrieval.query", "retrieval.passage", "text-matching", "classification", "separation". Raises: ValueError: If API key is not provided and not in environment, or if task is not a valid task type. """ self._model = model self._api_key = api_key or os.environ.get("JINA_API_KEY") self._task = task if not self._api_key: raise ValueError( "Jina API key is required. Please provide 'api_key' parameter " "or set the 'JINA_API_KEY' environment variable. " "Get your key from: https://jina.ai/api-dashboard" ) if task is not None and task not in self._VALID_TASKS: raise ValueError( f"Invalid task '{task}'. Valid tasks: {', '.join(self._VALID_TASKS)}" ) @property def model(self) -> str: """str: The Jina model name currently in use.""" return self._model @property def task(self) -> Optional[str]: """Optional[str]: The task type for the embedding model.""" return self._task def _get_client(self): """Get OpenAI-compatible client instance configured for Jina API. Returns: OpenAI: Configured OpenAI client pointing to Jina API. Raises: ImportError: If openai package is not installed. """ openai = require_module("openai") return openai.OpenAI(api_key=self._api_key, base_url=self._BASE_URL) def _call_text_embedding_api( self, input: TEXT, dimension: Optional[int] = None, ) -> list: """Call Jina Embeddings API. Args: input (TEXT): Input text to embed. dimension (Optional[int]): Target dimension for Matryoshka embeddings. Returns: list: Embedding vector as list of floats. Raises: RuntimeError: If API call fails. ValueError: If API returns error response. """ try: client = self._get_client() # Prepare embedding parameters params = {"model": self.model, "input": input} # Add dimension parameter for Matryoshka support if dimension is not None: params["dimensions"] = dimension # Add task parameter via extra_body if self._task is not None: params["extra_body"] = {"task": self._task} # Call Jina API (OpenAI-compatible) response = client.embeddings.create(**params) except Exception as e: # Check if it's an OpenAI API error openai = require_module("openai") if isinstance(e, (openai.APIError, openai.APIConnectionError)): raise RuntimeError(f"Failed to call Jina API: {e!s}") from e raise RuntimeError(f"Unexpected error during API call: {e!s}") from e # Extract embedding from response try: if not response.data: raise ValueError("Invalid API response: no embedding data returned") embedding_vector = response.data[0].embedding if not isinstance(embedding_vector, list): raise ValueError( "Invalid API response: embedding is not a list of numbers" ) return embedding_vector except (AttributeError, IndexError, TypeError) as e: raise ValueError(f"Failed to parse API response: {e!s}") from e ================================================ FILE: python/zvec/extension/multi_vector_reranker.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import heapq import math from collections import defaultdict from typing import Optional from ..model.doc import Doc from ..typing import MetricType from .rerank_function import RerankFunction class RrfReRanker(RerankFunction): """Re-ranker using Reciprocal Rank Fusion (RRF) for multi-vector search. RRF combines results from multiple vector queries without requiring relevance scores. It assigns higher weight to documents that appear early in multiple result lists. The RRF score for a document at rank ``r`` is: ``1 / (k + r + 1)``, where ``k`` is the rank constant. Note: This re-ranker is specifically designed for multi-vector scenarios where query results from multiple vector fields need to be combined. Args: topn (int, optional): Number of top documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored by RRF. Defaults to None. rank_constant (int, optional): Smoothing constant ``k`` in RRF formula. Larger values reduce the impact of early ranks. Defaults to 60. """ def __init__( self, topn: int = 10, rerank_field: Optional[str] = None, rank_constant: int = 60, ): super().__init__(topn=topn, rerank_field=rerank_field) self._rank_constant = rank_constant @property def rank_constant(self) -> int: return self._rank_constant def _rrf_score(self, rank: int) -> float: return 1.0 / (self._rank_constant + rank + 1) def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Apply Reciprocal Rank Fusion to combine multiple query results. Args: query_results (dict[str, list[Doc]]): Results from one or more vector queries. Returns: list[Doc]: Re-ranked documents with RRF scores in the ``score`` field. """ rrf_scores: dict[str, float] = defaultdict(float) id_to_doc: dict[str, Doc] = {} for _, query_result in query_results.items(): for rank, doc in enumerate(query_result): doc_id = doc.id rrf_score = self._rrf_score(rank) rrf_scores[doc_id] += rrf_score if doc_id not in id_to_doc: id_to_doc[doc_id] = doc top_docs = heapq.nlargest(self.topn, rrf_scores.items(), key=lambda x: x[1]) results: list[Doc] = [] for doc_id, rrf_score in top_docs: doc = id_to_doc[doc_id] new_doc = doc._replace(score=rrf_score) results.append(new_doc) return results class WeightedReRanker(RerankFunction): """Re-ranker that combines scores from multiple vector fields using weights. Each vector field's relevance score is normalized based on its metric type, then scaled by a user-provided weight. Final scores are summed across fields. Note: This re-ranker is specifically designed for multi-vector scenarios where query results from multiple vector fields need to be combined with configurable weights. Args: topn (int, optional): Number of top documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored. Defaults to None. metric (MetricType, optional): Distance metric used for score normalization. Defaults to ``MetricType.L2``. weights (Optional[dict[str, float]], optional): Weight per vector field. Fields not listed use weight 1.0. Defaults to None. Note: Supported metrics: L2, IP, COSINE. Scores are normalized to [0, 1]. """ def __init__( self, topn: int = 10, rerank_field: Optional[str] = None, metric: MetricType = MetricType.L2, weights: Optional[dict[str, float]] = None, ): super().__init__(topn=topn, rerank_field=rerank_field) self._weights = weights or {} self._metric = metric @property def weights(self) -> dict[str, float]: """dict[str, float]: Weight mapping for vector fields.""" return self._weights @property def metric(self) -> MetricType: """MetricType: Distance metric used for score normalization.""" return self._metric def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Combine scores from multiple vector fields using weighted sum. Args: query_results (dict[str, list[Doc]]): Results per vector field. Returns: list[Doc]: Re-ranked documents with combined scores in ``score`` field. """ weighted_scores: dict[str, float] = defaultdict(float) id_to_doc: dict[str, Doc] = {} for vector_name, query_result in query_results.items(): for _, doc in enumerate(query_result): doc_id = doc.id weighted_score = self._normalize_score( doc.score, self.metric ) * self.weights.get(vector_name, 1.0) weighted_scores[doc_id] += weighted_score if doc_id not in id_to_doc: id_to_doc[doc_id] = doc top_docs = heapq.nlargest( self.topn, weighted_scores.items(), key=lambda x: x[1] ) results: list[Doc] = [] for doc_id, weighted_score in top_docs: doc = id_to_doc[doc_id] new_doc = doc._replace(score=weighted_score) results.append(new_doc) return results def _normalize_score(self, score: float, metric: MetricType) -> float: if metric == MetricType.L2: return 1.0 - 2 * math.atan(score) / math.pi if metric == MetricType.IP: return 0.5 + math.atan(score) / math.pi if metric == MetricType.COSINE: return 1.0 - score / 2.0 raise ValueError("Unsupported metric type") ================================================ FILE: python/zvec/extension/openai_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import lru_cache from typing import Optional from ..common.constants import TEXT, DenseVectorType from .embedding_function import DenseEmbeddingFunction from .openai_function import OpenAIFunctionBase class OpenAIDenseEmbedding(OpenAIFunctionBase, DenseEmbeddingFunction[TEXT]): """Dense text embedding function using OpenAI API. This class provides text-to-vector embedding capabilities using OpenAI's embedding models. It inherits from ``DenseEmbeddingFunction`` and implements dense text embedding via the OpenAI API. The implementation supports various OpenAI embedding models with different dimensions and includes automatic result caching for improved performance. Args: model (str, optional): OpenAI embedding model identifier. Defaults to ``"text-embedding-3-small"``. Common options: - ``"text-embedding-3-small"``: 1536 dims, cost-efficient, good performance - ``"text-embedding-3-large"``: 3072 dims, highest quality - ``"text-embedding-ada-002"``: 1536 dims, legacy model dimension (Optional[int], optional): Desired output embedding dimension. If ``None``, uses model's default dimension. For text-embedding-3 models, you can specify custom dimensions (e.g., 256, 512, 1024, 1536). Defaults to ``None``. api_key (Optional[str], optional): OpenAI API authentication key. If ``None``, reads from ``OPENAI_API_KEY`` environment variable. Obtain your key from: https://platform.openai.com/api-keys base_url (Optional[str], optional): Custom API base URL for OpenAI-compatible services. Defaults to ``None`` (uses official OpenAI endpoint). Attributes: dimension (int): The embedding vector dimension. data_type (DataType): Always ``DataType.VECTOR_FP32`` for this implementation. model (str): The OpenAI model name being used. Raises: ValueError: If API key is not provided and not found in environment, or if API returns an error response. TypeError: If input to ``embed()`` is not a string. RuntimeError: If network error or OpenAI service error occurs. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``openai`` package: ``pip install openai`` - Embedding results are cached (LRU cache, maxsize=10) to reduce API calls - Network connectivity to OpenAI API endpoints is required - API usage incurs costs based on your OpenAI subscription plan - Rate limits apply based on your OpenAI account tier Examples: >>> # Basic usage with default model >>> from zvec.extension import OpenAIDenseEmbedding >>> import os >>> os.environ["OPENAI_API_KEY"] = "sk-..." >>> >>> emb_func = OpenAIDenseEmbedding() >>> vector = emb_func.embed("Hello, world!") >>> len(vector) 1536 >>> # Using specific model with custom dimension >>> emb_func = OpenAIDenseEmbedding( ... model="text-embedding-3-large", ... dimension=1024, ... api_key="sk-..." ... ) >>> vector = emb_func.embed("Machine learning is fascinating") >>> len(vector) 1024 >>> # Using with custom base URL (e.g., Azure OpenAI) >>> emb_func = OpenAIDenseEmbedding( ... model="text-embedding-ada-002", ... api_key="your-azure-key", ... base_url="https://your-resource.openai.azure.com/" ... ) >>> vector = emb_func("Natural language processing") >>> isinstance(vector, list) True >>> # Batch processing with caching benefit >>> texts = ["First text", "Second text", "First text"] >>> vectors = [emb_func.embed(text) for text in texts] >>> # Third call uses cached result for "First text" >>> # Error handling >>> try: ... emb_func.embed("") # Empty string ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only See Also: - ``DenseEmbeddingFunction``: Base class for dense embeddings - ``QwenDenseEmbedding``: Alternative using Qwen/DashScope API - ``DefaultDenseEmbedding``: Local model without API calls - ``SparseEmbeddingFunction``: Base class for sparse embeddings """ def __init__( self, model: str = "text-embedding-3-small", dimension: Optional[int] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, **kwargs, ): """Initialize the OpenAI dense embedding function. Args: model (str): OpenAI model name. Defaults to "text-embedding-3-small". dimension (Optional[int]): Target embedding dimension or None for default. api_key (Optional[str]): API key or None to use environment variable. base_url (Optional[str]): Custom API base URL or None for default. **kwargs: Additional parameters for API calls. Examples: - ``encoding_format`` (str): Format of embeddings, "float" or "base64". - ``user`` (str): User identifier for tracking. Raises: ValueError: If API key is not provided and not in environment. """ # Initialize base class for API connection OpenAIFunctionBase.__init__( self, model=model, api_key=api_key, base_url=base_url ) # Store dimension configuration self._custom_dimension = dimension # Determine actual dimension if dimension is None: # Use model default dimension self._dimension = self._MODEL_DIMENSIONS.get(model, 1536) else: self._dimension = dimension # Store dense-specific attributes self._extra_params = kwargs @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" return self._dimension @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: TEXT) -> DenseVectorType: """Make the embedding function callable.""" return self.embed(input) @lru_cache(maxsize=10) def embed(self, input: TEXT) -> DenseVectorType: """Generate dense embedding vector for the input text. This method calls the OpenAI Embeddings API to convert input text into a dense vector representation. Results are cached to improve performance for repeated inputs. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Maximum length is 8191 tokens for most models. Returns: DenseVectorType: A list of floats representing the embedding vector. Length equals ``self.dimension``. Example: ``[0.123, -0.456, 0.789, ...]`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty/whitespace-only, or if the API returns an error or malformed response. RuntimeError: If network connectivity issues or OpenAI service errors occur. Examples: >>> emb = OpenAIDenseEmbedding() >>> vector = emb.embed("Natural language processing") >>> len(vector) 1536 >>> isinstance(vector[0], float) True >>> # Error: empty input >>> emb.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> emb.embed(123) TypeError: Expected 'input' to be str, got int Note: - This method is cached (maxsize=10). Identical inputs return cached results. - The cache is based on exact string match (case-sensitive). - Consider pre-processing text (lowercasing, normalization) for better caching. """ if not isinstance(input, TEXT): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") # Call API embedding_vector = self._call_text_embedding_api( input=input, dimension=self._custom_dimension, ) # Verify dimension if len(embedding_vector) != self.dimension: raise ValueError( f"Dimension mismatch: expected {self.dimension}, " f"got {len(embedding_vector)}" ) return embedding_vector ================================================ FILE: python/zvec/extension/openai_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os from typing import ClassVar, Optional from ..common.constants import TEXT from ..tool import require_module class OpenAIFunctionBase: """Base class for OpenAI functions. This base class provides common functionality for calling OpenAI APIs and handling responses. It supports embeddings (dense) operations. This class is not meant to be used directly. Use concrete implementations: - ``OpenAIDenseEmbedding`` for dense embeddings Args: model (str): OpenAI model identifier. api_key (Optional[str]): OpenAI API authentication key. base_url (Optional[str]): Custom API base URL. Note: - This is an internal base class for code reuse across OpenAI features - Subclasses should inherit from appropriate Protocol - Provides unified API connection and response handling """ # Model default dimensions _MODEL_DIMENSIONS: ClassVar[dict[str, int]] = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, } def __init__( self, model: str, api_key: Optional[str] = None, base_url: Optional[str] = None, ): """Initialize the base OpenAI functionality. Args: model (str): OpenAI model name. api_key (Optional[str]): API key or None to use environment variable. base_url (Optional[str]): Custom API base URL or None for default. Raises: ValueError: If API key is not provided and not in environment. """ self._model = model self._api_key = api_key or os.environ.get("OPENAI_API_KEY") self._base_url = base_url if not self._api_key: raise ValueError( "OpenAI API key is required. Please provide 'api_key' parameter " "or set the 'OPENAI_API_KEY' environment variable." ) @property def model(self) -> str: """str: The OpenAI model name currently in use.""" return self._model def _get_client(self): """Get OpenAI client instance. Returns: OpenAI: Configured OpenAI client. Raises: ImportError: If openai package is not installed. """ openai = require_module("openai") if self._base_url: return openai.OpenAI(api_key=self._api_key, base_url=self._base_url) return openai.OpenAI(api_key=self._api_key) def _call_text_embedding_api( self, input: TEXT, dimension: Optional[int] = None, ) -> list: """Call OpenAI Embeddings API. Args: input (TEXT): Input text to embed. dimension (Optional[int]): Target dimension (for models that support it). Returns: list: Embedding vector as list of floats. Raises: RuntimeError: If API call fails. ValueError: If API returns error response. """ try: client = self._get_client() # Prepare embedding parameters params = {"model": self.model, "input": input} # Add dimension parameter for models that support it if dimension is not None: params["dimensions"] = dimension # Call OpenAI API response = client.embeddings.create(**params) except Exception as e: # Check if it's an OpenAI API error openai = require_module("openai") if isinstance(e, (openai.APIError, openai.APIConnectionError)): raise RuntimeError(f"Failed to call OpenAI API: {e!s}") from e raise RuntimeError(f"Unexpected error during API call: {e!s}") from e # Extract embedding from response try: if not response.data: raise ValueError("Invalid API response: no embedding data returned") embedding_vector = response.data[0].embedding if not isinstance(embedding_vector, list): raise ValueError( "Invalid API response: embedding is not a list of numbers" ) return embedding_vector except (AttributeError, IndexError, TypeError) as e: raise ValueError(f"Failed to parse API response: {e!s}") from e ================================================ FILE: python/zvec/extension/qwen_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import lru_cache from typing import Optional from ..common.constants import TEXT, DenseVectorType, SparseVectorType from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction from .qwen_function import QwenFunctionBase class QwenDenseEmbedding(QwenFunctionBase, DenseEmbeddingFunction[TEXT]): """Dense text embedding function using Qwen (DashScope) API. This class provides text-to-vector embedding capabilities using Alibaba Cloud's DashScope service and Qwen embedding models. It inherits from ``DenseEmbeddingFunction`` and implements dense text embedding. The implementation supports various Qwen embedding models with configurable dimensions and includes automatic result caching for improved performance. Args: dimension (int): Desired output embedding dimension. Common values: - 512: Balanced performance and accuracy - 1024: Higher accuracy, larger storage - 1536: Maximum accuracy for supported models model (str, optional): DashScope embedding model identifier. Defaults to ``"text-embedding-v4"``. Other options include: - ``"text-embedding-v3"`` - ``"text-embedding-v2"`` - ``"text-embedding-v1"`` api_key (Optional[str], optional): DashScope API authentication key. If ``None``, reads from ``DASHSCOPE_API_KEY`` environment variable. Obtain your key from: https://dashscope.console.aliyun.com/ **kwargs: Additional DashScope API parameters. Supported options: - ``text_type`` (str): Specifies the text role in retrieval tasks. Options: ``"query"`` (search query) or ``"document"`` (indexed content). This parameter optimizes embeddings for asymmetric search scenarios. Reference: https://help.aliyun.com/zh/model-studio/text-embedding-synchronous-api Attributes: dimension (int): The embedding vector dimension. data_type (DataType): Always ``DataType.VECTOR_FP32`` for this implementation. model (str): The DashScope model name being used. Raises: ValueError: If API key is not provided and not found in environment, or if API returns an error response. TypeError: If input to ``embed()`` is not a string. RuntimeError: If network error or DashScope service error occurs. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``dashscope`` package: ``pip install dashscope`` - Embedding results are cached (LRU cache, maxsize=10) to reduce API calls - Network connectivity to DashScope API endpoints is required - API usage may incur costs based on your DashScope subscription plan **Parameter Guidelines:** - Use ``text_type="query"`` for search queries and ``text_type="document"`` for indexed content to optimize asymmetric retrieval tasks. - For detailed API specifications and parameter usage, refer to: https://help.aliyun.com/zh/model-studio/text-embedding-synchronous-api Examples: >>> # Basic usage with default model >>> from zvec.extension import QwenDenseEmbedding >>> import os >>> os.environ["DASHSCOPE_API_KEY"] = "your-api-key" >>> >>> emb_func = QwenDenseEmbedding(dimension=1024) >>> vector = emb_func.embed("Hello, world!") >>> len(vector) 1024 >>> # Using specific model with explicit API key >>> emb_func = QwenDenseEmbedding( ... dimension=512, ... model="text-embedding-v3", ... api_key="sk-xxxxx" ... ) >>> vector = emb_func("Machine learning is fascinating") >>> isinstance(vector, list) True >>> # Using with custom parameters (text_type) >>> # For search queries - optimize for query-document matching >>> emb_func = QwenDenseEmbedding( ... dimension=1024, ... text_type="query" ... ) >>> query_vector = emb_func.embed("What is machine learning?") >>> >>> # For document embeddings - optimize for being matched by queries >>> doc_emb_func = QwenDenseEmbedding( ... dimension=1024, ... text_type="document" ... ) >>> doc_vector = doc_emb_func.embed( ... "Machine learning is a subset of artificial intelligence..." ... ) >>> # Batch processing with caching benefit >>> texts = ["First text", "Second text", "First text"] >>> vectors = [emb_func.embed(text) for text in texts] >>> # Third call uses cached result for "First text" >>> # Error handling >>> try: ... emb_func.embed("") # Empty string ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only See Also: - ``DenseEmbeddingFunction``: Base class for dense embeddings - ``SparseEmbeddingFunction``: Base class for sparse embeddings """ def __init__( self, dimension: int, model: str = "text-embedding-v4", api_key: Optional[str] = None, **kwargs, ): """Initialize the Qwen dense embedding function. Args: dimension (int): Target embedding dimension. model (str): DashScope model name. Defaults to "text-embedding-v4". api_key (Optional[str]): API key or None to use environment variable. **kwargs: Additional DashScope API parameters. Supported options: - ``text_type`` (str): Text role in asymmetric retrieval. * ``"query"``: Optimize for search queries (short, question-like). * ``"document"``: Optimize for indexed documents (longer content). Using appropriate text_type improves retrieval accuracy by optimizing the embedding space for query-document matching. For detailed API documentation, see: https://help.aliyun.com/zh/model-studio/text-embedding-synchronous-api Raises: ValueError: If API key is not provided and not in environment. """ # Initialize base class for API connection QwenFunctionBase.__init__(self, model=model, api_key=api_key) # Store dense-specific attributes self._dimension = dimension self._extra_params = kwargs @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" return self._dimension @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: TEXT) -> DenseVectorType: """Make the embedding function callable.""" return self.embed(input) @lru_cache(maxsize=10) def embed(self, input: TEXT) -> DenseVectorType: """Generate dense embedding vector for the input text. This method calls the DashScope TextEmbedding API to convert input text into a dense vector representation. Results are cached to improve performance for repeated inputs. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Maximum length depends on the model used (typically 2048-8192 tokens). Returns: DenseVectorType: A list of floats representing the embedding vector. Length equals ``self.dimension``. Example: ``[0.123, -0.456, 0.789, ...]`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty/whitespace-only, or if the API returns an error or malformed response. RuntimeError: If network connectivity issues or DashScope service errors occur. Examples: >>> emb = QwenDenseEmbedding(dimension=1024) >>> vector = emb.embed("Natural language processing") >>> len(vector) 1024 >>> isinstance(vector[0], float) True >>> # Error: empty input >>> emb.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> emb.embed(123) TypeError: Expected 'input' to be str, got int Note: - This method is cached (maxsize=10). Identical inputs return cached results. - The cache is based on exact string match (case-sensitive). - Consider pre-processing text (lowercasing, normalization) for better caching. """ if not isinstance(input, TEXT): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") # Call API with dense output type output = self._call_text_embedding_api( input=input, dimension=self.dimension, output_type="dense", text_type=self.extra_params.get("text_type"), ) embeddings = output.get("embeddings") if not isinstance(embeddings, list): raise ValueError( "Invalid API response: 'embeddings' field is missing or not a list" ) if len(embeddings) != 1: raise ValueError( f"Expected exactly 1 embedding in response, got {len(embeddings)}" ) first_emb = embeddings[0] if not isinstance(first_emb, dict): raise ValueError("Invalid API response: embedding item is not a dictionary") embedding_vector = first_emb.get("embedding") if not isinstance(embedding_vector, list): raise ValueError( "Invalid API response: 'embedding' field is missing or not a list" ) if len(embedding_vector) != self.dimension: raise ValueError( f"Dimension mismatch: expected {self.dimension}, " f"got {len(embedding_vector)}" ) return list(embedding_vector) class QwenSparseEmbedding(QwenFunctionBase, SparseEmbeddingFunction[TEXT]): """Sparse text embedding function using Qwen (DashScope) API. This class provides text-to-sparse-vector embedding capabilities using Alibaba Cloud's DashScope service and Qwen embedding models. It generates sparse keyword-weighted vectors suitable for lexical matching and BM25-style retrieval scenarios. Sparse embeddings are particularly useful for: - Keyword-based search and exact matching - Hybrid retrieval (combining with dense embeddings) - Interpretable search results (weights show term importance) Args: dimension (int): Desired output embedding dimension. Common values: - 512: Balanced performance and accuracy - 1024: Higher accuracy, larger storage - 1536: Maximum accuracy for supported models model (str, optional): DashScope embedding model identifier. Defaults to ``"text-embedding-v4"``. Other options include: - ``"text-embedding-v3"`` - ``"text-embedding-v2"`` api_key (Optional[str], optional): DashScope API authentication key. If ``None``, reads from ``DASHSCOPE_API_KEY`` environment variable. Obtain your key from: https://dashscope.console.aliyun.com/ **kwargs: Additional DashScope API parameters. Supported options: - ``encoding_type`` (Literal["query", "document"]): Encoding type. * ``"query"``: Optimize for search queries (default). * ``"document"``: Optimize for indexed documents. This distinction is important for asymmetric retrieval tasks. Attributes: model (str): The DashScope model name being used. encoding_type (str): The encoding type ("query" or "document"). Raises: ValueError: If API key is not provided and not found in environment, or if API returns an error response. TypeError: If input to ``embed()`` is not a string. RuntimeError: If network error or DashScope service error occurs. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``dashscope`` package: ``pip install dashscope`` - Embedding results are cached (LRU cache, maxsize=10) to reduce API calls - Network connectivity to DashScope API endpoints is required - API usage may incur costs based on your DashScope subscription plan - Sparse vectors have only non-zero dimensions stored as dict - Output is sorted by indices (keys) in ascending order **Parameter Guidelines:** - Use ``encoding_type="query"`` for search queries and ``encoding_type="document"`` for indexed content to optimize asymmetric retrieval tasks. - For detailed API specifications, refer to: https://help.aliyun.com/zh/model-studio/text-embedding-synchronous-api Examples: >>> # Basic usage for query embedding >>> from zvec.extension import QwenSparseEmbedding >>> import os >>> os.environ["DASHSCOPE_API_KEY"] = "your-api-key" >>> >>> query_emb = QwenSparseEmbedding(dimension=1024, encoding_type="query") >>> query_vec = query_emb.embed("machine learning") >>> type(query_vec) >>> len(query_vec) # Only non-zero dimensions 156 >>> # Document embedding >>> doc_emb = QwenSparseEmbedding(dimension=1024, encoding_type="document") >>> doc_vec = doc_emb.embed("Machine learning is a subset of AI") >>> isinstance(doc_vec, dict) True >>> # Asymmetric retrieval example >>> query_vec = query_emb.embed("what causes aging fast") >>> doc_vec = doc_emb.embed( ... "UV-A light causes tanning, skin aging, and cataracts..." ... ) >>> >>> # Calculate similarity (dot product for sparse vectors) >>> similarity = sum( ... query_vec.get(k, 0) * doc_vec.get(k, 0) ... for k in set(query_vec) | set(doc_vec) ... ) >>> # Output is sorted by indices >>> list(query_vec.items())[:5] # First 5 dimensions (by index) [(10, 0.45), (23, 0.87), (56, 0.32), (89, 1.12), (120, 0.65)] >>> # Hybrid retrieval (combining dense + sparse) >>> from zvec.extension import QwenDenseEmbedding >>> dense_emb = QwenDenseEmbedding(dimension=1024) >>> sparse_emb = QwenSparseEmbedding(dimension=1024) >>> >>> query = "deep learning neural networks" >>> dense_vec = dense_emb.embed(query) # [0.1, -0.3, 0.5, ...] >>> sparse_vec = sparse_emb.embed(query) # {12: 0.8, 45: 1.2, ...} >>> # Error handling >>> try: ... sparse_emb.embed("") # Empty string ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only See Also: - ``SparseEmbeddingFunction``: Base class for sparse embeddings - ``QwenDenseEmbedding``: Dense embedding using Qwen API - ``DefaultSparseEmbedding``: Sparse embedding with SPLADE model """ def __init__( self, dimension: int, model: str = "text-embedding-v4", api_key: Optional[str] = None, **kwargs, ): """Initialize the Qwen sparse embedding function. Args: dimension (int): Target embedding dimension. model (str): DashScope model name. Defaults to "text-embedding-v4". api_key (Optional[str]): API key or None to use environment variable. **kwargs: Additional DashScope API parameters. Supported options: - ``encoding_type`` (Literal["query", "document"]): Encoding type. * ``"query"``: Optimize for search queries (default). * ``"document"``: Optimize for indexed documents. This distinction is important for asymmetric retrieval tasks. Raises: ValueError: If API key is not provided and not in environment. """ # Initialize base class for API connection QwenFunctionBase.__init__(self, model=model, api_key=api_key) self._dimension = dimension self._extra_params = kwargs @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: TEXT) -> SparseVectorType: """Make the embedding function callable.""" return self.embed(input) @lru_cache(maxsize=10) def embed(self, input: TEXT) -> SparseVectorType: """Generate sparse embedding vector for the input text. This method calls the DashScope TextEmbedding API with sparse output type to convert input text into a sparse vector representation. The result is a dictionary where keys are dimension indices and values are importance weights (only non-zero values included). The embedding is optimized based on the ``encoding_type`` specified during initialization: "query" for search queries or "document" for indexed content. Args: input (TEXT): Input text string to embed. Must be non-empty after stripping whitespace. Maximum length depends on the model used (typically 2048-8192 tokens). Returns: SparseVectorType: A dictionary mapping dimension index to weight. Only non-zero dimensions are included. The dictionary is sorted by indices (keys) in ascending order for consistent output. Example: ``{10: 0.5, 245: 0.8, 1023: 1.2, 5678: 0.5}`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty/whitespace-only, or if the API returns an error or malformed response. RuntimeError: If network connectivity issues or DashScope service errors occur. Examples: >>> emb = QwenSparseEmbedding(dimension=1024, encoding_type="query") >>> sparse_vec = emb.embed("machine learning") >>> isinstance(sparse_vec, dict) True >>> >>> # Verify sorted output >>> keys = list(sparse_vec.keys()) >>> keys == sorted(keys) True >>> # Error: empty input >>> emb.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> emb.embed(123) TypeError: Expected 'input' to be str, got int Note: - This method is cached (maxsize=10). Identical inputs return cached results. - The cache is based on exact string match (case-sensitive). - Output dictionary is always sorted by indices for consistency. """ if not isinstance(input, TEXT): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") # Call API with sparse output type output = self._call_text_embedding_api( input=input, dimension=self._dimension, output_type="sparse", text_type=self.extra_params.get("encoding_type", "query"), ) embeddings = output.get("embeddings") if not isinstance(embeddings, list): raise ValueError( "Invalid API response: 'embeddings' field is missing or not a list" ) if len(embeddings) != 1: raise ValueError( f"Expected exactly 1 embedding in response, got {len(embeddings)}" ) first_emb = embeddings[0] if not isinstance(first_emb, dict): raise ValueError("Invalid API response: embedding item is not a dictionary") sparse_embedding = first_emb.get("sparse_embedding") if not isinstance(sparse_embedding, list): raise ValueError( "Invalid API response: 'sparse_embedding' field is missing or not a list" ) # Parse sparse embedding: convert array of {index, value, token} to dict sparse_dict = {} for item in sparse_embedding: if not isinstance(item, dict): raise ValueError( "Invalid API response: sparse_embedding item is not a dictionary" ) index = item.get("index") value = item.get("value") if index is None or value is None: raise ValueError( "Invalid API response: sparse_embedding item missing 'index' or 'value'" ) # Convert to int and float, filter positive values idx = int(index) val = float(value) if val > 0: sparse_dict[idx] = val # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) ================================================ FILE: python/zvec/extension/qwen_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os from http import HTTPStatus from typing import Optional from ..common.constants import TEXT from ..tool import require_module class QwenFunctionBase: """Base class for Qwen (DashScope) functions. This base class provides common functionality for calling DashScope APIs and handling responses. It supports embeddings (dense and sparse) and re-ranking operations. This class is not meant to be used directly. Use concrete implementations: - ``QwenDenseEmbedding`` for dense embeddings - ``QwenSparseEmbedding`` for sparse embeddings - ``QwenReRanker`` for semantic re-ranking Args: model (str): DashScope model identifier. api_key (Optional[str]): DashScope API authentication key. Note: - This is an internal base class for code reuse across Qwen features - Subclasses should inherit from appropriate Protocol/ABC - Provides unified API connection and response handling """ def __init__( self, model: str, api_key: Optional[str] = None, ): """Initialize the base Qwen embedding functionality. Args: model (str): DashScope model name. api_key (Optional[str]): API key or None to use environment variable. Raises: ValueError: If API key is not provided and not in environment. """ self._model = model self._api_key = api_key or os.environ.get("DASHSCOPE_API_KEY") if not self._api_key: raise ValueError( "DashScope API key is required. Please provide 'api_key' parameter " "or set the 'DASHSCOPE_API_KEY' environment variable." ) @property def model(self) -> str: """str: The DashScope embedding model name currently in use.""" return self._model def _get_connection(self): """Establish connection to DashScope API. Returns: module: The dashscope module with API key configured. Raises: ImportError: If dashscope package is not installed. """ dashscope = require_module("dashscope") dashscope.api_key = self._api_key return dashscope def _call_text_embedding_api( self, input: TEXT, dimension: int, output_type: str, text_type: Optional[str] = None, ) -> dict: """Call DashScope TextEmbedding API. Args: input (TEXT): Input text to embed. dimension (int): Target embedding dimension. output_type (str): Output type ("dense" or "sparse"). text_type (Optional[str]): Text type ("query" or "document"). Returns: dict: API response output field. Raises: RuntimeError: If API call fails. ValueError: If API returns error response. """ try: # Prepare API call parameters call_params = { "model": self.model, "input": input, "dimension": dimension, "output_type": output_type, } # Add optional text_type parameter if provided if text_type is not None: call_params["text_type"] = text_type resp = self._get_connection().TextEmbedding.call(**call_params) except Exception as e: raise RuntimeError(f"Failed to call DashScope API: {e!s}") from e if resp.status_code != HTTPStatus.OK: error_msg = getattr(resp, "message", "Unknown error") error_code = getattr(resp, "code", "N/A") raise ValueError( f"DashScope API error: [Code={error_code}, " f"Status={resp.status_code}] {error_msg}" ) output = getattr(resp, "output", None) if not isinstance(output, dict): raise ValueError( "Invalid API response: missing or malformed 'output' field" ) return output def _call_rerank_api( self, query: str, documents: list[str], top_n: int, ) -> dict: """Call DashScope TextReRank API. Args: query (str): Query text for semantic matching. documents (list[str]): List of document texts to re-rank. top_n (int): Maximum number of documents to return. Returns: dict: API response output field containing re-ranked results. Raises: RuntimeError: If API call fails. ValueError: If API returns error response. """ try: resp = self._get_connection().TextReRank.call( model=self.model, query=query, documents=documents, top_n=top_n, return_documents=False, ) except Exception as e: raise RuntimeError(f"Failed to call DashScope API: {e!s}") from e if resp.status_code != HTTPStatus.OK: error_msg = getattr(resp, "message", "Unknown error") error_code = getattr(resp, "code", "N/A") raise ValueError( f"DashScope API error: [Code={error_code}, " f"Status={resp.status_code}] {error_msg}" ) output = getattr(resp, "output", None) if not isinstance(output, dict): raise ValueError( "Invalid API response: missing or malformed 'output' field" ) return output ================================================ FILE: python/zvec/extension/qwen_rerank_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Optional from ..model.doc import Doc from .qwen_function import QwenFunctionBase from .rerank_function import RerankFunction class QwenReRanker(QwenFunctionBase, RerankFunction): """Re-ranker using Qwen (DashScope) cross-encoder API for semantic re-ranking. This re-ranker leverages DashScope's TextReRank service to perform cross-encoder style re-ranking. It sends query and document pairs to the API and receives relevance scores based on deep semantic understanding. The re-ranker is suitable for single-vector or multi-vector search scenarios where semantic relevance to a specific query is required. Args: query (str): Query text for semantic re-ranking. **Required**. topn (int, optional): Maximum number of documents to return after re-ranking. Defaults to 10. rerank_field (str): Document field name to use as re-ranking input text. **Required** (e.g., "content", "title", "body"). model (str, optional): DashScope re-ranking model identifier. Defaults to ``"gte-rerank-v2"``. api_key (Optional[str], optional): DashScope API authentication key. If not provided, reads from ``DASHSCOPE_API_KEY`` environment variable. Raises: ValueError: If ``query`` is empty/None, ``rerank_field`` is None, or API key is not available. Note: - Requires ``dashscope`` Python package installed - Documents without valid content in ``rerank_field`` are skipped - API rate limits and quotas apply per DashScope subscription Example: >>> reranker = QwenReRanker( ... query="machine learning algorithms", ... topn=5, ... rerank_field="content", ... model="gte-rerank-v2", ... api_key="your-api-key" ... ) >>> # Use in collection.query(reranker=reranker) """ def __init__( self, query: Optional[str] = None, topn: int = 10, rerank_field: Optional[str] = None, model: str = "gte-rerank-v2", api_key: Optional[str] = None, ): """Initialize QwenReRanker with query and configuration. Args: query (Optional[str]): Query text for semantic matching. Required. topn (int): Number of top results to return. rerank_field (Optional[str]): Document field for re-ranking input. model (str): DashScope model name. api_key (Optional[str]): API key or None to use environment variable. Raises: ValueError: If query is empty or API key is unavailable. """ QwenFunctionBase.__init__(self, model=model, api_key=api_key) RerankFunction.__init__(self, topn=topn, rerank_field=rerank_field) if not query: raise ValueError("Query is required for QwenReRanker") self._query = query @property def query(self) -> str: """str: Query text used for semantic re-ranking.""" return self._query def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Re-rank documents using Qwen's TextReRank API. Sends document texts to DashScope TextReRank service along with the query. Returns documents sorted by relevance scores from the cross-encoder model. Args: query_results (dict[str, list[Doc]]): Mapping from vector field names to lists of retrieved documents. Documents from all fields are deduplicated and re-ranked together. Returns: list[Doc]: Re-ranked documents (up to ``topn``) with updated ``score`` fields containing relevance scores from the API. Raises: ValueError: If no valid documents are found or API call fails. Note: - Duplicate documents (same ID) across fields are processed once - Documents with empty/missing ``rerank_field`` content are skipped - Returned scores are relevance scores from the cross-encoder model """ if not query_results: return [] # Collect and deduplicate documents id_to_doc: dict[str, Doc] = {} doc_ids: list[str] = [] contents: list[str] = [] for _, query_result in query_results.items(): for doc in query_result: doc_id = doc.id if doc_id in id_to_doc: continue # Extract text content from specified field field_value = doc.field(self.rerank_field) rank_content = str(field_value).strip() if field_value else "" if not rank_content: continue id_to_doc[doc_id] = doc doc_ids.append(doc_id) contents.append(rank_content) if not contents: raise ValueError("No documents to rerank") # Call DashScope TextReRank API output = self._call_rerank_api( query=self.query, documents=contents, top_n=self.topn, ) # Build result list with updated scores results: list[Doc] = [] for item in output["results"]: idx = item["index"] doc_id = doc_ids[idx] doc = id_to_doc[doc_id] new_doc = doc._replace(score=item["relevance_score"]) results.append(new_doc) return results ================================================ FILE: python/zvec/extension/rerank_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional from ..model.doc import Doc class RerankFunction(ABC): """Abstract base class for re-ranking search results. Re-rankers refine the output of one or more vector queries by applying a secondary scoring strategy. They are used in the ``query()`` method of ``Collection`` via the ``reranker`` parameter. Args: topn (int, optional): Number of top documents to return after re-ranking. Defaults to 10. rerank_field (Optional[str], optional): Field name used as input for re-ranking (e.g., document title or body). Defaults to None. Note: Subclasses must implement the ``rerank()`` method. """ def __init__( self, topn: int = 10, rerank_field: Optional[str] = None, ): self._topn = topn self._rerank_field = rerank_field @property def topn(self) -> int: """int: Number of top documents to return after re-ranking.""" return self._topn @property def rerank_field(self) -> Optional[str]: """Optional[str]: Field name used as re-ranking input.""" return self._rerank_field @abstractmethod def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Re-rank documents from one or more vector queries. Args: query_results (dict[str, list[Doc]]): Mapping from vector field name to list of retrieved documents (sorted by relevance). Returns: list[Doc]: Re-ranked list of documents (length ≤ ``topn``), with updated ``score`` fields. """ ... ================================================ FILE: python/zvec/extension/sentence_transformer_embedding_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import ClassVar, Literal, Optional import numpy as np from ..common.constants import TEXT, DenseVectorType, SparseVectorType from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction from .sentence_transformer_function import SentenceTransformerFunctionBase class DefaultLocalDenseEmbedding( SentenceTransformerFunctionBase, DenseEmbeddingFunction[TEXT] ): """Default local dense embedding using all-MiniLM-L6-v2 model. This is the default implementation for dense text embedding that uses the ``all-MiniLM-L6-v2`` model from Hugging Face by default. This model provides a good balance between speed and quality for general-purpose text embedding. The class provides text-to-vector dense embedding capabilities using the sentence-transformers library. It supports models from Hugging Face Hub and ModelScope, runs locally without API calls, and supports CPU/GPU acceleration. The model produces 384-dimensional embeddings and is optimized for semantic similarity tasks. It runs locally without requiring API keys. Args: model_source (Literal["huggingface", "modelscope"], optional): Model source. - ``"huggingface"``: Use Hugging Face Hub (default, for international users) - ``"modelscope"``: Use ModelScope (recommended for users in China) Defaults to ``"huggingface"``. device (Optional[str], optional): Device to run the model on. Options: ``"cpu"``, ``"cuda"``, ``"mps"`` (for Apple Silicon), or ``None`` for automatic detection. Defaults to ``None``. normalize_embeddings (bool, optional): Whether to normalize embeddings to unit length (L2 normalization). Useful for cosine similarity. Defaults to ``True``. batch_size (int, optional): Batch size for encoding. Defaults to ``32``. **kwargs: Additional parameters for future extension. Attributes: dimension (int): Always 384 for both models. model_name (str): "all-MiniLM-L6-v2" (HF) or "iic/nlp_gte_sentence-embedding_chinese-small" (MS). model_source (str): The model source being used. device (str): The device the model is running on. Raises: ValueError: If the model cannot be loaded or input is invalid. TypeError: If input to ``embed()`` is not a string. RuntimeError: If model inference fails. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``sentence-transformers`` package: ``pip install sentence-transformers`` - For ModelScope, also requires: ``pip install modelscope`` - First run downloads the model (~50-80MB) from chosen source - Hugging Face cache: ``~/.cache/torch/sentence_transformers/`` - ModelScope cache: ``~/.cache/modelscope/hub/`` - No API keys or network required after initial download - Inference speed: ~1000 sentences/sec on CPU, ~10000 on GPU **For users in China:** If you encounter Hugging Face access issues, use ModelScope instead: .. code-block:: python # Recommended for users in China emb = DefaultLocalDenseEmbedding(model_source="modelscope") Alternatively, use Hugging Face mirror: .. code-block:: bash export HF_ENDPOINT=https://hf-mirror.com # Then use default Hugging Face mode Examples: >>> # Basic usage with Hugging Face (default) >>> from zvec.extension import DefaultLocalDenseEmbedding >>> >>> emb_func = DefaultLocalDenseEmbedding() >>> vector = emb_func.embed("Hello, world!") >>> len(vector) 384 >>> isinstance(vector, list) True >>> # Recommended for users in China (uses ModelScope) >>> emb_func = DefaultLocalDenseEmbedding(model_source="modelscope") >>> vector = emb_func.embed("你好,世界!") # Works well with Chinese text >>> len(vector) 384 >>> # Alternative for China users: Use Hugging Face mirror >>> import os >>> os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" >>> emb_func = DefaultLocalDenseEmbedding() # Uses HF mirror >>> vector = emb_func.embed("Hello, world!") >>> # Using GPU for faster inference >>> emb_func = DefaultLocalDenseEmbedding(device="cuda") >>> vector = emb_func("Machine learning is fascinating") >>> # Normalized vector has unit length >>> import numpy as np >>> np.linalg.norm(vector) 1.0 >>> # Batch processing >>> texts = ["First text", "Second text", "Third text"] >>> vectors = [emb_func.embed(text) for text in texts] >>> len(vectors) 3 >>> all(len(v) == 384 for v in vectors) True >>> # Semantic similarity >>> v1 = emb_func.embed("The cat sits on the mat") >>> v2 = emb_func.embed("A feline rests on a rug") >>> v3 = emb_func.embed("Python programming") >>> similarity_high = np.dot(v1, v2) # Similar sentences >>> similarity_low = np.dot(v1, v3) # Different topics >>> similarity_high > similarity_low True >>> # Error handling >>> try: ... emb_func.embed("") # Empty string ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only See Also: - ``DenseEmbeddingFunction``: Base class for dense embeddings - ``DefaultLocalSparseEmbedding``: Sparse embedding with SPLADE - ``QwenDenseEmbedding``: Alternative using Qwen API """ def __init__( self, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, normalize_embeddings: bool = True, batch_size: int = 32, **kwargs, ): """Initialize with all-MiniLM-L6-v2 model. Args: model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). Defaults to None (automatic detection). normalize_embeddings (bool): Whether to L2-normalize output vectors. Defaults to True. batch_size (int): Batch size for encoding. Defaults to 32. **kwargs: Additional parameters for future extension. Raises: ImportError: If sentence-transformers or modelscope is not installed. ValueError: If model cannot be loaded. """ # Use different models based on source if model_source == "modelscope": # Use Chinese-optimized model for ModelScope (better for Chinese text) model_name = "iic/nlp_gte_sentence-embedding_chinese-small" else: model_name = "all-MiniLM-L6-v2" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( self, model_name=model_name, model_source=model_source, device=device ) self._normalize_embeddings = normalize_embeddings self._batch_size = batch_size # Load model and get dimension model = self._get_model() self._dimension = model.get_sentence_embedding_dimension() # Store extra parameters self._extra_params = kwargs @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" return self._dimension @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: str) -> DenseVectorType: """Make the embedding function callable.""" return self.embed(input) def embed(self, input: str) -> DenseVectorType: """Generate dense embedding vector for the input text. This method uses the Sentence Transformer model to convert input text into a dense vector representation. The model runs locally without requiring API calls. Args: input (str): Input text string to embed. Must be non-empty after stripping whitespace. Maximum length depends on the model used (typically 128-512 tokens for most models). Returns: DenseVectorType: A list of floats representing the embedding vector. Length equals ``self.dimension``. If ``normalize_embeddings=True``, the vector has unit length. Example: ``[0.123, -0.456, 0.789, ...]`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty or whitespace-only. RuntimeError: If model inference fails. Examples: >>> emb = DefaultLocalDenseEmbedding() >>> vector = emb.embed("Natural language processing") >>> len(vector) 384 >>> isinstance(vector[0], float) True >>> # Normalized vectors have unit length >>> import numpy as np >>> emb = DefaultLocalDenseEmbedding(normalize_embeddings=True) >>> vector = emb.embed("Test sentence") >>> np.linalg.norm(vector) 1.0 >>> # Error: empty input >>> emb.embed(" ") ValueError: Input text cannot be empty or whitespace only >>> # Error: non-string input >>> emb.embed(123) TypeError: Expected 'input' to be str, got int >>> # Semantic similarity example >>> v1 = emb.embed("The cat sits on the mat") >>> v2 = emb.embed("A feline rests on a rug") >>> similarity = np.dot(v1, v2) # High similarity due to semantic meaning >>> similarity > 0.7 True Note: - First call may be slower due to model loading - Subsequent calls are much faster as the model stays in memory - For batch processing, consider encoding multiple texts together (though this method handles single texts only) - GPU acceleration provides 5-10x speedup over CPU """ if not isinstance(input, str): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") try: model = self._get_model() embedding = model.encode( input, convert_to_numpy=True, normalize_embeddings=self._normalize_embeddings, batch_size=self._batch_size, ) # Convert numpy array to list if isinstance(embedding, np.ndarray): embedding_list = embedding.tolist() else: embedding_list = list(embedding) # Validate dimension if len(embedding_list) != self.dimension: raise ValueError( f"Dimension mismatch: expected {self.dimension}, " f"got {len(embedding_list)}" ) return embedding_list except Exception as e: if isinstance(e, (TypeError, ValueError)): raise raise RuntimeError(f"Failed to generate embedding: {e!s}") from e class DefaultLocalSparseEmbedding( SentenceTransformerFunctionBase, SparseEmbeddingFunction[TEXT] ): """Default local sparse embedding using SPLADE model. This class provides sparse vector embedding using the SPLADE (SParse Lexical AnD Expansion) model. SPLADE generates sparse, interpretable representations where each dimension corresponds to a vocabulary term with learned importance weights. It's ideal for lexical matching, BM25-style retrieval, and hybrid search scenarios. The default model is ``naver/splade-cocondenser-ensembledistil``, which is publicly available without authentication. It produces sparse vectors with thousands of dimensions but only hundreds of non-zero values, making them efficient for storage and retrieval while maintaining strong lexical matching. **Model Caching:** This class uses class-level caching to share the SPLADE model across all instances with the same configuration (model_source, device). This significantly reduces memory usage when creating multiple instances for different encoding types (query vs document). **Cache Management:** The class provides methods to manage the model cache: - ``clear_cache()``: Clear all cached models to free memory - ``get_cache_info()``: Get information about cached models - ``remove_from_cache(model_source, device)``: Remove a specific model from cache .. note:: **Why not use splade-v3?** The newer ``naver/splade-v3`` model is gated (requires access approval). We use ``naver/splade-cocondenser-ensembledistil`` instead. **To use splade-v3 (if you have access):** 1. Request access at https://huggingface.co/naver/splade-v3 2. Get your Hugging Face token from https://huggingface.co/settings/tokens 3. Set environment variable: .. code-block:: bash export HF_TOKEN="your_huggingface_token" 4. Or login programmatically: .. code-block:: python from huggingface_hub import login login(token="your_huggingface_token") 5. To use a custom SPLADE model, you can subclass this class and override the model_name in ``__init__``, or create your own implementation inheriting from ``SentenceTransformerFunctionBase`` and ``SparseEmbeddingFunction``. Args: model_source (Literal["huggingface", "modelscope"], optional): Model source. Defaults to ``"huggingface"``. ModelScope support may vary for SPLADE models. device (Optional[str], optional): Device to run the model on. Options: ``"cpu"``, ``"cuda"``, ``"mps"`` (for Apple Silicon), or ``None`` for automatic detection. Defaults to ``None``. encoding_type (Literal["query", "document"], optional): Encoding type. - ``"query"``: Optimize for search queries (default) - ``"document"``: Optimize for indexed documents **kwargs: Additional parameters (currently unused, for future extension). Attributes: model_name (str): Model identifier. model_source (str): The model source being used. device (str): The device the model is running on. Raises: ValueError: If the model cannot be loaded or input is invalid. TypeError: If input to ``embed()`` is not a string. RuntimeError: If model inference fails. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires the ``sentence-transformers`` package: ``pip install sentence-transformers`` - First run downloads the model (~100MB) from Hugging Face - Cache location: ``~/.cache/torch/sentence_transformers/`` - No API keys or authentication required - Sparse vectors have ~30k dimensions but only ~100-200 non-zero values - Best combined with dense embeddings for hybrid retrieval **SPLADE vs Dense Embeddings:** - **Dense**: Continuous semantic vectors, good for semantic similarity - **Sparse**: Lexical keyword-based, interpretable, good for exact matching - **Hybrid**: Combine both for best retrieval performance Examples: >>> # Memory-efficient: both instances share the same model (~200MB) >>> from zvec.extension import DefaultLocalSparseEmbedding >>> >>> # Query embedding >>> query_emb = DefaultLocalSparseEmbedding(encoding_type="query") >>> query_vec = query_emb.embed("machine learning algorithms") >>> type(query_vec) >>> len(query_vec) # Only non-zero dimensions 156 >>> # Document embedding (shares model with query_emb) >>> doc_emb = DefaultLocalSparseEmbedding(encoding_type="document") >>> doc_vec = doc_emb.embed("Machine learning is a subset of AI") >>> # Total memory: ~200MB (not 400MB) thanks to model caching >>> # Asymmetric retrieval example >>> query_vec = query_emb.embed("what causes aging fast") >>> doc_vec = doc_emb.embed( ... "UV-A light causes tanning, skin aging, and cataracts..." ... ) >>> >>> # Calculate similarity (dot product for sparse vectors) >>> similarity = sum( ... query_vec.get(k, 0) * doc_vec.get(k, 0) ... for k in set(query_vec) | set(doc_vec) ... ) >>> # Batch processing >>> queries = ["query 1", "query 2", "query 3"] >>> query_vecs = [query_emb.embed(q) for q in queries] >>> >>> documents = ["doc 1", "doc 2", "doc 3"] >>> doc_vecs = [doc_emb.embed(d) for d in documents] >>> # Inspecting sparse dimensions (output is sorted by indices) >>> query_vec = query_emb.embed("machine learning") >>> list(query_vec.items())[:5] # First 5 dimensions (by index) [(10, 0.45), (23, 0.87), (56, 0.32), (89, 1.12), (120, 0.65)] >>> >>> # Sort by weight to find most important terms >>> sorted_by_weight = sorted(query_vec.items(), key=lambda x: x[1], reverse=True) >>> top_5 = sorted_by_weight[:5] # Top 5 most important terms >>> top_5 [(1023, 1.45), (245, 1.23), (8901, 0.98), (5678, 0.87), (12034, 0.76)] >>> # Using GPU for faster inference >>> sparse_emb = DefaultLocalSparseEmbedding(device="cuda") >>> vector = sparse_emb.embed("natural language processing") >>> # Hybrid retrieval example (combining dense + sparse) >>> from zvec.extension import DefaultDenseEmbedding >>> dense_emb = DefaultDenseEmbedding() >>> sparse_emb = DefaultLocalSparseEmbedding() >>> >>> query = "deep learning neural networks" >>> dense_vec = dense_emb.embed(query) # [0.1, -0.3, 0.5, ...] >>> sparse_vec = sparse_emb.embed(query) # {12: 0.8, 45: 1.2, ...} >>> # Error handling >>> try: ... sparse_emb.embed("") # Empty string ... except ValueError as e: ... print(f"Error: {e}") Error: Input text cannot be empty or whitespace only >>> # Cache management >>> # Check cache status >>> info = DefaultLocalSparseEmbedding.get_cache_info() >>> print(f"Cached models: {info['cached_models']}") Cached models: 1 >>> >>> # Clear cache to free memory >>> DefaultLocalSparseEmbedding.clear_cache() >>> info = DefaultLocalSparseEmbedding.get_cache_info() >>> print(f"Cached models: {info['cached_models']}") Cached models: 0 >>> >>> # Remove specific model from cache >>> query_emb = DefaultLocalSparseEmbedding() # Creates CPU model >>> cuda_emb = DefaultLocalSparseEmbedding(device="cuda") # Creates CUDA model >>> info = DefaultLocalSparseEmbedding.get_cache_info() >>> print(f"Cached models: {info['cached_models']}") Cached models: 2 >>> >>> # Remove only CPU model >>> removed = DefaultLocalSparseEmbedding.remove_from_cache(device=None) >>> print(f"Removed: {removed}") True >>> info = DefaultLocalSparseEmbedding.get_cache_info() >>> print(f"Cached models: {info['cached_models']}") Cached models: 1 See Also: - ``SparseEmbeddingFunction``: Base class for sparse embeddings - ``DefaultDenseEmbedding``: Dense embedding with all-MiniLM-L6-v2 - ``QwenDenseEmbedding``: Alternative using Qwen API References: - SPLADE Paper: https://arxiv.org/abs/2109.10086 - Model: https://huggingface.co/naver/splade-cocondenser-ensembledistil """ # Class-level model cache: {(model_name, model_source, device): model} # Shared across all DefaultLocalSparseEmbedding instances to save memory _model_cache: ClassVar[dict] = {} @classmethod def clear_cache(cls) -> None: """Clear all cached SPLADE models from memory. This is useful for: - Freeing memory when models are no longer needed - Forcing a fresh model reload - Testing and debugging Examples: >>> # Clear cache to free memory >>> DefaultLocalSparseEmbedding.clear_cache() >>> # Or in tests to ensure fresh model loading >>> def test_something(): ... DefaultLocalSparseEmbedding.clear_cache() ... emb = DefaultLocalSparseEmbedding() ... # Test with fresh model """ cls._model_cache.clear() @classmethod def get_cache_info(cls) -> dict: """Get information about currently cached models. Returns: dict: Dictionary with cache statistics: - cached_models (int): Number of cached model instances - cache_keys (list): List of cache keys (model_name, model_source, device) Examples: >>> info = DefaultLocalSparseEmbedding.get_cache_info() >>> print(f"Cached models: {info['cached_models']}") Cached models: 2 >>> print(f"Cache keys: {info['cache_keys']}") Cache keys: [('naver/splade-cocondenser-ensembledistil', 'huggingface', None), ('naver/splade-cocondenser-ensembledistil', 'huggingface', 'cuda')] """ return { "cached_models": len(cls._model_cache), "cache_keys": list(cls._model_cache.keys()), } @classmethod def remove_from_cache( cls, model_source: str = "huggingface", device: Optional[str] = None ) -> bool: """Remove a specific model from cache. Args: model_source (str): Model source ("huggingface" or "modelscope"). Defaults to "huggingface". device (Optional[str]): Device identifier. Defaults to None. Returns: bool: True if model was found and removed, False otherwise. Examples: >>> # Remove CPU model from cache >>> removed = DefaultLocalSparseEmbedding.remove_from_cache() >>> print(f"Removed: {removed}") True >>> # Remove CUDA model from cache >>> removed = DefaultLocalSparseEmbedding.remove_from_cache(device="cuda") >>> print(f"Removed: {removed}") True """ model_name = "naver/splade-cocondenser-ensembledistil" cache_key = (model_name, model_source, device) if cache_key in cls._model_cache: del cls._model_cache[cache_key] return True return False def __init__( self, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, encoding_type: Literal["query", "document"] = "query", **kwargs, ): """Initialize with SPLADE model. Args: model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). Defaults to None (automatic detection). encoding_type (Literal["query", "document"]): Encoding type for embeddings. - "query": Optimize for search queries (default) - "document": Optimize for indexed documents This distinction is important for asymmetric retrieval tasks. **kwargs: Additional parameters (reserved for future use). Raises: ImportError: If sentence-transformers is not installed. ValueError: If model cannot be loaded. Note: Multiple instances with the same (model_source, device) configuration will share the same underlying model to save memory. Different instances can use different encoding_type settings while sharing the model. **Model Selection:** Uses ``naver/splade-cocondenser-ensembledistil`` instead of the newer ``naver/splade-v3`` because splade-v3 is a gated model requiring Hugging Face authentication. The cocondenser-ensembledistil variant: - Does not require authentication or API tokens - Is immediately available for all users - Provides comparable retrieval performance (~2% difference) - Avoids "Access to model is restricted" errors If you need splade-v3 and have obtained access, you can subclass this class and override the model_name parameter. Examples: >>> # Both instances share the same model (saves memory) >>> query_emb = DefaultLocalSparseEmbedding(encoding_type="query") >>> doc_emb = DefaultLocalSparseEmbedding(encoding_type="document") >>> # Only one model is loaded in memory """ # Use publicly available SPLADE model (no gated access required) # Note: naver/splade-v3 requires authentication, so we use the # cocondenser-ensembledistil variant which is publicly accessible model_name = "naver/splade-cocondenser-ensembledistil" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( self, model_name=model_name, model_source=model_source, device=device ) self._encoding_type = encoding_type self._extra_params = kwargs # Create cache key for this model configuration self._cache_key = (model_name, model_source, device) # Load model to ensure it's available (will use cache if exists) self._get_model() @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" return self._extra_params def __call__(self, input: str) -> SparseVectorType: """Make the embedding function callable.""" return self.embed(input) def embed(self, input: str) -> SparseVectorType: """Generate sparse embedding vector for the input text. This method uses the SPLADE model to convert input text into a sparse vector representation. The result is a dictionary where keys are dimension indices and values are importance weights (only non-zero values included). The embedding is optimized based on the ``encoding_type`` specified during initialization: "query" for search queries or "document" for indexed content. Args: input (str): Input text string to embed. Must be non-empty after stripping whitespace. Returns: SparseVectorType: A dictionary mapping dimension index to weight. Only non-zero dimensions are included. The dictionary is sorted by indices (keys) in ascending order for consistent output. Example: ``{10: 0.5, 245: 0.8, 1023: 1.2, 5678: 0.5}`` Raises: TypeError: If ``input`` is not a string. ValueError: If input is empty or whitespace-only. RuntimeError: If model inference fails. Examples: >>> # Query embedding >>> query_emb = DefaultLocalSparseEmbedding(encoding_type="query") >>> query_vec = query_emb.embed("machine learning") >>> isinstance(query_vec, dict) True Note: - First call may be slower due to model loading - Subsequent calls are much faster as the model stays in memory - GPU acceleration provides significant speedup - Sparse vectors are memory-efficient (only store non-zero values) """ if not isinstance(input, str): raise TypeError(f"Expected 'input' to be str, got {type(input).__name__}") input = input.strip() if not input: raise ValueError("Input text cannot be empty or whitespace only") try: model = self._get_model() # Use appropriate encoding method based on type if self._encoding_type == "document" and hasattr(model, "encode_document"): # Use document encoding sparse_matrix = model.encode_document([input]) elif hasattr(model, "encode_query"): # Use query encoding (default) sparse_matrix = model.encode_query([input]) else: # Fallback: manual implementation for older sentence-transformers return self._manual_sparse_encode(input) # Convert sparse matrix to dictionary # SPLADE returns shape [1, vocab_size] for single input # Check if it's a sparse matrix (duck typing - has toarray method) if hasattr(sparse_matrix, "toarray"): # Sparse matrix (CSR/CSC/etc.) - convert to dense array sparse_array = sparse_matrix[0].toarray().flatten() sparse_dict = { int(idx): float(val) for idx, val in enumerate(sparse_array) if val > 0 } else: # Dense array format (numpy array or similar) if isinstance(sparse_matrix, np.ndarray): sparse_array = sparse_matrix[0] else: sparse_array = sparse_matrix sparse_dict = { int(idx): float(val) for idx, val in enumerate(sparse_array) if val > 0 } # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) except Exception as e: if isinstance(e, (TypeError, ValueError)): raise raise RuntimeError(f"Failed to generate sparse embedding: {e!s}") from e def _manual_sparse_encode(self, input: str) -> SparseVectorType: """Fallback manual SPLADE encoding for older sentence-transformers. Args: input (str): Input text to encode. Returns: SparseVectorType: Sparse vector as dictionary. """ import torch model = self._get_model() # Tokenize input features = model.tokenize([input]) # Move to correct device features = {k: v.to(model.device) for k, v in features.items()} # Forward pass with no gradient with torch.no_grad(): embeddings = model.forward(features) # Get logits from model output # SPLADE models typically output 'token_embeddings' if isinstance(embeddings, dict) and "token_embeddings" in embeddings: logits = embeddings["token_embeddings"][0] # First batch item elif hasattr(embeddings, "token_embeddings"): logits = embeddings.token_embeddings[0] # Fallback: try to get first value elif isinstance(embeddings, dict): logits = next(iter(embeddings.values()))[0] else: logits = embeddings[0] # Apply SPLADE activation: log(1 + relu(x)) relu_log = torch.log(1 + torch.relu(logits)) # Max pooling over token dimension (reduce to vocab size) if relu_log.dim() > 1: sparse_vec, _ = torch.max(relu_log, dim=0) else: sparse_vec = relu_log # Convert to sparse dictionary (only non-zero values) sparse_vec_np = sparse_vec.cpu().numpy() sparse_dict = { int(idx): float(val) for idx, val in enumerate(sparse_vec_np) if val > 0 } # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) def _get_model(self): """Load or retrieve the SPLADE model from class-level cache. Returns: SentenceTransformer: The loaded SPLADE model instance. Raises: ImportError: If required packages are not installed. ValueError: If model cannot be loaded. Note: Models are cached at class level and shared across all instances with the same (model_name, model_source, device) configuration. This allows memory-efficient usage when creating multiple instances with different encoding_type settings. """ # Check class-level cache first if self._cache_key in self._model_cache: return self._model_cache[self._cache_key] # Use parent class method to load model model = super()._get_model() # Cache the model at class level self._model_cache[self._cache_key] = model return model ================================================ FILE: python/zvec/extension/sentence_transformer_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Literal, Optional from ..tool import require_module class SentenceTransformerFunctionBase: """Base class for Sentence Transformer functions (both dense and sparse). This base class provides common functionality for loading and managing sentence-transformers models from Hugging Face or ModelScope. It supports both dense models (e.g., all-MiniLM-L6-v2) and sparse models (e.g., SPLADE). This class is not meant to be used directly. Use concrete implementations: - ``SentenceTransformerEmbeddingFunction`` for dense embeddings - ``SentenceTransformerSparseEmbeddingFunction`` for sparse embeddings - ``DefaultDenseEmbedding`` for default dense embeddings - ``DefaultSparseEmbedding`` for default sparse embeddings Args: model_name (str): Model identifier or local path. model_source (Literal["huggingface", "modelscope"]): Model source. device (Optional[str]): Device to run the model on. Note: - This is an internal base class for code reuse - Subclasses should inherit from appropriate Protocol (Dense/Sparse) - Provides model loading and management functionality """ def __init__( self, model_name: str, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, ): """Initialize the base Sentence Transformer functionality. Args: model_name (str): Model identifier or local path. model_source (Literal["huggingface", "modelscope"]): Model source. device (Optional[str]): Device to run the model on. Raises: ValueError: If model_source is invalid. """ # Validate model_source if model_source not in ("huggingface", "modelscope"): raise ValueError( f"Invalid model_source: '{model_source}'. " "Must be 'huggingface' or 'modelscope'." ) self._model_name = model_name self._model_source = model_source self._device = device self._model = None @property def model_name(self) -> str: """str: The Sentence Transformer model name currently in use.""" return self._model_name @property def model_source(self) -> str: """str: The model source being used ("huggingface" or "modelscope").""" return self._model_source @property def device(self) -> str: """str: The device the model is running on.""" model = self._get_model() if model is not None: return str(model.device) return self._device or "cpu" def _get_model(self): """Load or retrieve the Sentence Transformer model. Returns: SentenceTransformer or SparseEncoder: The loaded model instance. Raises: ImportError: If required packages are not installed. ValueError: If model cannot be loaded. """ # Return cached model if exists if self._model is not None: return self._model # Load model try: sentence_transformers = require_module("sentence_transformers") if self._model_source == "modelscope": # Load from ModelScope require_module("modelscope") from modelscope.hub.snapshot_download import snapshot_download # Download model to cache model_dir = snapshot_download(self._model_name) # Load from local path self._model = sentence_transformers.SentenceTransformer( model_dir, device=self._device, trust_remote_code=True ) else: # Load from Hugging Face (default) self._model = sentence_transformers.SentenceTransformer( self._model_name, device=self._device, trust_remote_code=True ) return self._model except ImportError as e: if "modelscope" in str(e) and self._model_source == "modelscope": raise ImportError( "ModelScope support requires the 'modelscope' package. " "Please install it with: pip install modelscope" ) from e raise except Exception as e: raise ValueError( f"Failed to load Sentence Transformer model '{self._model_name}' " f"from {self._model_source}: {e!s}" ) from e def _is_sparse_model(self) -> bool: """Check if the loaded model is a sparse encoder (e.g., SPLADE). Returns: bool: True if model supports sparse encoding. """ model = self._get_model() # Check if model has sparse encoding methods return hasattr(model, "encode_query") or hasattr(model, "encode_document") ================================================ FILE: python/zvec/extension/sentence_transformer_rerank_function.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Literal, Optional from ..model.doc import Doc from ..tool import require_module from .rerank_function import RerankFunction from .sentence_transformer_function import SentenceTransformerFunctionBase class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): """Re-ranker using Sentence Transformer cross-encoder models for semantic re-ranking. This re-ranker leverages pre-trained cross-encoder models to perform deep semantic re-ranking of search results. It runs locally without API calls, supports GPU acceleration, and works with models from Hugging Face or ModelScope. Cross-encoder models evaluate query-document pairs jointly, providing more accurate relevance scores than bi-encoder (embedding-based) similarity. Args: query (str): Query text for semantic re-ranking. **Required**. topn (int, optional): Maximum number of documents to return after re-ranking. Defaults to 10. rerank_field (Optional[str], optional): Document field name to use as re-ranking input text. **Required** (e.g., "content", "title", "body"). model_name (str, optional): Cross-encoder model identifier or local path. Defaults to ``"cross-encoder/ms-marco-MiniLM-L6-v2"`` (MS MARCO MiniLM). Common options: - ``"cross-encoder/ms-marco-MiniLM-L6-v2"``: Lightweight, fast (~80MB, recommended) - ``"cross-encoder/ms-marco-MiniLM-L12-v2"``: Better accuracy (~120MB) - ``"BAAI/bge-reranker-base"``: BGE Reranker Base (~280MB) - ``"BAAI/bge-reranker-large"``: BGE Reranker Large (highest quality, ~560MB) model_source (Literal["huggingface", "modelscope"], optional): Model source. Defaults to ``"huggingface"``. - ``"huggingface"``: Load from Hugging Face Hub - ``"modelscope"``: Load from ModelScope (recommended for users in China) device (Optional[str], optional): Device to run the model on. Options: ``"cpu"``, ``"cuda"``, ``"mps"`` (for Apple Silicon), or ``None`` for automatic detection. Defaults to ``None``. batch_size (int, optional): Batch size for processing query-document pairs. Larger values speed up processing but use more memory. Defaults to ``32``. Attributes: query (str): The query text used for re-ranking. topn (int): Maximum number of documents to return. rerank_field (Optional[str]): Field name used for re-ranking input. model_name (str): The cross-encoder model being used. model_source (str): The model source ("huggingface" or "modelscope"). device (str): The device the model is running on. Raises: ValueError: If ``query`` is empty/None, ``rerank_field`` is None, or model cannot be loaded. TypeError: If input types are invalid. RuntimeError: If model inference fails. Note: - Requires Python 3.10, 3.11, or 3.12 - Requires ``sentence-transformers`` package: ``pip install sentence-transformers`` - For ModelScope support, also requires: ``pip install modelscope`` - First run downloads the model (~80-560MB depending on model) from chosen source - No API keys or network required after initial download - Cross-encoders are slower than bi-encoders but more accurate - GPU acceleration provides significant speedup (5-10x) **MS MARCO MiniLM-L6-v2 Model (Default):** The default model ``cross-encoder/ms-marco-MiniLM-L6-v2`` is a lightweight and efficient cross-encoder trained on MS MARCO dataset. It provides: - Fast inference speed (suitable for real-time applications) - Small model size (~80MB, quick to download) - Good balance between speed and accuracy - Trained on 500K+ query-document pairs - Public availability without authentication **For users in China:** If you encounter Hugging Face access issues, use ModelScope instead: .. code-block:: python # Recommended for users in China reranker = SentenceTransformerReRanker( query="机器学习算法", rerank_field="content", model_source="modelscope" ) Alternatively, use Hugging Face mirror: .. code-block:: bash export HF_ENDPOINT=https://hf-mirror.com Examples: >>> # Basic usage with default MS MARCO MiniLM model >>> from zvec.extension import SentenceTransformerReRanker >>> >>> reranker = SentenceTransformerReRanker( ... query="machine learning algorithms", ... topn=5, ... rerank_field="content" ... ) >>> >>> # Use in collection.query() >>> results = collection.query( ... data={"vector_field": query_vector}, ... reranker=reranker, ... topk=20 ... ) >>> # Using ModelScope for users in China >>> reranker = SentenceTransformerReRanker( ... query="深度学习", ... topn=10, ... rerank_field="content", ... model_source="modelscope" ... ) >>> # Using larger model for better quality >>> reranker = SentenceTransformerReRanker( ... query="neural networks", ... topn=5, ... rerank_field="content", ... model_name="BAAI/bge-reranker-large", ... device="cuda", ... batch_size=64 ... ) >>> # Direct rerank call (for testing) >>> query_results = { ... "vector1": [ ... Doc(id="1", score=0.9, fields={"content": "Machine learning is..."}), ... Doc(id="2", score=0.8, fields={"content": "Deep learning is..."}), ... ] ... } >>> reranked = reranker.rerank(query_results) >>> for doc in reranked: ... print(f"ID: {doc.id}, Score: {doc.score:.4f}") ID: 2, Score: 0.9234 ID: 1, Score: 0.8567 See Also: - ``RerankFunction``: Abstract base class for re-rankers - ``QwenReRanker``: Re-ranker using Qwen API - ``RrfReRanker``: Multi-vector re-ranker using RRF - ``WeightedReRanker``: Multi-vector re-ranker using weighted scores References: - MS MARCO Cross-Encoder: https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2 - BGE Reranker: https://huggingface.co/BAAI/bge-reranker-base - Cross-Encoder vs Bi-Encoder: https://www.sbert.net/examples/applications/cross-encoder/README.html """ def __init__( self, query: Optional[str] = None, topn: int = 10, rerank_field: Optional[str] = None, model_name: str = "cross-encoder/ms-marco-MiniLM-L6-v2", model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, batch_size: int = 32, ): """Initialize SentenceTransformerReRanker with query and configuration. Args: query (Optional[str]): Query text for semantic matching. Required. topn (int): Number of top results to return. rerank_field (Optional[str]): Document field for re-ranking input. model_name (str): Cross-encoder model identifier. model_source (Literal["huggingface", "modelscope"]): Model source. device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). batch_size (int): Batch size for processing query-document pairs. Raises: ValueError: If query is empty or model cannot be loaded. """ # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( self, model_name=model_name, model_source=model_source, device=device ) # Initialize rerank function RerankFunction.__init__(self, topn=topn, rerank_field=rerank_field) # Validate query if not query: raise ValueError("Query is required for DefaultLocalReRanker") self._query = query self._batch_size = batch_size # Load and validate cross-encoder model model = self._get_model() if not hasattr(model, "predict"): raise ValueError( f"Model '{model_name}' does not appear to be a cross-encoder model. " "Cross-encoder models should have a 'predict' method." ) self._model = model def _get_model(self): """Load or retrieve the CrossEncoder model. This overrides the base class method to load CrossEncoder instead of SentenceTransformer, as reranking requires cross-encoder models. Returns: CrossEncoder: The loaded cross-encoder model instance. Raises: ImportError: If required packages are not installed. ValueError: If model cannot be loaded. """ # Return cached model if exists if self._model is not None: return self._model # Load cross-encoder model try: sentence_transformers = require_module("sentence_transformers") if self._model_source == "modelscope": # Load from ModelScope require_module("modelscope") from modelscope.hub.snapshot_download import snapshot_download # Download model to cache model_dir = snapshot_download(self._model_name) # Load CrossEncoder from local path model = sentence_transformers.CrossEncoder( model_dir, device=self._device ) else: # Load CrossEncoder from Hugging Face (default) model = sentence_transformers.CrossEncoder( self._model_name, device=self._device ) return model except ImportError as e: if "modelscope" in str(e) and self._model_source == "modelscope": raise ImportError( "ModelScope support requires the 'modelscope' package. " "Please install it with: pip install modelscope" ) from e raise except Exception as e: raise ValueError( f"Failed to load CrossEncoder model '{self._model_name}' " f"from {self._model_source}: {e!s}" ) from e @property def query(self) -> str: """str: Query text used for semantic re-ranking.""" return self._query @property def batch_size(self) -> int: """int: Batch size for processing query-document pairs.""" return self._batch_size def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Re-rank documents using Sentence Transformer cross-encoder model. Evaluates each query-document pair using the cross-encoder model to compute relevance scores. Documents are then sorted by these scores and the top-k results are returned. Args: query_results (dict[str, list[Doc]]): Mapping from vector field names to lists of retrieved documents. Documents from all fields are deduplicated and re-ranked together. Returns: list[Doc]: Re-ranked documents (up to ``topn``) with updated ``score`` fields containing relevance scores from the cross-encoder model. Raises: ValueError: If no valid documents are found or model inference fails. Note: - Duplicate documents (same ID) across fields are processed once - Documents with empty/missing ``rerank_field`` content are skipped - Returned scores are logits from the cross-encoder model - Higher scores indicate higher relevance - Processing time is O(n) where n is the number of documents Examples: >>> reranker = SentenceTransformerReRanker( ... query="machine learning", ... topn=3, ... rerank_field="content" ... ) >>> query_results = { ... "vector1": [ ... Doc(id="1", score=0.9, fields={"content": "ML basics"}), ... Doc(id="2", score=0.8, fields={"content": "DL tutorial"}), ... ] ... } >>> reranked = reranker.rerank(query_results) >>> len(reranked) <= 3 True """ if not query_results: return [] # Collect and deduplicate documents id_to_doc: dict[str, Doc] = {} doc_ids: list[str] = [] contents: list[str] = [] for _, query_result in query_results.items(): for doc in query_result: doc_id = doc.id if doc_id in id_to_doc: continue # Extract text content from specified field field_value = doc.field(self.rerank_field) rank_content = str(field_value).strip() if field_value else "" if not rank_content: continue id_to_doc[doc_id] = doc doc_ids.append(doc_id) contents.append(rank_content) if not contents: raise ValueError("No documents to rerank") try: # Use standard cross-encoder predict method pairs = [[self.query, content] for content in contents] scores = self._model.predict( pairs, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, ) # Convert to float list if needed if hasattr(scores, "tolist"): scores = scores.tolist() else: scores = [float(s) for s in scores] except Exception as e: raise RuntimeError(f"Failed to compute rerank scores: {e!s}") from e # Create scored documents scored_docs = [ (doc_ids[i], id_to_doc[doc_ids[i]], scores[i]) for i in range(len(doc_ids)) ] # Sort by score (descending) and take top-k scored_docs.sort(key=lambda x: x[2], reverse=True) top_scored_docs = scored_docs[: self.topn] # Build result list with updated scores results: list[Doc] = [] for _, doc, score in top_scored_docs: new_doc = doc._replace(score=score) results.append(new_doc) return results ================================================ FILE: python/zvec/model/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from .collection import Collection from .doc import Doc from .param.vector_query import VectorQuery from .schema.collection_schema import CollectionSchema from .schema.field_schema import FieldSchema __all__ = ["Collection", "CollectionSchema", "Doc", "FieldSchema", "VectorQuery"] ================================================ FILE: python/zvec/model/collection.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Optional, Union, overload from _zvec import _Collection from ..executor import QueryContext, QueryExecutorFactory from ..extension import ReRanker from ..typing import Status from .convert import convert_to_cpp_doc, convert_to_py_doc from .doc import Doc from .param import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, HnswIndexParam, IndexOption, InvertIndexParam, IVFIndexParam, OptimizeOption, ) from .param.vector_query import VectorQuery from .schema import CollectionSchema, CollectionStats, FieldSchema __all__ = ["Collection"] _VECTOR_INDEX_TYPES = (HnswIndexParam, IVFIndexParam, FlatIndexParam) class Collection: """Represents an opened collection in Zvec. A `Collection` provides methods for data definition (DDL), data manipulation (DML), and querying (DQL). It is obtained via `create_and_open()` or `open()`. This class is not meant to be instantiated directly; use factory functions instead. """ def __init__(self, obj: _Collection): self._obj = obj self._schema = None self._querier = None @classmethod def _from_core(cls, core_collection: _Collection) -> Collection: if not core_collection: raise ValueError("Collection is None") inst = cls.__new__(cls) inst._obj = core_collection schema = CollectionSchema._from_core(core_collection.Schema()) inst._schema = schema inst._querier = QueryExecutorFactory.create(schema) return inst @property def path(self) -> str: """str: The filesystem path of the collection.""" return self._obj.Path() @property def option(self) -> CollectionOption: """CollectionOption: The options used to open the collection.""" return self._obj.Options() @property def schema(self) -> CollectionSchema: """CollectionSchema: The schema defining the structure of the collection.""" return self._schema @property def stats(self) -> CollectionStats: """CollectionStats: Runtime statistics about the collection (e.g., doc count, size).""" return self._obj.Stats() # ========== Collection DDL Methods ========== def destroy(self) -> None: """Permanently delete the collection from disk. Warning: This operation is irreversible. All data will be lost. """ self._obj.Destroy() def flush(self) -> None: """Force all pending writes to disk. Ensures durability of recent inserts/updates. """ self._obj.Flush() # ========== Index DDL Methods ========== def create_index( self, field_name: str, index_param: Union[ HnswIndexParam, IVFIndexParam, FlatIndexParam, InvertIndexParam ], option: IndexOption = IndexOption(), ) -> None: """Create an index on a field. Vector index types (HNSW, IVF, FLAT) can only be applied to vector fields. Inverted index (`InvertIndexParam`) is for scalar fields. Args: field_name (str): Name of the field to index. index_param (Union[HnswIndexParam, IVFIndexParam, FlatIndexParam, InvertIndexParam]): Index configuration. option (Optional[IndexOption], optional): Index creation options. Defaults to ``IndexOption()``. Raises: ValueError: If a vector index is applied to a non-vector field. """ if index_param in _VECTOR_INDEX_TYPES and not self.schema.vector(field_name): supported_types = ", ".join(cls.__name__ for cls in _VECTOR_INDEX_TYPES) raise ValueError( f"Cannot apply vector index to non-vector field '{field_name}'. " f"The field must be of vector type to use index types like {supported_types}." ) self._obj.CreateIndex(field_name, index_param, option) self._schema = CollectionSchema._from_core(self._obj.Schema()) def drop_index(self, field_name: str) -> None: """Remove the index from a field. Args: field_name (str): Name of the indexed field. """ self._obj.DropIndex(field_name) self._schema = CollectionSchema._from_core(self._obj.Schema()) def optimize(self, option: OptimizeOption = OptimizeOption()) -> None: """Optimize the collection (e.g., merge segments, rebuild index). Args: option (Optional[OptimizeOption], optional): Optimization options. Defaults to ``OptimizeOption()``. """ self._obj.Optimize(option) # ========== COLUMN DDL Methods ========== def add_column( self, field_schema: FieldSchema, expression: str = "", option: AddColumnOption = AddColumnOption(), ) -> None: """Add a new column to the collection. The column is populated using the provided expression (e.g., SQL-like formula). Args: field_schema (FieldSchema): Schema definition for the new column. expression (str): Expression to compute values for existing documents. option (Optional[AddColumnOption], optional): Options for the operation. Defaults to ``AddColumnOption()``. """ self._obj.AddColumn(field_schema._get_object(), expression, option) self._schema = CollectionSchema._from_core(self._obj.Schema()) def drop_column(self, field_name: str) -> None: """Remove a column from the collection. Args: field_name (str): Name of the column to drop. """ self._obj.DropColumn(field_name) self._schema = CollectionSchema._from_core(self._obj.Schema()) def alter_column( self, old_name: str, new_name: Optional[str] = None, field_schema: Optional[FieldSchema] = None, option: AlterColumnOption = AlterColumnOption(), ) -> None: """Rename a column, update its schema. This method supports three atomic operations: 1. Rename only (when `field_schema` is None). 2. Modify schema only (when `new_name` is None or empty string). Args: old_name (str): The current name of the column to be altered. new_name (Optional[str]): The new name for the column. - If provided and non-empty, the column will be renamed. - If `None` or empty string, no rename occurs. field_schema (Optional[FieldSchema]): The new schema definition. - If provided, the column's type, dimension, or other properties will be updated. - If `None`, only renaming (if requested) is performed. option (AlterColumnOption, optional): Options controlling the alteration behavior. Defaults to ``AlterColumnOption()``. **Limitation**: This operation **only supports scalar numeric columns**. such as: - `DOUBLE`, `FLOAT`, - `INT32`, `INT64`, `UINT32`, `UINT64` Note: - Schema modification may trigger data migration or index rebuild. Examples: >>> # Rename column only >>> results = collection.alter_column(old_name="id", new_name="doc_id") >>> # Modify schema only >>> new_schema = FieldSchema(name="doc_id", dtype=DataType.INT64) >>> collection.alter_column("id", field_schema=new_schema) """ self._obj.AlterColumn( old_name, new_name or "", field_schema._get_object() if field_schema else None, option, ) self._schema = CollectionSchema._from_core(self._obj.Schema()) # ========== Collection DDL Methods ========== @overload def insert(self, docs: Doc) -> Status: pass @overload def insert(self, docs: list[Doc]) -> list[Status]: pass def insert(self, docs: Union[Doc, list[Doc]]) -> Union[Status, list[Status]]: """Insert new documents into the collection. Documents must have unique IDs and conform to the schema. Args: docs (Union[Doc, list[Doc]]): One or more documents to insert. Returns: Union[Status, list[Status]]: If a single Doc was given, returns its Status; if a list was given, returns a list of Status objects. """ is_single = isinstance(docs, Doc) doc_list = [docs] if is_single else docs results = self._obj.Insert( [convert_to_cpp_doc(doc, self.schema) for doc in doc_list] ) return results[0] if is_single else results @overload def upsert(self, docs: Doc) -> Status: pass @overload def upsert(self, docs: list[Doc]) -> list[Status]: pass def upsert(self, docs: Union[Doc, list[Doc]]) -> Union[Status, list[Status]]: """Insert new documents or update existing ones by ID. Args: docs (Union[Doc, list[Doc]]): Documents to upsert. Returns: Union[Status, list[Status]]: If a single Doc was given, returns its Status; if a list was given, returns a list of Status objects. """ is_single = isinstance(docs, Doc) doc_list = [docs] if is_single else docs results = self._obj.Upsert( [convert_to_cpp_doc(doc, self.schema) for doc in doc_list] ) return results[0] if is_single else results @overload def update(self, docs: Doc) -> Status: pass @overload def update(self, docs: list[Doc]) -> list[Status]: pass def update(self, docs: Union[Doc, list[Doc]]) -> Union[Status, list[Status]]: """Update existing documents by ID. Only specified fields are updated; others remain unchanged. Args: docs (Union[Doc, list[Doc]]): Documents containing updated fields. Returns: Union[Status, list[Status]]: If a single Doc was given, returns its Status; if a list was given, returns a list of Status objects. """ is_single = isinstance(docs, Doc) doc_list = [docs] if is_single else docs results = self._obj.Update( [convert_to_cpp_doc(doc, self.schema) for doc in doc_list] ) return results[0] if is_single else results @overload def delete(self, ids: str) -> Status: pass @overload def delete(self, ids: list[str]) -> list[Status]: pass def delete(self, ids: Union[str, list[str]]) -> Union[Status, list[Status]]: """Delete documents by ID. Args: ids (Union[str, list[str]]): One or more document IDs to delete. Returns: Union[Status, list[Status]]: If a single id was given, returns its Status; if a list was given, returns a list of Status objects. """ is_single = isinstance(ids, str) id_list = [ids] if isinstance(ids, str) else ids results = self._obj.Delete(id_list) return results[0] if is_single else results def delete_by_filter(self, filter: str) -> None: """Delete documents matching a filter expression. Args: filter (str): Boolean expression (e.g., ``"age > 30"``). """ self._obj.DeleteByFilter(filter) # ========== Collection DQL-fetch Methods ========== def fetch(self, ids: Union[str, list[str]]) -> dict[str, Doc]: """Retrieve documents by ID. Args: ids (Union[str, list[str]]): Document IDs to fetch. Returns: dict[str, Doc]: Mapping from ID to document. Missing IDs are omitted. """ ids = [ids] if isinstance(ids, str) else ids docs = self._obj.Fetch(ids) return { doc_id: py_doc for doc_id, core_doc in docs.items() if (py_doc := convert_to_py_doc(core_doc, self.schema)) is not None } # ========== Collection DQL-Query Methods ========== def query( self, vectors: Optional[Union[VectorQuery, list[VectorQuery]]] = None, *, topk: int = 10, filter: Optional[str] = None, include_vector: bool = False, output_fields: Optional[list[str]] = None, reranker: Optional[ReRanker] = None, ) -> list[Doc]: """Perform vector similarity search with optional filtering and re-ranking. At least one `VectorQuery` must be provided. Args: vectors (Optional[Union[VectorQuery, list[VectorQuery]]], optional): One or more vector queries. Defaults to None. topk (int, optional): Number of nearest neighbors to return. Defaults to 10. filter (Optional[str], optional): Boolean expression to pre-filter candidates. Defaults to None. include_vector (bool, optional): Whether to include vector data in results. Defaults to False. output_fields (Optional[list[str]], optional): Scalar fields to include. If None, all fields are returned. Defaults to None. reranker (Optional[ReRanker], optional): Re-ranker to refine results. Defaults to None. Returns: list[Doc]: Top-k matching documents, sorted by relevance score. Examples: >>> from zvec import VectorQuery >>> results = collection.query( ... vectors=VectorQuery("embedding", vector=[0.1, 0.2]), ... topk=5, ... filter="category == 'tech'", ... output_fields=["title", "url"] ... ) """ ctx = QueryContext( topk=topk, filter=filter, queries=[vectors] if isinstance(vectors, VectorQuery) else vectors, include_vector=include_vector, output_fields=output_fields, reranker=reranker, ) return self._querier.execute(ctx, self._obj) ================================================ FILE: python/zvec/model/convert.py ================================================ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from _zvec import _Doc from .doc import Doc from .schema import CollectionSchema def convert_to_cpp_doc(doc: Doc, collection_schema: CollectionSchema) -> _Doc: if not doc or not collection_schema: return None _doc = _Doc() # set pk _doc.set_pk(doc.id) # set scalar fields for k, v in doc.fields.items(): field_schema = collection_schema.field(k) if not field_schema: raise ValueError( f"schema validate failed: {k} not found in collection schema" ) _doc.set_any(k, field_schema._get_object(), v) # set vector fields for k, v in doc.vectors.items(): vector_schema = collection_schema.vector(k) if not vector_schema: raise ValueError( f"schema validate failed: {k} not found in collection schema" ) _doc.set_any(k, vector_schema._get_object(), v) return _doc def convert_to_py_doc(doc: _Doc, collection_schema: CollectionSchema) -> Doc: if not doc or not collection_schema: return None data_tuple = doc.get_all(collection_schema._get_object()) return Doc._from_tuple(data_tuple) ================================================ FILE: python/zvec/model/doc.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import json from typing import Any, Optional from ..common import VectorType __all__ = [ "Doc", ] class Doc: """Represents a retrieved document with optional metadata, fields, and vectors. This immutable data class encapsulates the result of a search or retrieval operation. It includes the document ID, relevance score (if applicable), scalar fields, and vector embeddings. During initialization, any `numpy.ndarray` in `vectors` is automatically converted to a plain Python list for JSON serialization and immutability. Attributes: id (str): Unique identifier of the document. score (Optional[float], optional): Relevance score from search. Defaults to None. vectors (Optional[dict[str, VectorType]], optional): Named vector embeddings associated with the document. Values are converted to lists if originally `np.ndarray`. Defaults to None. fields (Optional[dict[str, Any]], optional): Scalar metadata fields (e.g., title, timestamp). Defaults to None. Examples: >>> import numpy as np >>> import zvec >>> doc = zvec.Doc( ... id="doc1", ... score=0.95, ... vectors={"emb": np.array([0.1, 0.2, 0.3])}, ... fields={"title": "Hello World"} ... ) >>> print(doc.vector("emb")) [0.1, 0.2, 0.3] >>> print(doc.has_field("title")) True """ __slots__ = ("id", "score", "vectors", "fields") def __init__( self, id: str, score: Optional[float] = None, vectors: Optional[dict[str, VectorType]] = None, fields: Optional[dict[str, Any]] = None, ): self.id = id self.score = score self.vectors = vectors or {} self.fields = fields or {} def has_field(self, name: str) -> bool: """Check if the document contains a scalar field with the given name. Args: name (str): Name of the field to check. Returns: bool: True if the field exists, False otherwise. """ return name in self.fields def has_vector(self, name: str) -> bool: """Check if the document contains a vector with the given name. Args: name (str): Name of the vector to check. Returns: bool: True if the vector exists, False otherwise. """ return name in self.vectors def vector(self, name: str): """Get a vector by name. Args: name (str): Name of the vector. Returns: Any: The vector (as a list) if it exists, otherwise None. """ return self.vectors and self.vectors.get(name) def field(self, name: str): """Get a scalar field by name. Args: name (str): Name of the field. Returns: Any: The field value if it exists, otherwise None. """ return self.fields and self.fields.get(name) def vector_names(self) -> list[str]: """Get the list of all vector names in this document. Returns: list[str]: A list of vector field names. Empty if no vectors. """ return [] if not self.vectors else list(self.vectors.keys()) def field_names(self) -> list[str]: """Get the list of all scalar field names in this document. Returns: list[str]: A list of field names. Empty if no fields. """ return [] if not self.fields else list(self.fields.keys()) def __repr__(self) -> str: try: schema = { "id": self.id, "score": self.score, "fields": self.fields, "vectors": self.vectors, } return json.dumps(schema, indent=2, ensure_ascii=False) except Exception as e: return f"" def _replace(self, **changes): new_tuple = ( changes.get("id", self.id), changes.get("score", self.score), changes.get("fields", self.fields.copy() if self.fields else None), changes.get("vectors", self.vectors.copy() if self.vectors else None), ) return type(self)._from_tuple(new_tuple) @classmethod def _from_tuple( cls, data_tuple: tuple[str, float, dict[str, Any], dict[str, VectorType]] ): obj = object.__new__(cls) obj.id = data_tuple[0] obj.score = data_tuple[1] obj.fields = data_tuple[2] or {} vectors = data_tuple[3] if vectors is not None: obj.vectors = { name: (vec.tolist() if hasattr(vec, "tolist") else vec) for name, vec in vectors.items() } else: obj.vectors = {} return obj ================================================ FILE: python/zvec/model/param/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from _zvec.param import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, HnswIndexParam, HnswQueryParam, HnswRabitqIndexParam, HnswRabitqQueryParam, IndexOption, InvertIndexParam, IVFIndexParam, IVFQueryParam, OptimizeOption, ) __all__ = [ "AddColumnOption", "AlterColumnOption", "CollectionOption", "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", "HnswRabitqIndexParam", "HnswRabitqQueryParam", "IVFIndexParam", "IVFQueryParam", "IndexOption", "InvertIndexParam", "OptimizeOption", ] ================================================ FILE: python/zvec/model/param/__init__.pyi ================================================ """ This module contains the params of Zvec """ from __future__ import annotations import collections import typing import _zvec.typing __all__: list[str] = [ "AddColumnOption", "AlterColumnOption", "CollectionOption", "FlatIndexParam", "HnswIndexParam", "HnswQueryParam", "HnswRabitqIndexParam", "HnswRabitqQueryParam", "IVFIndexParam", "IVFQueryParam", "IndexOption", "IndexParam", "InvertIndexParam", "OptimizeOption", "QueryParam", "SegmentOption", "VectorIndexParam", ] class AddColumnOption: """ Options for adding a new column to a collection. Attributes: concurrency (int): Number of threads to use when backfilling data for the new column. If 0, auto-detect is used. Default is 0. Examples: >>> opt = AddColumnOption(concurrency=1) >>> print(opt.concurrency) 1 """ def __getstate__(self) -> tuple: ... def __init__(self, concurrency: typing.SupportsInt = 0) -> None: """ Constructs an AddColumnOption instance. Args: concurrency (int, optional): Number of threads for data backfill. 0 means auto-detect. Defaults to 0. """ def __setstate__(self, arg0: tuple) -> None: ... @property def concurrency(self) -> int: """ int: Number of threads used when adding a column (0 = auto). """ class AlterColumnOption: """ Options for altering an existing column (e.g., changing index settings). Attributes: concurrency (int): Number of threads to use during the alteration process. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = AlterColumnOption(concurrency=1) >>> print(opt.concurrency) 1 """ def __getstate__(self) -> tuple: ... def __init__(self, concurrency: typing.SupportsInt = 0) -> None: """ Constructs an AlterColumnOption instance. Args: concurrency (int, optional): Number of threads for column alteration. 0 means auto-detect. Defaults to 0. """ def __setstate__(self, arg0: tuple) -> None: ... @property def concurrency(self) -> int: """ int: Number of threads used when altering a column (0 = auto). """ class CollectionOption: """ Options for opening or creating a collection. Attributes: read_only (bool): Whether the collection is opened in read-only mode. Default is False. enable_mmap (bool): Whether to use memory-mapped I/O for data files. Default is True. Examples: >>> opt = CollectionOption(read_only=True, enable_mmap=False) >>> print(opt.read_only) True """ def __getstate__(self) -> tuple: ... def __init__(self, read_only: bool = False, enable_mmap: bool = True) -> None: """ Constructs a CollectionOption instance. Args: read_only (bool, optional): Open collection in read-only mode. Defaults to False. enable_mmap (bool, optional): Enable memory-mapped I/O. Defaults to True. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @property def enable_mmap(self) -> bool: ... @property def read_only(self) -> bool: ... class FlatIndexParam(VectorIndexParam): """ Parameters for configuring a flat (brute-force) index. A flat index performs exact nearest neighbor search by comparing the query vector against all vectors in the collection. It is simple, accurate, and suitable for small to medium datasets or as a baseline. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Use ``QuantizeType.UNDEFINED`` to disable quantization. Default is ``QuantizeType.UNDEFINED``. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = FlatIndexParam( ... metric_type=MetricType.L2, ... quantize_type=QuantizeType.FP16 ... ) >>> print(params) {'metric_type': 'L2', 'quantize_type': 'FP16'} """ def __getstate__(self) -> tuple: ... def __init__( self, metric_type: _zvec.typing.MetricType = ..., quantize_type: _zvec.typing.QuantizeType = ..., ) -> None: """ Constructs a FlatIndexParam instance. Args: metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ class HnswIndexParam(VectorIndexParam): """ Parameters for configuring an HNSW (Hierarchical Navigable Small World) index. HNSW is a graph-based approximate nearest neighbor search index. This class encapsulates its construction hyperparameters. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). m (int): Number of bi-directional links created for every new element during construction. Higher values improve accuracy but increase memory usage and construction time. Default is 50. ef_construction (int): Size of the dynamic candidate list for nearest neighbors during index construction. Larger values yield better graph quality at the cost of slower build time. Default is 500. quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Default is `QuantizeType.UNDEFINED` to disable quantization. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = HnswIndexParam( ... metric_type=MetricType.COSINE, ... m=16, ... ef_construction=200, ... quantize_type=QuantizeType.INT8 ... ) >>> print(params) {'metric_type': 'IP', 'm': 16, 'ef_construction': 200, 'quantize_type': 'INT8'} """ def __getstate__(self) -> tuple: ... def __init__( self, metric_type: _zvec.typing.MetricType = ..., m: typing.SupportsInt = 50, ef_construction: typing.SupportsInt = 500, quantize_type: _zvec.typing.QuantizeType = ..., ) -> None: ... def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def ef_construction(self) -> int: """ int: Candidate list size during index construction. """ @property def m(self) -> int: """ int: Maximum number of neighbors per node in upper layers. """ class HnswQueryParam(QueryParam): """ Query parameters for HNSW (Hierarchical Navigable Small World) index. Controls the trade-off between search speed and accuracy via the `ef` parameter. Attributes: type (IndexType): Always ``IndexType.HNSW``. ef (int): Size of the dynamic candidate list during search. Larger values improve recall but slow down search. Default is 300. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. Examples: >>> params = HnswQueryParam(ef=300) >>> print(params.ef) 300 >>> print(params.to_dict() if hasattr(params, 'to_dict') else params) {"type":"HNSW", "ef":300} """ def __getstate__(self) -> tuple: ... def __init__( self, ef: typing.SupportsInt = 300, radius: typing.SupportsFloat = 0.0, is_linear: bool = False, is_using_refiner: bool = False, ) -> None: """ Constructs an HnswQueryParam instance. Args: ef (int, optional): Search-time candidate list size. Higher values improve accuracy. Defaults to 300. radius (float, optional): Search radius for range queries. Default is 0.0. is_linear (bool, optional): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @property def ef(self) -> int: """ int: Size of the dynamic candidate list during HNSW search. """ class HnswRabitqIndexParam(VectorIndexParam): """ Parameters for configuring an HNSW (Hierarchical Navigable Small World) index with RabitQ quantization. HNSW is a graph-based approximate nearest neighbor search index. RabitQ is a quantization method that provides high compression with minimal accuracy loss. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). total_bits (int): Total bits for RabitQ quantization. Default is 7. num_clusters (int): Number of clusters for RabitQ. Default is 16. m (int): Number of bi-directional links created for every new element during construction. Higher values improve accuracy but increase memory usage and construction time. Default is 50. ef_construction (int): Size of the dynamic candidate list for nearest neighbors during index construction. Larger values yield better graph quality at the cost of slower build time. Default is 500. sample_count (int): Sample count for RabitQ training. Default is 0. Examples: >>> from zvec.typing import MetricType >>> params = HnswRabitqIndexParam( ... metric_type=MetricType.COSINE, ... total_bits=8, ... num_clusters=256, ... m=16, ... ef_construction=200, ... sample_count=10000 ... ) >>> print(params) {'metric_type': 'COSINE', 'total_bits': 8, 'num_clusters': 256, 'm': 16, 'ef_construction': 200, 'sample_count': 10000} """ def __getstate__(self) -> tuple: ... def __init__( self, metric_type: _zvec.typing.MetricType = ..., total_bits: typing.SupportsInt = 7, num_clusters: typing.SupportsInt = 16, m: typing.SupportsInt = 50, ef_construction: typing.SupportsInt = 500, sample_count: typing.SupportsInt = 0, ) -> None: ... def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def ef_construction(self) -> int: """ int: Candidate list size during index construction. """ @property def m(self) -> int: """ int: Maximum number of neighbors per node. """ @property def total_bits(self) -> int: """ int: Total bits for RabitQ quantization. """ @property def num_clusters(self) -> int: """ int: Number of clusters for RabitQ. """ @property def sample_count(self) -> int: """ int: Sample count for RabitQ training. """ class HnswRabitqQueryParam(QueryParam): """ Query parameters for HNSW index with RabitQ quantization. Controls the trade-off between search speed and accuracy via the `ef` parameter. Attributes: type (IndexType): Always ``IndexType.HNSW_RABITQ``. ef (int): Size of the dynamic candidate list during search. Larger values improve recall but slow down search. Default is 300. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. Examples: >>> params = HnswRabitqQueryParam(ef=300) >>> print(params.ef) 300 """ def __getstate__(self) -> tuple: ... def __init__( self, ef: typing.SupportsInt = 300, radius: typing.SupportsFloat = 0.0, is_linear: bool = False, is_using_refiner: bool = False, ) -> None: """ Constructs an HnswRabitqQueryParam instance. Args: ef (int, optional): Search-time candidate list size. Higher values improve accuracy. Defaults to 300. radius (float, optional): Search radius for range queries. Default is 0.0. is_linear (bool, optional): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @property def ef(self) -> int: """ int: Size of the dynamic candidate list during HNSW search. """ class IVFIndexParam(VectorIndexParam): """ Parameters for configuring an IVF (Inverted File Index) index. IVF partitions the vector space into clusters (inverted lists). At query time, only a subset of clusters is searched, providing a trade-off between speed and accuracy. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). n_list (int): Number of clusters (inverted lists) to partition the dataset into. If set to 0, the system will auto-select a reasonable value based on data size. Default is 0 (auto). n_iters (int): Number of iterations for k-means clustering during index training. Higher values yield more stable centroids. Default is 10. use_soar (bool): Whether to enable SOAR (Scalable Optimized Adaptive Routing) for improved IVF search performance. Default is False. quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Default is ``QuantizeType.UNDEFINED``. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = IVFIndexParam( ... metric_type=MetricType.COSINE, ... n_list=100, ... n_iters=15, ... use_soar=True, ... quantize_type=QuantizeType.INT8 ... ) >>> print(params.n_list) 100 """ def __getstate__(self) -> tuple: ... def __init__( self, metric_type: _zvec.typing.MetricType = ..., n_list: typing.SupportsInt = 0, n_iters: typing.SupportsInt = 10, use_soar: bool = False, quantize_type: _zvec.typing.QuantizeType = ..., ) -> None: """ Constructs an IVFIndexParam instance. Args: metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. n_list (int, optional): Number of inverted lists (clusters). Set to 0 for auto. Defaults to 0. n_iters (int, optional): Number of k-means iterations during training. Defaults to 10. use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def n_iters(self) -> int: """ int: Number of k-means iterations during training. """ @property def n_list(self) -> int: """ int: Number of inverted lists (0 = auto). """ @property def use_soar(self) -> bool: """ bool: Whether SOAR optimization is enabled. """ class IVFQueryParam(QueryParam): """ Query parameters for IVF (Inverted File Index) index. Controls how many inverted lists (`nprobe`) to visit during search. Attributes: type (IndexType): Always ``IndexType.IVF``. nprobe (int): Number of closest clusters (inverted lists) to search. Higher values improve recall but increase latency. Default is 10. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. Examples: >>> params = IVFQueryParam(nprobe=20) >>> print(params.nprobe) 20 """ def __getstate__(self) -> tuple: ... def __init__(self, nprobe: typing.SupportsInt = 10) -> None: """ Constructs an IVFQueryParam instance. Args: nprobe (int, optional): Number of inverted lists to probe during search. Higher values improve accuracy. Defaults to 10. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @property def nprobe(self) -> int: """ int: Number of inverted lists to search during IVF query. """ class IndexOption: """ Options for creating an index. Attributes: concurrency (int): Number of threads to use during index creation. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = IndexOption(concurrency=4) >>> print(opt.concurrency) 4 """ def __getstate__(self) -> tuple: ... def __init__(self, concurrency: typing.SupportsInt = 0) -> None: """ Constructs an IndexOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. """ def __setstate__(self, arg0: tuple) -> None: ... @property def concurrency(self) -> int: """ int: Number of threads used for index creation (0 = auto). """ class IndexParam: """ Base class for all index parameter configurations. This abstract base class defines the common interface for index types. It should not be instantiated directly; use derived classes instead. Attributes: type (IndexType): The type of the index (e.g., HNSW, FLAT, INVERT). """ __hash__: typing.ClassVar[None] = None def __eq__(self, arg0: typing.Any) -> bool: ... def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... def clone(self) -> IndexParam: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def type(self) -> _zvec.typing.IndexType: """ IndexType: The type of the index. """ class InvertIndexParam(IndexParam): """ Parameters for configuring an invert index. This class controls whether range query optimization is enabled for invert index structures. Attributes: type (IndexType): Always `IndexType.INVERTED`. enable_range_optimization (bool): Whether range optimization is enabled. enable_extended_wildcard (bool): Whether extended wildcard (suffix and infix) search is enabled. Examples: >>> params = InvertIndexParam(enable_range_optimization=True, enable_extended_wildcard=False) >>> print(params.enable_range_optimization) True >>> print(params.enable_extended_wildcard) False >>> config = params.to_dict() >>> print(config) {'enable_range_optimization': True, 'enable_extended_wildcard': False} """ def __getstate__(self) -> tuple: ... def __init__( self, enable_range_optimization: bool = False, enable_extended_wildcard: bool = False, ) -> None: """ Constructs an InvertIndexParam instance. Args: enable_range_optimization (bool, optional): If True, enables range query optimization for the invert index. Defaults to False. enable_extended_wildcard (bool, optional): If True, enables extended wildcard search including suffix and infix patterns. Defaults to False. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def enable_extended_wildcard(self) -> bool: """ bool: Whether extended wildcard (suffix and infix) search is enabled. Note: Prefix search is always enabled regardless of this setting. """ @property def enable_range_optimization(self) -> bool: """ bool: Whether range optimization is enabled for this inverted index. """ class OptimizeOption: """ Options for optimizing a collection (e.g., merging segments). Attributes: concurrency (int): Number of threads to use during optimization. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = OptimizeOption(concurrency=2) >>> print(opt.concurrency) 2 """ def __getstate__(self) -> tuple: ... def __init__(self, concurrency: typing.SupportsInt = 0) -> None: """ Constructs an OptimizeOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. """ def __setstate__(self, arg0: tuple) -> None: ... @property def concurrency(self) -> int: """ int: Number of threads used for optimization (0 = auto). """ class QueryParam: """ Base class for all query parameter configurations. This abstract base class defines common query settings such as search radius and whether to force linear (brute-force) search. It should not be instantiated directly; use derived classes like `HnswQueryParam` or `IVFQueryParam`. Attributes: type (IndexType): The index type this query is configured for. radius (float): Search radius for range queries. Used in combination with top-k to filter results. Default is 0.0 (disabled). is_linear (bool): If True, forces brute-force linear search instead of using the index. Useful for debugging or small datasets. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. """ def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... @property def is_linear(self) -> bool: """ bool: Whether to bypass the index and use brute-force linear search. """ @property def is_using_refiner(self) -> bool: """ bool: Whether to use refiner for the query. """ @property def radius(self) -> float: """ IndexType: The type of index this query targets. """ @property def type(self) -> _zvec.typing.IndexType: """ IndexType: The type of index this query targets. """ class SegmentOption: """ Options for segment-level operations. Currently, this class mirrors CollectionOption and is used internally. It supports read-only mode, memory mapping, and buffer configuration. Note: This class is primarily for internal use. Most users should use CollectionOption instead. Examples: >>> opt = SegmentOption() >>> print(opt.enable_mmap) True """ def __getstate__(self) -> tuple: ... def __init__(self) -> None: """ Constructs a SegmentOption with default settings. """ def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @property def enable_mmap(self) -> bool: """ bool: Whether memory-mapped I/O is enabled. """ @property def max_buffer_size(self) -> int: """ int: Maximum buffer size in bytes (internal use). """ @property def read_only(self) -> bool: """ bool: Whether the segment is read-only. """ class VectorIndexParam(IndexParam): """ Base class for vector index parameter configurations. Encapsulates common settings for all vector index types. Attributes: type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. """ def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... def to_dict(self) -> dict: """ Convert to dictionary with all fields """ @property def metric_type(self) -> _zvec.typing.MetricType: """ MetricType: Distance metric (e.g., IP, COSINE, L2). """ @property def quantize_type(self) -> _zvec.typing.QuantizeType: """ QuantizeType: Vector quantization type (e.g., FP16, INT8). """ class _VectorQuery: field_name: str filter: str include_vector: bool query_params: QueryParam def __getstate__(self) -> tuple: ... def __init__(self) -> None: ... def __setstate__(self, arg0: tuple) -> None: ... def set_vector(self, arg0: ..., arg1: typing.Any) -> None: ... @property def output_fields(self) -> list[str] | None: ... @output_fields.setter def output_fields(self, arg0: collections.abc.Sequence[str] | None) -> None: ... @property def topk(self) -> int: ... @topk.setter def topk(self, arg0: typing.SupportsInt) -> None: ... ================================================ FILE: python/zvec/model/param/vector_query.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from dataclasses import dataclass from typing import Optional, Union from ...common import VectorType from . import HnswQueryParam, IVFQueryParam __all__ = ["VectorQuery"] @dataclass(frozen=True) class VectorQuery: """Represents a vector search query for a specific field in a collection. A `VectorQuery` can be constructed using either a document ID (to look up its vector) or an explicit vector. It may optionally include index-specific query parameters to control search behavior (e.g., `ef` for HNSW, `nprobe` for IVF). Exactly one of `id` or `vector` should be provided. If both are given, behavior is implementation-defined (typically `id` takes precedence). Attributes: field_name (str): Name of the vector field to query. id (Optional[str], optional): Document ID to fetch vector from. Default is None. vector (VectorType, optional): Explicit query vector. Default is None. param (Optional[Union[HnswQueryParam, IVFQueryParam]], optional): Index-specific query parameters. Default is None. Examples: >>> import zvec >>> # Query by ID >>> q1 = zvec.VectorQuery(field_name="embedding", id="doc123") >>> # Query by vector >>> q2 = zvec.VectorQuery( ... field_name="embedding", ... vector=[0.1, 0.2, 0.3], ... param=HnswQueryParam(ef=300) ... ) """ field_name: str id: Optional[str] = None vector: VectorType = None param: Optional[Union[HnswQueryParam, IVFQueryParam]] = None def has_id(self) -> bool: """Check if the query is based on a document ID. Returns: bool: True if `id` is set, False otherwise. """ return self.id is not None def has_vector(self) -> bool: """Check if the query contains an explicit vector. Returns: bool: True if `vector` is non-empty, False otherwise. """ return self.vector is not None and len(self.vector) > 0 def _validate(self) -> None: if self.field_name is None: raise ValueError("Field name cannot be empty") if self.id and self.vector: raise ValueError("Cannot provide both id and vector") ================================================ FILE: python/zvec/model/schema/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from _zvec.schema import CollectionStats from .collection_schema import CollectionSchema from .field_schema import FieldSchema, VectorSchema __all__ = ["CollectionSchema", "CollectionStats", "FieldSchema", "VectorSchema"] ================================================ FILE: python/zvec/model/schema/__init__.pyi ================================================ """ This module contains the schema of Zvec """ from __future__ import annotations import collections.abc import typing import _zvec.param import _zvec.typing from .collection_schema import CollectionSchema from .field_schema import FieldSchema, VectorSchema __all__: list[str] = [ "CollectionSchema", "CollectionStats", "FieldSchema", "VectorSchema", ] class CollectionStats: def __init__(self) -> None: ... def __repr__(self) -> str: ... @property def doc_count(self) -> int: ... @property def index_completeness(self) -> dict[str, float]: ... class _CollectionSchema: __hash__: typing.ClassVar[None] = None def __eq__(self, arg0: _CollectionSchema) -> bool: ... def __init__( self, name: str, fields: collections.abc.Sequence[_FieldSchema] ) -> None: """ Construct with name and list of fields """ def __ne__(self, arg0: _CollectionSchema) -> bool: ... def fields(self) -> list[_FieldSchema]: """ Return list of all field schemas. """ def forward_fields(self) -> list[_FieldSchema]: """ Return list of forward-indexed fields. """ def get_field(self, field_name: str) -> _FieldSchema: """ Get field by name (const pointer), returns None if not found. """ def get_forward_field(self, field_name: str) -> _FieldSchema: """ Get forward field (used for filtering). """ def get_vector_field(self, field_name: str) -> _FieldSchema: """ Get vector field by name. """ def has_field(self, field_name: str) -> bool: """ Check if a field exists. """ def vector_fields(self) -> list[_FieldSchema]: """ Return list of vector fields. """ @property def name(self) -> str: ... class _FieldSchema: __hash__: typing.ClassVar[None] = None def __eq__(self, arg0: _FieldSchema) -> bool: ... def __init__( self, name: str, data_type: _zvec.typing.DataType, nullable: bool = False, dimension: typing.SupportsInt = 0, index_param: _zvec.param.IndexParam = None, ) -> None: ... def __ne__(self, arg0: _FieldSchema) -> bool: ... @property def data_type(self) -> _zvec.typing.DataType: ... @property def dimension(self) -> int: ... @property def index_param(self) -> typing.Any: ... @property def index_type(self) -> _zvec.typing.IndexType: ... @property def is_dense_vector(self) -> bool: ... @property def is_sparse_vector(self) -> bool: ... @property def name(self) -> str: ... @property def nullable(self) -> bool: ... ================================================ FILE: python/zvec/model/schema/collection_schema.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import json from typing import Optional, Union from _zvec.schema import _CollectionSchema, _FieldSchema from .field_schema import FieldSchema, VectorSchema __all__ = [ "CollectionSchema", ] class CollectionSchema: """Defines the structure of a collection in Zvec. A collection schema specifies the name of the collection and its fields, including both scalar fields (e.g., int, string) and vector fields. Field names must be unique across both scalar and vector fields. Args: name (str): Name of the collection. fields (Optional[Union[FieldSchema, list[FieldSchema]]], optional): One or more scalar field definitions. Defaults to None. vectors (Optional[Union[VectorSchema, list[VectorSchema]]], optional): One or more vector field definitions. Defaults to None. Raises: TypeError: If `fields` or `vectors` are of unsupported types. ValueError: If any field or vector name is duplicated. Examples: >>> from zvec import FieldSchema, VectorSchema, DataType, IndexType >>> id_field = FieldSchema("id", DataType.INT64, is_primary=True) >>> emb_field = VectorSchema("embedding", dim=128, data_type=DataType.VECTOR_FP32) >>> schema = CollectionSchema( ... name="my_collection", ... fields=id_field, ... vectors=emb_field ... ) >>> print(schema.name) my_collection """ def __init__( self, name: str, fields: Optional[Union[FieldSchema, list[FieldSchema]]] = None, vectors: Optional[Union[VectorSchema, list[VectorSchema]]] = None, ): if name is None or not isinstance(name, str): raise ValueError( f"schema validate failed: collection name must be str, got {type(name).__name__}" ) # handle fields _fields_name: list[str] = [] _fields_list: list[_FieldSchema] = [] self._check_fields(fields, _fields_name, _fields_list) self._check_vectors(vectors, _fields_name, _fields_list) # init self._cpp_obj = _CollectionSchema( name=name, fields=_fields_list, ) def _check_fields( self, fields: Optional[Union[FieldSchema, list[FieldSchema]]], _fields_name: list[str], _fields_list: list[_FieldSchema], ) -> None: field_items = [] if isinstance(fields, FieldSchema): field_items = [fields] elif isinstance(fields, list): field_items = fields elif fields is None: field_items = [] else: raise TypeError( f"schema validate failed: invalid 'fields' type, expected FieldSchema or list[FieldSchema], " f"got {type(fields).__name__}" ) for idx, field in enumerate(field_items): if not isinstance(field, FieldSchema): raise TypeError( f"schema validate failed: invalid field type in 'fields' list, expected FieldSchema, " f"got {type(field).__name__} at index {idx}" ) if field.name in _fields_name: raise ValueError( f"schema validate failed: duplicate field name '{field.name}': field names must be unique" ) _fields_name.append(field.name) _fields_list.append(field._get_object()) def _check_vectors( self, vectors: Optional[Union[VectorSchema, list[VectorSchema]]], _fields_name: list[str], _fields_list: list[_FieldSchema], ) -> None: # handle vector if isinstance(vectors, VectorSchema): vectors_items = [vectors] elif isinstance(vectors, list): vectors_items = vectors elif vectors is None: vectors_items = [] else: raise TypeError( f"schema validate failed: invalid 'vectors' type, expected VectorSchema or list[VectorSchema], " f"got {type(vectors).__name__}" ) for idx, vector in enumerate(vectors_items): if not isinstance(vector, VectorSchema): raise TypeError( f"schema validate failed: invalid vector type in 'vectors' list, expected VectorSchema, " f"got {type(vector).__name__} at index {idx}" ) if vector.name in _fields_name: raise ValueError( f"schema validate failed: duplicate vector name '{vector.name}', vector names must be unique " f"(conflicts with existing field or vector)" ) _fields_name.append(vector.name) _fields_list.append(vector._get_object()) @classmethod def _from_core(cls, core_collection_schema: _CollectionSchema): inst = cls.__new__(cls) if not core_collection_schema: raise ValueError("schema validate failed: schema is null") inst._cpp_obj = core_collection_schema return inst @property def name(self) -> str: """str: The name of the collection.""" return self._cpp_obj.name def field(self, name: str) -> Optional[FieldSchema]: """Retrieve a scalar field by name. Args: name (str): Name of the field. Returns: Optional[FieldSchema]: The field if found, otherwise None. """ _field = self._cpp_obj.get_forward_field(name) return FieldSchema._from_core(_field) if _field else None def vector(self, name: str) -> Optional[VectorSchema]: """Retrieve a vector field by name. Args: name (str): Name of the vector field. Returns: Optional[VectorSchema]: The vector field if found, otherwise None. """ _field = self._cpp_obj.get_vector_field(name) return VectorSchema._from_core(_field) if _field else None @property def fields(self) -> list[FieldSchema]: """list[FieldSchema]: All scalar (non-vector) fields in the schema.""" _fields = self._cpp_obj.forward_fields() return [FieldSchema._from_core(_field) for _field in _fields] @property def vectors(self) -> list[VectorSchema]: """list[VectorSchema]: All vector fields in the schema.""" _vectors = self._cpp_obj.vector_fields() return [VectorSchema._from_core(_vector) for _vector in _vectors] def _get_object(self) -> _CollectionSchema: return self._cpp_obj def __repr__(self) -> str: try: schema = { "name": self.name, "fields": {field.name: field.__dict__() for field in self.fields}, "vectors": {vector.name: vector.__dict__() for vector in self.vectors}, } return json.dumps(schema, indent=2, ensure_ascii=False) except Exception as e: return f"" def __str__(self) -> str: return self.__repr__() ================================================ FILE: python/zvec/model/schema/field_schema.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import json from typing import Any, Optional, Union from _zvec.schema import _FieldSchema from zvec.model.param import ( FlatIndexParam, HnswIndexParam, InvertIndexParam, IVFIndexParam, ) from zvec.typing import DataType __all__ = [ "FieldSchema", "VectorSchema", ] SUPPORT_VECTOR_DATA_TYPE = [ DataType.VECTOR_FP16, DataType.VECTOR_FP32, DataType.VECTOR_FP64, DataType.VECTOR_INT8, DataType.SPARSE_VECTOR_FP16, DataType.SPARSE_VECTOR_FP32, ] SUPPORT_SCALAR_DATA_TYPE = [ DataType.INT32, DataType.INT64, DataType.UINT32, DataType.UINT64, DataType.FLOAT, DataType.DOUBLE, DataType.STRING, DataType.BOOL, DataType.ARRAY_INT32, DataType.ARRAY_INT64, DataType.ARRAY_UINT32, DataType.ARRAY_UINT64, DataType.ARRAY_FLOAT, DataType.ARRAY_DOUBLE, DataType.ARRAY_STRING, DataType.ARRAY_BOOL, ] class FieldSchema: """Represents a scalar (non-vector) field in a collection schema. A `FieldSchema` defines the name, data type, nullability, and optional inverted index configuration for a regular field (e.g., ID, timestamp, category). Args: name (str): Name of the field. Must be unique within the collection. data_type (DataType): Data type of the field (e.g., INT64, STRING). nullable (bool, optional): Whether the field can contain null values. Defaults to False. index_param (Optional[InvertIndexParam], optional): Inverted index parameters for this field. Only applicable to fields that support indexing (e.g., scalar fields used in filtering). Defaults to None. Examples: >>> from zvec.typing import DataType >>> from zvec.model.param import InvertIndexParam >>> id_field = FieldSchema( ... name="id", ... data_type=DataType.INT64, ... nullable=False, ... index_param=InvertIndexParam(enable_range_optimization=True) ... ) """ def __init__( self, name: str, data_type: DataType, nullable: bool = False, index_param: Optional[InvertIndexParam] = None, ): if name is None or not isinstance(name, str): raise ValueError( f"schema validate failed: field name must be str, got {type(name).__name__}" ) if data_type not in SUPPORT_SCALAR_DATA_TYPE: raise ValueError( f"schema validate failed: scalar_field's data_type must be one of " f"{', '.join(str(dt) for dt in SUPPORT_SCALAR_DATA_TYPE)}, " f"but field[{name}]'s data_type is {data_type}" ) self._cpp_obj = _FieldSchema( name=name, data_type=data_type, dimension=0, nullable=nullable, index_param=index_param, ) @classmethod def _from_core(cls, core_field_schema: _FieldSchema): if core_field_schema is None: raise ValueError("schema validate failed: field schema is None") inst = cls.__new__(cls) inst._cpp_obj = core_field_schema return inst def _get_object(self) -> _FieldSchema: return self._cpp_obj @property def name(self) -> str: """str: The name of the field.""" return self._cpp_obj.name @property def data_type(self) -> DataType: """DataType: The data type of the field (e.g., INT64, STRING).""" return self._cpp_obj.data_type @property def nullable(self) -> bool: """bool: Whether the field allows null values.""" return self._cpp_obj.nullable @property def index_param(self) -> Optional[InvertIndexParam]: """Optional[InvertIndexParam]: Inverted index configuration, if any.""" return self._cpp_obj.index_param def __dict__(self) -> dict[str, Any]: return { "name": self.name, "data_type": ( self.data_type.name if hasattr(self.data_type, "name") else str(self.data_type) ), "nullable": self.nullable, "index_param": ( self.index_param.to_dict() if self.index_param is not None else None ), } def __repr__(self) -> str: try: schema = self.__dict__() return json.dumps(schema, indent=2, ensure_ascii=False) except Exception as e: return f"" def __str__(self) -> str: return self.__repr__() def __eq__(self, other: object) -> bool: if not isinstance(other, FieldSchema): return False return self._cpp_obj == other._cpp_obj def __hash__(self) -> int: return hash((self.name, self.data_type, self.nullable)) class VectorSchema: """Represents a vector field in a collection schema. A `VectorSchema` defines the name, data type, dimensionality, and index configuration for a vector field used in similarity search. Args: name (str): Name of the vector field. Must be unique within the collection. data_type (DataType): Vector data type (e.g., VECTOR_FP32, VECTOR_INT8). dimension (int, optional): Dimensionality of the vector. Must be > 0 for dense vectors; may be `None` for sparse vectors. index_param (Union[HnswIndexParam, IVFIndexParam, FlatIndexParam], optional): Index configuration for this vector field. Defaults to ``HnswIndexParam()``. Examples: >>> from zvec.typing import DataType >>> from zvec.model.param import HnswIndexParam >>> emb_field = VectorSchema( ... name="embedding", ... data_type=DataType.VECTOR_FP32, ... dimension=128, ... index_param=HnswIndexParam(ef_construction=200, m=16) ... ) """ def __init__( self, name: str, data_type: DataType, dimension: Optional[int] = 0, index_param: Optional[ Union[HnswIndexParam, FlatIndexParam, IVFIndexParam] ] = None, ): if name is None or not isinstance(name, str): raise ValueError( f"schema validate failed: field name must be str, got {type(name).__name__}" ) if not isinstance(dimension, int) or dimension < 0: raise ValueError("schema validate failed: vector's dimension must be >= 0") if data_type not in SUPPORT_VECTOR_DATA_TYPE: raise ValueError( f"schema validate failed: vector's data_type must be one of " f"{', '.join(str(dt) for dt in SUPPORT_VECTOR_DATA_TYPE)}, " f"but field[{name}]'s data_type is {data_type}" ) if index_param is None: index_param = FlatIndexParam() self._cpp_obj = _FieldSchema( name=name, data_type=data_type, dimension=dimension, nullable=False, index_param=index_param, ) @classmethod def _from_core(cls, core_field_schema: _FieldSchema): inst = cls.__new__(cls) inst._cpp_obj = core_field_schema return inst def _get_object(self) -> _FieldSchema: return self._cpp_obj @property def name(self) -> str: """str: The name of the vector field.""" return self._cpp_obj.name @property def data_type(self) -> DataType: """DataType: The vector data type (e.g., VECTOR_FP32).""" return self._cpp_obj.data_type @property def dimension(self) -> int: """int: The dimensionality of the vector.""" return self._cpp_obj.dimension @property def index_param(self) -> Union[HnswIndexParam, IVFIndexParam, FlatIndexParam]: """Union[HnswIndexParam, IVFIndexParam, FlatIndexParam]: Index configuration for the vector.""" return self._cpp_obj.index_param def __dict__(self) -> dict[str, Any]: return { "name": self.name, "data_type": ( self.data_type.name if hasattr(self.data_type, "name") else str(self.data_type) ), "dimension": self.dimension, "index_param": ( self.index_param.to_dict() if self.index_param is not None else None ), } def __repr__(self) -> str: try: schema = self.__dict__() return json.dumps(schema, indent=2, ensure_ascii=False) except Exception as e: return f"" def __str__(self) -> str: return self.__repr__() def __eq__(self, other: object) -> bool: if not isinstance(other, VectorSchema): return False return self._cpp_obj == other._cpp_obj def __hash__(self) -> int: return hash((self.name, self.data_type, self.dimension)) ================================================ FILE: python/zvec/py.typed ================================================ ================================================ FILE: python/zvec/tool/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from .util import require_module __all__ = ["require_module"] ================================================ FILE: python/zvec/tool/util.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import importlib from typing import Any, Optional def require_module(module: str, mitigation: Optional[str] = None) -> Any: """Import a Python module and raise a user-friendly error if it is not available. This utility helps provide actionable error messages when optional dependencies are missing. It attempts to import the given module and, on failure, suggests a `pip install` command based on either the module name or an optional mitigation package name. Args: module (str): The full module name to import (e.g., ``"numpy"``, ``"pandas.io.parquet"``). mitigation (Optional[str], optional): The package name to suggest for installation if the import fails. If not provided, the top-level package of `module` will be used (e.g., ``"pandas"`` for ``"pandas.io.parquet"``). Returns: Any: The imported module object. Raises: ImportError: If the module cannot be imported, with a clear installation hint. Examples: >>> import zvec >>> np = zvec.require_module("numpy") >>> pq = zvec.require_module("pyarrow.parquet", mitigation="pyarrow") Note: This function is intended for lazy-loading optional dependencies with helpful error messages, not for core dependencies. """ try: return importlib.import_module(module) except ImportError as e: package = mitigation or module msg = f"Required package '{package}' is not installed. " if "." in module: top_level = module.split(".", maxsplit=1)[0] msg += f"Module '{module}' is part of '{top_level}', " if mitigation: msg += f"please pip install '{mitigation}'." else: msg += f"please pip install '{top_level}'." else: msg += f"Please pip install '{package}'." raise ImportError(msg) from e ================================================ FILE: python/zvec/typing/__init__.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from _zvec.typing import ( DataType, IndexType, MetricType, QuantizeType, Status, StatusCode, ) __all__ = [ "DataType", "IndexType", "MetricType", "QuantizeType", "Status", "StatusCode", ] ================================================ FILE: python/zvec/typing/__init__.pyi ================================================ """ This module contains the basic data types of Zvec """ from __future__ import annotations import typing __all__: list[str] = [ "DataType", "IndexType", "MetricType", "QuantizeType", "Status", "StatusCode", ] class DataType: """ Enumeration of supported data types in Zvec. Includes scalar types, dense/sparse vector types, and array types. Examples: >>> import zvec >>> print(zvec.DataType.FLOAT) DataType.FLOAT >>> print(zvec.DataType.VECTOR_FP32) DataType.VECTOR_FP32 Members: STRING BOOL INT32 INT64 FLOAT DOUBLE UINT32 UINT64 VECTOR_FP16 VECTOR_FP32 VECTOR_FP64 VECTOR_INT8 SPARSE_VECTOR_FP32 SPARSE_VECTOR_FP16 ARRAY_STRING ARRAY_INT32 ARRAY_INT64 ARRAY_FLOAT ARRAY_DOUBLE ARRAY_BOOL ARRAY_UINT32 ARRAY_UINT64 """ ARRAY_BOOL: typing.ClassVar[DataType] # value = ARRAY_DOUBLE: typing.ClassVar[DataType] # value = ARRAY_FLOAT: typing.ClassVar[DataType] # value = ARRAY_INT32: typing.ClassVar[DataType] # value = ARRAY_INT64: typing.ClassVar[DataType] # value = ARRAY_STRING: typing.ClassVar[DataType] # value = ARRAY_UINT32: typing.ClassVar[DataType] # value = ARRAY_UINT64: typing.ClassVar[DataType] # value = BOOL: typing.ClassVar[DataType] # value = DOUBLE: typing.ClassVar[DataType] # value = FLOAT: typing.ClassVar[DataType] # value = INT32: typing.ClassVar[DataType] # value = INT64: typing.ClassVar[DataType] # value = SPARSE_VECTOR_FP16: typing.ClassVar[ DataType ] # value = SPARSE_VECTOR_FP32: typing.ClassVar[ DataType ] # value = STRING: typing.ClassVar[DataType] # value = UINT32: typing.ClassVar[DataType] # value = UINT64: typing.ClassVar[DataType] # value = VECTOR_FP16: typing.ClassVar[DataType] # value = VECTOR_FP32: typing.ClassVar[DataType] # value = VECTOR_FP64: typing.ClassVar[DataType] # value = VECTOR_INT8: typing.ClassVar[DataType] # value = __members__: typing.ClassVar[ dict[str, DataType] ] # value = {'STRING': , 'BOOL': , 'INT32': , 'INT64': , 'FLOAT': , 'DOUBLE': , 'UINT32': , 'UINT64': , 'VECTOR_FP16': , 'VECTOR_FP32': , 'VECTOR_FP64': , 'VECTOR_INT8': , 'SPARSE_VECTOR_FP32': , 'SPARSE_VECTOR_FP16': , 'ARRAY_STRING': , 'ARRAY_INT32': , 'ARRAY_INT64': , 'ARRAY_FLOAT': , 'ARRAY_DOUBLE': , 'ARRAY_BOOL': , 'ARRAY_UINT32': , 'ARRAY_UINT64': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class IndexType: """ Enumeration of supported index types in Zvec. Examples: >>> import zvec >>> print(zvec.IndexType.HNSW) IndexType.HNSW Members: UNDEFINED HNSW IVF FLAT INVERT """ FLAT: typing.ClassVar[IndexType] # value = HNSW: typing.ClassVar[IndexType] # value = INVERT: typing.ClassVar[IndexType] # value = IVF: typing.ClassVar[IndexType] # value = UNDEFINED: typing.ClassVar[IndexType] # value = __members__: typing.ClassVar[ dict[str, IndexType] ] # value = {'UNDEFINED': , 'HNSW': , 'IVF': , 'FLAT': , 'INVERT': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class MetricType: """ Enumeration of supported distance/similarity metrics. - COSINE: Cosine similarity. - IP: Inner product (dot product). - L2: Euclidean distance (L2 norm). Examples: >>> import zvec >>> print(zvec.MetricType.COSINE) MetricType.COSINE Members: COSINE IP L2 """ COSINE: typing.ClassVar[MetricType] # value = IP: typing.ClassVar[MetricType] # value = L2: typing.ClassVar[MetricType] # value = __members__: typing.ClassVar[ dict[str, MetricType] ] # value = {'COSINE': , 'IP': , 'L2': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class QuantizeType: """ Enumeration of supported quantization types for vector compression. Examples: >>> import zvec >>> print(zvec.QuantizeType.INT8) QuantizeType.INT8 Members: UNDEFINED FP16 INT8 INT4 """ FP16: typing.ClassVar[QuantizeType] # value = INT4: typing.ClassVar[QuantizeType] # value = INT8: typing.ClassVar[QuantizeType] # value = UNDEFINED: typing.ClassVar[QuantizeType] # value = __members__: typing.ClassVar[ dict[str, QuantizeType] ] # value = {'UNDEFINED': , 'FP16': , 'INT8': , 'INT4': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class Status: """ Represents the outcome of a Zvec operation. A `Status` object is either OK (success) or carries an error code and message. Examples: >>> from zvec.typing import Status, StatusCode >>> s = Status() >>> print(s.ok()) True >>> s = Status(StatusCode.INVALID_ARGUMENT, "Field not found") >>> print(s.code() == StatusCode.INVALID_ARGUMENT) True >>> print(s.message()) Field not found """ __hash__: typing.ClassVar[None] = None @staticmethod def AlreadyExists(message: str) -> Status: ... @staticmethod def InternalError(message: str) -> Status: ... @staticmethod def InvalidArgument(message: str) -> Status: ... @staticmethod def NotFound(message: str) -> Status: ... @staticmethod def OK() -> Status: """ Create an OK status. """ @staticmethod def PermissionDenied(message: str) -> Status: ... def __eq__(self, arg0: Status) -> bool: ... @typing.overload def __init__(self) -> None: ... @typing.overload def __init__(self, code: StatusCode, message: str = "") -> None: """ Construct a status with the given code and optional message. Args: code (StatusCode): The status code. message (str, optional): Error message. Defaults to empty string. """ def __ne__(self, arg0: Status) -> bool: ... def __repr__(self) -> str: ... def code(self) -> StatusCode: """ StatusCode: Returns the status code. """ def message(self) -> str: """ str: Returns the error message (may be empty). """ def ok(self) -> bool: """ bool: Returns True if the status is OK. """ class StatusCode: """ Enumeration of possible status codes for Zvec operations. Used by the `Status` class to indicate success or failure reason. Members: OK NOT_FOUND ALREADY_EXISTS INVALID_ARGUMENT PERMISSION_DENIED FAILED_PRECONDITION RESOURCE_EXHAUSTED UNAVAILABLE INTERNAL_ERROR NOT_SUPPORTED UNKNOWN """ ALREADY_EXISTS: typing.ClassVar[ StatusCode ] # value = FAILED_PRECONDITION: typing.ClassVar[ StatusCode ] # value = INTERNAL_ERROR: typing.ClassVar[ StatusCode ] # value = INVALID_ARGUMENT: typing.ClassVar[ StatusCode ] # value = NOT_FOUND: typing.ClassVar[StatusCode] # value = NOT_SUPPORTED: typing.ClassVar[StatusCode] # value = OK: typing.ClassVar[StatusCode] # value = PERMISSION_DENIED: typing.ClassVar[ StatusCode ] # value = RESOURCE_EXHAUSTED: typing.ClassVar[ StatusCode ] # value = UNAVAILABLE: typing.ClassVar[StatusCode] # value = UNKNOWN: typing.ClassVar[StatusCode] # value = __members__: typing.ClassVar[ dict[str, StatusCode] ] # value = {'OK': , 'NOT_FOUND': , 'ALREADY_EXISTS': , 'INVALID_ARGUMENT': , 'PERMISSION_DENIED': , 'FAILED_PRECONDITION': , 'RESOURCE_EXHAUSTED': , 'UNAVAILABLE': , 'INTERNAL_ERROR': , 'NOT_SUPPORTED': , 'UNKNOWN': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: typing.Any) -> bool: ... def __repr__(self) -> str: ... def __setstate__(self, state: typing.SupportsInt) -> None: ... def __str__(self) -> str: ... @property def name(self) -> str: ... @property def value(self) -> int: ... ================================================ FILE: python/zvec/typing/enum.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from enum import IntEnum __all__ = ["LogLevel", "LogType"] class LogLevel(IntEnum): """Enumeration of logging severity levels, ordered from lowest to highest priority. Used to control verbosity and filtering of log messages. Higher numeric values indicate more severe conditions. Note: ``WARNING`` is an alias for ``WARN`` to match Python's built-in :mod:`logging` module convention. Attributes: DEBUG (int): Detailed information, typically of interest only when diagnosing problems. INFO (int): Confirmation that things are working as expected. WARN (int): An indication that something unexpected happened, or indicative of potential future problems. (Alias: ``WARNING``) WARNING (int): Same as ``WARN``. ERROR (int): Due to a more serious problem, the software has not been able to perform some function. FATAL (int): A serious error, indicating that the program itself may be unable to continue running. """ DEBUG = 0 INFO = 1 WARN = 2 WARNING = 2 ERROR = 3 FATAL = 4 class LogType(IntEnum): """Enumeration of supported log output destinations. Specifies where log messages should be written. Attributes: CONSOLE (int): Output logs to standard output/error (e.g., terminal or IDE console). FILE (int): Write logs to a persistent file on disk. """ CONSOLE = 0 FILE = 1 ================================================ FILE: python/zvec/zvec.py ================================================ # Copyright 2025-present the zvec project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Optional from _zvec import Initialize, _Collection from .model import Collection from .model.param import CollectionOption from .model.schema import CollectionSchema __all__ = ["create_and_open", "init", "open"] from .typing.enum import LogLevel, LogType def init( *, log_type: Optional[LogType] = LogType.CONSOLE, log_level: Optional[LogLevel] = LogLevel.WARN, log_dir: Optional[str] = "./logs", log_basename: Optional[str] = "zvec.log", log_file_size: Optional[int] = 2048, log_overdue_days: Optional[int] = 7, query_threads: Optional[int] = None, optimize_threads: Optional[int] = None, invert_to_forward_scan_ratio: Optional[float] = None, brute_force_by_keys_ratio: Optional[float] = None, memory_limit_mb: Optional[int] = None, ) -> None: """Initialize Zvec with configuration options. This function must be called before any other operation. It can only be called once — subsequent calls raise a ``RuntimeError``. Parameters set to ``None`` are **omitted** from the configuration and fall back to Zvec's internal defaults, which may be derived from the runtime environment (e.g., cgroup CPU/memory limits). Explicitly provided values always override defaults. Args: log_type (Optional[LogType], optional): Logger destination. - ``LogType.CONSOLE`` (default if omitted or set to this) - ``LogType.FILE`` - If ``None``, uses internal default (currently ``CONSOLE``). log_level (Optional[LogLevel], optional): Minimum log severity. Default: ``LogLevel.WARN``. Accepted values: ``DEBUG``, ``INFO``, ``WARN``, ``ERROR``, ``FATAL``. If ``None``, uses internal default (``WARN``). log_dir (Optional[str], optional): Directory for log files (only used when ``log_type=FILE``). Parent directories are **not** created automatically. Default: ``"./logs"``. If ``None``, internal default is used. log_basename (Optional[str], optional): Base name for rotated log files (e.g., ``zvec.log.1``, ``zvec.log.2``). Default: ``"zvec.log"``. log_file_size (Optional[int], optional): Max size per log file in **MB** before rotation. Default: ``2048`` MB (2 GB). log_overdue_days (Optional[int], optional): Days to retain rotated log files before deletion. Default: ``7`` days. query_threads (Optional[int], optional): Number of threads for query execution. If ``None`` (default), inferred from available CPU cores (via cgroup). Must be ≥ 1 if provided. optimize_threads (Optional[int], optional): Threads for background tasks (e.g., compaction, indexing). If ``None``, defaults to same as ``query_threads`` or CPU count. invert_to_forward_scan_ratio (Optional[float], optional): Threshold to switch from inverted index to full forward scan. Range: [0.0, 1.0]. Higher → more aggressive index skipping. Default: ``0.9`` (if omitted). brute_force_by_keys_ratio (Optional[float], optional): Threshold to use brute-force key lookup over index. Lower → prefer index; higher → prefer brute-force. Range: [0.0, 1.0]. Default: ``0.1``. memory_limit_mb (Optional[int], optional): Soft memory cap in MB. Zvec may throttle or fail operations approaching this limit. If ``None``, inferred from cgroup memory limit * 0.8 (e.g., in Docker). Must be > 0 if provided. Raises: RuntimeError: If Zvec is already initialized. ValueError: On invalid values (e.g., negative thread count, log level out of range). TypeError: If a value has incorrect type (e.g., string for ``query_threads``). Note: - All ``None`` arguments are **excluded** from the configuration payload, allowing the core library to apply environment-aware defaults. - This design ensures container-friendliness: in Kubernetes/Docker, omitting ``memory_limit_mb`` and thread counts lets Zvec auto-adapt. Examples: Initialize with defaults (log to console, auto-detect resources): >>> import zvec >>> zvec.init() Customize logging to file with rotation: >>> zvec.init( ... log_type=LogType.FILE, ... log_dir="/var/log/zvec", ... log_file_size=1024, ... log_overdue_days=30 ... ) Limit resources explicitly: >>> zvec.init( ... memory_limit_mb=2048, ... query_threads=4, ... optimize_threads=2 ... ) Fine-tune query heuristics: >>> zvec.init( ... invert_to_forward_scan_ratio=0.95, ... brute_force_by_keys_ratio=0.05 ... ) """ # Build config dict, skipping None values config_dict = {} if log_type is not None: if not isinstance(log_type, LogType): raise TypeError("log_type must be LogType") config_dict["log_type"] = log_type.name if log_level is not None: if not isinstance(log_level, LogLevel): raise TypeError("log_level must be LogLevel") config_dict["log_level"] = log_level.name if log_dir is not None: config_dict["log_dir"] = log_dir if log_basename is not None: config_dict["log_basename"] = log_basename if log_file_size is not None: config_dict["log_file_size"] = log_file_size if log_overdue_days is not None: config_dict["log_overdue_days"] = log_overdue_days if query_threads is not None: config_dict["query_threads"] = query_threads if optimize_threads is not None: config_dict["optimize_threads"] = optimize_threads if invert_to_forward_scan_ratio is not None: config_dict["invert_to_forward_scan_ratio"] = invert_to_forward_scan_ratio if brute_force_by_keys_ratio is not None: config_dict["brute_force_by_keys_ratio"] = brute_force_by_keys_ratio if memory_limit_mb is not None: config_dict["memory_limit_mb"] = memory_limit_mb Initialize(config_dict) def create_and_open( path: str, schema: CollectionSchema, option: Optional[CollectionOption] = None, ) -> Collection: """Create a new collection and open it for use. If a collection already exists at the given path, it may raise an error depending on the underlying implementation. Args: path (str): Path or name of the collection to create. schema (CollectionSchema): Schema defining the structure of the collection. option (CollectionOption): Configuration options for opening the collection. Defaults to a default-constructed ``CollectionOption()`` if not provided. Returns: Collection: An opened collection instance ready for operations. Examples: >>> import zvec >>> schema = zvec.CollectionSchema( ... name="my_collection", ... fields=[zvec.FieldSchema("id", zvec.DataType.INT64, nullable=True)] ... ) >>> coll = create_and_open("./my_collection", schema) """ if not isinstance(path, str): raise TypeError("path must be a string") if not isinstance(schema, CollectionSchema): raise TypeError("schema must be a CollectionSchema") option = option or CollectionOption() if not isinstance(option, CollectionOption): raise TypeError("option must be a CollectionOption") _collection = _Collection.CreateAndOpen(path, schema._get_object(), option) return Collection._from_core(_collection) def open(path: str, option: CollectionOption = CollectionOption()) -> Collection: """Open an existing collection from disk. The collection must have been previously created with ``create_and_open``. Args: path (str): Path or name of the existing collection. option (CollectionOption): Configuration options for opening the collection. Defaults to a default-constructed ``CollectionOption()`` if not provided. Returns: Collection: An opened collection instance. Examples: >>> import zvec >>> coll = zvec.open("./my_collection") """ _collection = _Collection.Open(path, option) return Collection._from_core(_collection) ================================================ FILE: scripts/README.md ================================================ ================================================ FILE: scripts/build_android.sh ================================================ #!/bin/bash set -e CURRENT_DIR=$(pwd) ABI=${1:-"arm64-v8a"} API_LEVEL=${2:-21} BUILD_TYPE=${3:-"Release"} # step1: use host env to compile protoc echo "step1: building protoc for host..." HOST_BUILD_DIR="build_host" mkdir -p $HOST_BUILD_DIR cd $HOST_BUILD_DIR cmake -DCMAKE_BUILD_TYPE="$BUILD_TYPE" .. make -j protoc PROTOC_EXECUTABLE=$CURRENT_DIR/$HOST_BUILD_DIR/bin/protoc cd $CURRENT_DIR echo "step1: Done!!!" # step2: cross build zvec based on android ndk echo "step2: building zvec for android..." # reset thirdparty directory git submodule foreach --recursive 'git stash --include-untracked' export ANDROID_SDK_ROOT=$HOME/Library/Android/sdk export ANDROID_HOME=$ANDROID_SDK_ROOT export ANDROID_NDK_HOME=$ANDROID_SDK_ROOT/ndk/28.2.13676358 export CMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_HOME/build/cmake/android.toolchain.cmake export PATH=$PATH:$ANDROID_SDK_ROOT/cmdline-tools/latest/bin export PATH=$PATH:$ANDROID_SDK_ROOT/platform-tools export PATH=$PATH:$ANDROID_NDK_HOME if [ -z "$ANDROID_NDK_HOME" ]; then echo "error: ANDROID_NDK_HOME env not set" echo "please install NDK and set env variable ANDROID_NDK_HOME" exit 1 fi BUILD_DIR="build_android_${ABI}" mkdir -p $BUILD_DIR cd $BUILD_DIR echo "configure CMake..." cmake \ -DANDROID_NDK="$ANDROID_NDK_HOME" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK_HOME/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="$ABI" \ -DANDROID_NATIVE_API_LEVEL="$API_LEVEL" \ -DANDROID_STL="c++_static" \ -DCMAKE_BUILD_TYPE="$BUILD_TYPE" \ -DBUILD_PYTHON_BINDINGS=OFF \ -DBUILD_TOOLS=OFF \ -DCMAKE_INSTALL_PREFIX="./install" \ -DGLOBAL_CC_PROTOBUF_PROTOC=$PROTOC_EXECUTABLE \ ../ echo "building..." CORE_COUNT=$(sysctl -n hw.ncpu) make -j$CORE_COUNT echo "step2: Done!!!" ================================================ FILE: scripts/gcov.sh ================================================ #!/bin/bash project_name=proxima-zvec gcov_tool=gcov zip_html=false output_name=html keep_info=false script_dir=$(cd "$(dirname "$0")"; pwd) source_base=$(dirname "$script_dir") filter_list="'*/tests/*' '*/thirdparty/*' '*/deps/*' '*/proto/*' '*/external/*' '*/sqlengine/antlr/gen/*'" while getopts t:p:o:zk option; do case "$option" in t) gcov_tool=$OPTARG;; p) project_name=$OPTARG;; o) output_name=$OPTARG;; z) zip_html=true;; k) keep_info=true;; esac done # Process sources lcov -c -b "$source_base" -d . -o $project_name.lcov.info --gcov-tool=$gcov_tool --no-external || exit 1 eval $(echo lcov -r $project_name.lcov.info -o $project_name-filtered.lcov.info $filter_list) || exit 1 # Gather HTML files genhtml -t "$project_name" -o $output_name $project_name-filtered.lcov.info || exit 1 if [ "$keep_info" = false ]; then rm -rf *.lcov.info fi # Zip HTML files if $zip_html ; then zip -r $output_name.zip $output_name/ fi ================================================ FILE: src/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) # Retrieve version from git repository git_version(ZVEC_VERSION ${CMAKE_CURRENT_SOURCE_DIR}) # Add repository cc_directory(ailego) cc_directory(turbo) cc_directory(core) cc_directory(db) if(BUILD_PYTHON_BINDINGS) cc_directory(binding) endif() ================================================ FILE: src/ailego/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) find_package(Threads REQUIRED) if(UNIX AND NOT APPLE AND NOT ANDROID) find_library(LIB_RT NAMES rt) else() set(LIB_RT "") endif() git_version(GIT_SRCS_VER ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE ALL_SRCS *.cc *.c *.h) set(EXTRA_LIBS ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) if(UNIX AND NOT APPLE) list(APPEND EXTRA_LIBS ${LIB_RT}) endif() if(NOT ANDROID AND AUTO_DETECT_ARCH) if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|i686|i386|x64") setup_compiler_march_for_x86(MATH_MARCH_FLAG_SSE MATH_MARCH_FLAG_AVX2 MATH_MARCH_FLAG_AVX512 MATH_MARCH_FLAG_AVX512FP16) message(STATUS "best compiler march, sse: " ${MATH_MARCH_FLAG_SSE} ", avx2: " ${MATH_MARCH_FLAG_AVX2} ", avx512: " ${MATH_MARCH_FLAG_AVX512} ", avx512fp16: " ${MATH_MARCH_FLAG_AVX512FP16}) file(GLOB_RECURSE MATH_FILES_SSE ${CMAKE_CURRENT_SOURCE_DIR}/math/*_sse.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_sse.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_sse.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_sse.c ) file(GLOB_RECURSE MATH_FILES_AVX2 ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx2.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx2.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx2.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx2.c ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx.c ) file(GLOB_RECURSE MATH_FILES_AVX512 ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512.c ) file(GLOB_RECURSE MATH_FILES_AVX512FP16 ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.c ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512fp16.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_avx512fp16.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512fp16.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_avx512fp16.c ) foreach(MATH_FILE ${MATH_FILES_SSE}) set_source_files_properties( ${MATH_FILE} PROPERTIES COMPILE_FLAGS "${MATH_MARCH_FLAG_SSE}" ) endforeach() foreach(MATH_FILE ${MATH_FILES_AVX2}) set_source_files_properties( ${MATH_FILE} PROPERTIES COMPILE_FLAGS "${MATH_MARCH_FLAG_AVX2}" ) endforeach() foreach(MATH_FILE ${MATH_FILES_AVX512}) set_source_files_properties( ${MATH_FILE} PROPERTIES COMPILE_FLAGS "${MATH_MARCH_FLAG_AVX512}" ) endforeach() foreach(MATH_FILE ${MATH_FILES_AVX512FP16}) set_source_files_properties( ${MATH_FILE} PROPERTIES COMPILE_FLAGS "${MATH_MARCH_FLAG_AVX512FP16}" ) endforeach() elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64|ARM64") # set(CMAKE_CXX_FLAGS "-march=armv8-a") # set(CMAKE_C_FLAGS "-march=armv8-a") set(MATH_MARCH_FLAG_NEON "-march=armv8-a") file(GLOB_RECURSE MATH_FILES_NEON ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_dispatch.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_dispatch.c ${CMAKE_CURRENT_SOURCE_DIR}/math/*_neon.cc ${CMAKE_CURRENT_SOURCE_DIR}/math/*_neon.c ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_neon.cc ${CMAKE_CURRENT_SOURCE_DIR}/math_batch/*_neon.c ) foreach(MATH_FILE ${MATH_FILES_NEON}) set_source_files_properties( ${MATH_FILE} PROPERTIES COMPILE_FLAGS "${MATH_MARCH_FLAG_NEON}" ) endforeach() endif() endif() cc_library( NAME zvec_ailego STATIC STRICT PACKED SRCS ${ALL_SRCS} LIBS ${EXTRA_LIBS} Arrow::arrow_static Arrow::parquet_static VERSION "${GIT_SRCS_VER}" ) ================================================ FILE: src/ailego/algorithm/binary_quantizer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "binary_quantizer.h" #include #include #include #include #include #include namespace zvec { namespace ailego { //! Feed the training data bool BinaryQuantizer::feed(const float *vec, size_t dim) { for (size_t i = 0; i < dim; ++i) { data_.emplace_back(vec[i]); } return true; } //! Train the quantizer bool BinaryQuantizer::train(void) { return true; } //! Quantize data: encode the float input to uint32_t output void BinaryQuantizer::encode(const float *in, size_t dim, uint32_t *out) const { for (size_t i = 0; i < dim; i += 32) { size_t remain = i + 32 <= dim ? 32 : dim - i; uint32_t data = 0; uint32_t mask = 1; for (size_t j = 0; j < remain; j++) { if (in[i + j] >= threshold_) { data |= mask; } mask <<= 1; } *out = data; out++; } } //! De-quantize data: decode the input uint32_t to float output //! bit value 1 will be mapped to 1.0 //! bit value 0 will be mapped to -1.0 void BinaryQuantizer::decode(const uint32_t *in, size_t dim, float *out) const { for (size_t i = 0; i < dim; ++i) { uint8_t bit = (in[i >> 5] >> (i & 31)) & 0x01; if (bit == 1) { out[i] = 1.0f; } else { out[i] = -1.0f; } // std::cout << "dim: " << i << ", value: " << (size_t)bit << std::endl; } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/algorithm/binary_quantizer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { /*! Binary Quantization Algorithm */ class BinaryQuantizer { public: //! Constructor BinaryQuantizer(void) {} //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Quantize data: encode the float input to uint32_t output void encode(const float *in, size_t dim, uint32_t *out) const; //! De-quantize data: decode the input uint32_t to float output void decode(const uint32_t *in, size_t dim, float *out) const; //! Get encoded elements in type of uint32_t static size_t EncodedSizeInBinary32(size_t dim) { return (dim + 31) / 32; } //! Set quantization threshold void set_threshold(float threshold) { threshold_ = threshold; } //! Get quantization threshold float threshold(void) const { return threshold_; } private: //! Disable them BinaryQuantizer(const BinaryQuantizer &) = delete; BinaryQuantizer &operator=(const BinaryQuantizer &) = delete; private: //! Members std::vector data_{}; float threshold_{0.0f}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/algorithm/integer_quantizer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "integer_quantizer.h" #include #include #include #include #include #include namespace zvec { namespace ailego { //! Make smooth the distribution to eliminate zero in hist static inline void MakeSmooth(std::vector &dist) { constexpr float epsilon = std::numeric_limits::epsilon(); // L1 Normalize first float norm = 1.0f; Normalizer::L1(dist.data(), dist.size(), &norm); size_t zero_count = std::count_if(dist.begin(), dist.end(), [](float val) { return (std::abs(val) < std::numeric_limits::epsilon()); }); size_t nonzero_count = dist.size() - zero_count; // Double check if (nonzero_count == 0 || zero_count == 0) { return; } float y = epsilon * zero_count / static_cast(nonzero_count); for (auto &it : dist) { if (std::abs(it) < epsilon) { it += epsilon; } else { it -= y; } } // end of for } //! Compute the Entropy of distribution p/q by Kullback-Leibler Divergence static inline double ComputeKlDivergence(const std::vector &p, const std::vector &q) { if (p.size() != q.size() || p.size() == 0) { return std::numeric_limits::max(); } double v = 0.0f; for (size_t i = 0; i != p.size(); ++i) { if (p[i] == 0 || q[i] == 0) { return std::numeric_limits::max(); } v += p[i] * std::log(static_cast(p[i]) / static_cast(q[i])); } return v; } //! Expand the quantization distribution to origin distribution in //! [-threshold, threshold] static inline void ExpandCandidateDistribution( const std::vector &distribution, const std::vector &quantized_distribution, size_t threshold, std::vector *expand_distribution) { expand_distribution->resize(threshold * 2, 0); float merged_cnt = static_cast(expand_distribution->size()) / quantized_distribution.size(); size_t left_boundary = distribution.size() / 2 - threshold; for (size_t i = 0; i < quantized_distribution.size(); ++i) { float start = i * merged_cnt; float end = start + merged_cnt; const size_t start_ceil = static_cast(std::ceil(start)); const size_t end_floor = static_cast(std::floor(end)); float left_ratio = static_cast(start_ceil) - start; float right_ratio = end - static_cast(end_floor); float nonzero_count = 0; //! Count the non-zeros bins, if the histogram bin is partially included, //! non-zero bins is also partially counted if (left_ratio > 0 && left_boundary + start_ceil > 0) { if (distribution[left_boundary + start_ceil - 1] != 0) { nonzero_count += left_ratio; } } if (right_ratio > 0 && left_boundary + end_floor < distribution.size()) { if (distribution[left_boundary + end_floor] != 0) { nonzero_count += right_ratio; } } for (size_t j = start_ceil; j < end_floor; j++) { nonzero_count += distribution[left_boundary + j] != 0; } if (nonzero_count == 0) { continue; } //! expand the quantized value float value = quantized_distribution[i] / nonzero_count; if (left_ratio > 0 && start_ceil > 0) { (*expand_distribution)[start_ceil - 1] += value * left_ratio; } if (right_ratio > 0 && end_floor < expand_distribution->size()) { (*expand_distribution)[end_floor] += value * right_ratio; } for (size_t j = start_ceil; j < end_floor; j++) { if (distribution[left_boundary + j] != 0) { (*expand_distribution)[j] = value; } } // end of for } // end of for } /*! Compute quantization threshold bins * Implement Int8 Quantization Algorithm ref: * http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf */ static inline size_t ComputeThreshold(const std::vector &hist, const size_t target_bins) { std::vector P_distribution(hist.size()); size_t zero_point_index = hist.size() / 2; size_t start_bin = target_bins / 2; size_t end_bin = hist.size() / 2; size_t negative_outliers_count = 0; size_t positive_outliers_count = 0; double min_divergence = std::numeric_limits::max(); size_t target_threshold = end_bin; for (size_t threshold = start_bin; threshold <= end_bin; ++threshold) { negative_outliers_count += hist[zero_point_index - threshold]; positive_outliers_count += hist[zero_point_index + threshold - 1]; } //! for each zero-axised quantization range: [-threshold, threshold], search //! the best solution for (size_t threshold = start_bin; threshold <= end_bin; ++threshold) { P_distribution.resize(threshold * 2); auto p_hist = &hist[zero_point_index - threshold]; for (size_t i = 0; i != P_distribution.size(); ++i) { P_distribution[i] = static_cast(p_hist[i]); } negative_outliers_count -= hist[zero_point_index - threshold]; positive_outliers_count -= hist[zero_point_index + threshold - 1]; P_distribution[0] += negative_outliers_count; P_distribution[P_distribution.size() - 1] += positive_outliers_count; //! Quantize the bins in range [-threshold, threshold] to target_bins std::vector Q_distribution(target_bins, 0); float merged_cnt = static_cast(threshold * 2) / target_bins; size_t left_boundary = zero_point_index - threshold; for (size_t i = 0; i < target_bins; ++i) { float start = i * merged_cnt; float end = start + merged_cnt; const size_t start_ceil = static_cast(std::ceil(start)); const size_t end_floor = static_cast(std::floor(end)); if (left_boundary + start_ceil > 0) { Q_distribution[i] += ((float)start_ceil - start) * hist[left_boundary + start_ceil - 1]; } if (left_boundary + end_floor < hist.size()) { Q_distribution[i] += (end - (float)end_floor) * hist[left_boundary + end_floor]; } for (size_t j = start_ceil; j < end_floor; j++) { Q_distribution[i] += hist[left_boundary + j]; } } std::vector Q_expand_distribution; ExpandCandidateDistribution(hist, Q_distribution, threshold, &Q_expand_distribution); //! Compute Kullback-Leibler Divergence, normalize the smooth the data //! first. Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf MakeSmooth(P_distribution); MakeSmooth(Q_expand_distribution); double divergence = ComputeKlDivergence(P_distribution, Q_expand_distribution); if (divergence < min_divergence) { min_divergence = divergence; target_threshold = threshold; } } return target_threshold; } // Quantize the value in range template static inline float QuantizeValue(float val, float scale, float bias) { val = (val + bias) * scale; if (val > RANGE_MAX) { val = RANGE_MAX; } else if (val < RANGE_MIN) { val = RANGE_MIN; } return val; } // Init the historgram params #define INIT_HISTOGRAM() \ { \ if (histogram_bins_ == 0) { \ size_t range = non_bias_ \ ? std::max(std::abs(MIN_VALUE), std::abs(MAX_VALUE)) \ : (MAX_VALUE - MIN_VALUE); \ histogram_bins_ = std::max(4096u, range * 8); \ } \ histogram_.resize((histogram_bins_ + 1) >> 1 << 1); \ if (non_bias_) { \ bias_ = 0.0f; \ auto val = std::max(std::abs(max_), std::abs(min_)); \ left_boundary_ = -val; \ hist_interval_ = (val * 2) / static_cast(histogram_.size()); \ } else { \ bias_ = -static_cast(min_ + (max_ - min_) * 0.5); \ left_boundary_ = min_; \ hist_interval_ = (max_ - min_) / static_cast(histogram_.size()); \ } \ } // Feed vector and update the historgram #define UPDATE_HISTOGRAM(vec, dim) \ { \ if (max_ < min_) { \ return false; \ } \ if (histogram_.size() == 0) { \ INIT_HISTOGRAM() \ } \ for (size_t i = 0; i < dim; ++i) { \ ssize_t index = 0; \ if (hist_interval_ > 0.0) { \ index = \ static_cast((vec[i] - left_boundary_) / hist_interval_); \ } \ if (index < 0) { \ index = 0; \ } else if ((size_t)index >= histogram_.size()) { \ index = histogram_.size() - 1; \ } \ ailego_assert_with((size_t)index < histogram_.size(), "Invalid index"); \ histogram_[index] += 1; \ } \ return true; \ } // Train the quantizer #define TRAIN_QUANTIZER() \ { \ auto sum = std::accumulate(histogram_.begin(), histogram_.end(), 0); \ if (sum == 0) { \ return false; \ } \ size_t target_bins = \ ailego_align(static_cast(MAX_VALUE - MIN_VALUE), 2); \ auto threshold_bins = ComputeThreshold(histogram_, target_bins); \ auto threshold = \ (static_cast(threshold_bins) + 0.5f) * hist_interval_; \ scale_ = target_bins / 2 / threshold; \ if (!non_bias_) { \ bias_ += (MAX_VALUE + MIN_VALUE) * 0.5f / scale_; \ } \ scale_reciprocal_ = 1 / scale_; \ return true; \ } // Feed the INT16 quantizer bool EntropyInt16Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the INT16 quantizer bool EntropyInt16Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to INT16 void EntropyInt16Quantizer::encode(const float *in, size_t dim, int16_t *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = static_cast( std::round(QuantizeValue(in[i], scale_, bias_))); } } // Decode from INT16 void EntropyInt16Quantizer::decode(const int16_t *in, size_t dim, float *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = in[i] * this->scale_reciprocal() - this->bias(); } } // Feed the UINT16 quantizer bool EntropyUInt16Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the UINT16 quantizer bool EntropyUInt16Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to UINT16 void EntropyUInt16Quantizer::encode(const float *in, size_t dim, uint16_t *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = static_cast_from_float_to_uint16( std::round(QuantizeValue(in[i], scale_, bias_))); } } // Decode from INT16 void EntropyUInt16Quantizer::decode(const uint16_t *in, size_t dim, float *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = in[i] * this->scale_reciprocal() - this->bias(); } } // Feed the INT8 quantizer bool EntropyInt8Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the INT8 quantizer bool EntropyInt8Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to INT8 void EntropyInt8Quantizer::encode(const float *in, size_t dim, int8_t *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = static_cast( std::round(QuantizeValue(in[i], scale_, bias_))); } } // Decode from INT8 void EntropyInt8Quantizer::decode(const int8_t *in, size_t dim, float *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = in[i] * this->scale_reciprocal() - this->bias(); } } // Feed the UINT8 quantizer bool EntropyUInt8Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the UINT8 quantizer bool EntropyUInt8Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to INT8 void EntropyUInt8Quantizer::encode(const float *in, size_t dim, uint8_t *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = static_cast_from_float_to_uint8( std::round(QuantizeValue(in[i], scale_, bias_))); } } // Decode from UINT8 void EntropyUInt8Quantizer::decode(const uint8_t *in, size_t dim, float *out) const { for (size_t i = 0; i < dim; ++i) { out[i] = in[i] * this->scale_reciprocal() - this->bias(); } } // Feed the INT4 quantizer bool EntropyInt4Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the INT4 quantizer bool EntropyInt4Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to INT4 void EntropyInt4Quantizer::encode(const float *in, size_t dim, uint8_t *out) const { ailego_assert_with(dim % 2 == 0, "Dimension must be aligned with 2"); for (size_t i = 0; i < dim; i += 2) { float lo = QuantizeValue(in[i], scale_, bias_); float hi = QuantizeValue(in[i + 1], scale_, bias_); out[i / 2] = (static_cast_from_float_to_uint8(std::round(hi)) << 4) | (static_cast_from_float_to_uint8(std::round(lo)) & 0xF); } } // Decode from INT4 void EntropyInt4Quantizer::decode(const uint8_t *in, size_t dim, float *out) const { ailego_assert_with(dim % 2 == 0, "Dimension must be aligned with 2"); size_t size = dim / 2; for (size_t i = 0; i < size; i += 1) { uint8_t v = in[i]; int8_t lo = (static_cast(v << 4) >> 4); int8_t hi = (static_cast(v & 0xf0) >> 4); out[2 * i] = lo * this->scale_reciprocal() - this->bias(); out[2 * i + 1] = hi * this->scale_reciprocal() - this->bias(); } } // Feed the UINT4 quantizer bool EntropyUInt4Quantizer::feed(const float *vec, size_t dim) { UPDATE_HISTOGRAM(vec, dim) } // Train the UINT4 quantizer bool EntropyUInt4Quantizer::train(void) { TRAIN_QUANTIZER() } // Encode to INT4 void EntropyUInt4Quantizer::encode(const float *in, size_t dim, uint8_t *out) const { ailego_assert_with(dim % 2 == 0, "Dimension must be aligned with 2"); for (size_t i = 0; i < dim; i += 2) { float lo = QuantizeValue(in[i], scale_, bias_); float hi = QuantizeValue(in[i + 1], scale_, bias_); out[i / 2] = (static_cast_from_float_to_uint8(std::round(hi)) << 4) | (static_cast_from_float_to_uint8(std::round(lo)) & 0xF); } } // Decode from INT4 void EntropyUInt4Quantizer::decode(const uint8_t *in, size_t dim, float *out) const { ailego_assert_with(dim % 2 == 0, "Dimension must be aligned with 2"); size_t size = dim / 2; for (size_t i = 0; i < size; i += 1) { uint8_t v = in[i]; out[2 * i] = (v & 0xf) * this->scale_reciprocal() - this->bias(); out[2 * i + 1] = (v >> 4) * this->scale_reciprocal() - this->bias(); } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/algorithm/integer_quantizer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include namespace zvec { namespace ailego { /*! Entropy-based Integer Quantization Algorithm */ template class EntropyIntegerQuantizer { public: //! Primitive Built-in Types to store the quantized data using ValueType = typename std::remove_cv::type; //! Constants constexpr static int MIN_VALUE = RANGE_MIN; constexpr static int MAX_VALUE = RANGE_MAX; // Check supporting type static_assert(std::is_integral::value, "ValueType must be integral"); // Check template values static_assert(RANGE_MIN < RANGE_MAX, "Invalid value range"); //! Constructor EntropyIntegerQuantizer(void) {} //! Set histogram bins in train void set_histogram_bins(size_t bins) { if (bins > (RANGE_MAX - RANGE_MIN)) { histogram_bins_ = bins; } } //! Set quantization params scale void set_scale(float val) { if (val > 0.0f) { scale_ = val; scale_reciprocal_ = 1 / scale_; } } //! Set quantization params bias void set_bias(float val) { bias_ = val; } //! Set quantization params max void set_max(float val) { max_ = val; } //! Set quantization params min void set_min(float val) { min_ = val; } //! Set quantization params non bias void set_non_bias(bool val) { non_bias_ = val; } //! Get histogram bins in train size_t histogram_bins(void) const { return histogram_bins_; } //! Get quantization params scale float scale(void) const { return scale_; } //! Get quantization params bias float bias(void) const { return bias_; } //! Get quantization params max float max(void) const { return max_; } //! Get quantization params min float min(void) const { return min_; } //! Get quantization params non bias bool non_bias(void) const { return non_bias_; } //! Retrieve the scale reciprocal for decoding float scale_reciprocal(void) const { return scale_reciprocal_; } protected: //! Disable them EntropyIntegerQuantizer(const EntropyIntegerQuantizer &) = delete; EntropyIntegerQuantizer &operator=(const EntropyIntegerQuantizer &) = delete; //! Members size_t histogram_bins_{0}; float hist_interval_{1.0f}; float max_{std::numeric_limits::min()}; float min_{std::numeric_limits::max()}; float bias_{0.0f}; float scale_{0.0f}; float scale_reciprocal_{0.0f}; float left_boundary_{0.0f}; bool non_bias_{false}; std::vector histogram_{}; }; /*! INT16 Quantizer */ class EntropyInt16Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to int16 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from int16 void decode(const ValueType *in, size_t dim, float *out) const; }; /*! UINT16 Quantizer */ class EntropyUInt16Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to uint16 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from uint16 void decode(const ValueType *in, size_t dim, float *out) const; }; /*! INT8 Quantizer */ class EntropyInt8Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to int8 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from int8 void decode(const ValueType *in, size_t dim, float *out) const; }; /*! UINT8 Quantizer */ class EntropyUInt8Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to uint8 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from uint8 void decode(const ValueType *in, size_t dim, float *out) const; }; /*! INT4 Quantizer */ class EntropyInt4Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to int4 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from int4 void decode(const ValueType *in, size_t dim, float *out) const; }; /*! UINT4 Quantizer */ class EntropyUInt4Quantizer : public EntropyIntegerQuantizer { public: //! Feed the training data bool feed(const float *vec, size_t dim); //! Train the quantizer bool train(void); //! Encode float vector to uint4 void encode(const float *in, size_t dim, ValueType *out) const; //! Decode to float vector from uint4 void decode(const ValueType *in, size_t dim, float *out) const; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/algorithm/kmeans.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "lloyd_cluster.h" namespace zvec { namespace ailego { /*! K-MC2 Centroids Generator */ template class Kmc2CentroidsGenerator { public: //! Type of values using OwnerType = typename std::decay::type; using ContainerType = typename OwnerType::ContainerType; using ContextType = typename OwnerType::ContextType; using ValueType = typename OwnerType::ValueType; using StoreType = typename OwnerType::StoreType; using ThreadPoolType = TPool; //! constexpr variables constexpr static size_t BatchCount = OwnerType::BatchCount; //! Generate centroids void operator()(OwnerType *owner, ThreadPoolType &pool) const { if (chain_length_ == 0) { this->init_centroids_random(owner); } else if (!assumption_free_) { this->init_centroids_kmc2(owner, pool); } else { this->init_centroids_afkmc2(owner, pool); } } //! Retrieve the markov chain length size_t chain_length(void) const { return chain_length_; } //! Set the mutable markov chain length void set_chain_length(size_t len) { chain_length_ = len; } //! Retrieve assumption free option bool assumption_free(void) const { return assumption_free_; } //! Set the assumption free option void set_assumption_free(bool val) { assumption_free_ = val; } protected: //! Initialize centroids randomly void init_centroids_random(OwnerType *owner) const { RandomSelectBenches(owner->feature_cache(), owner->feature_matrix(), owner->k_value(), owner->mutable_centroids()); } //! Initialize centroids with K-MC2 void init_centroids_kmc2(OwnerType *owner, ThreadPoolType &pool) const { const auto &matrix = owner->feature_matrix(); const auto &cache = owner->feature_cache(); auto *centroids = owner->mutable_centroids(); std::mt19937 mt((std::random_device())()); std::uniform_real_distribution dist(0.0, 1.0); ContainerType benches(cache.dimension()); std::vector scores; // Sample first center uniformly RandomSelectBenches(cache, matrix, 1, centroids); // Make a thread group auto group = pool.make_group(); for (size_t i = 1, k = owner->k_value(); i < k; ++i) { RandomSelectBenches(cache, matrix, chain_length_, &benches); // Update bench scores scores.resize(benches.count()); for (size_t j = 0; j != scores.size(); ++j) { group->submit(Closure::New(&Kmc2CentroidsGenerator::UpdateBenchScores, centroids, benches[j], &scores[j])); } group->wait_finish(); //! Select the better centroid randomly float x = scores[0]; size_t xj = 0; for (size_t j = 1; j != scores.size(); ++j) { float y = scores[j]; if (x == 0.0f || x * dist(mt) < y) { x = y; xj = j; } } centroids->append(benches[xj], benches.dimension()); } // end of for } //! Initialize centroids with K-MC2 void init_centroids_afkmc2(OwnerType *owner, ThreadPoolType &pool) const { const auto &matrix = owner->feature_matrix(); const auto &cache = owner->feature_cache(); // Probability std::vector probs(matrix.count() + cache.count()); // Sample first center uniformly RandomSelectBenches(cache, matrix, 1, owner->mutable_centroids()); // Make a thread group auto group = pool.make_group(); if (!matrix.empty()) { size_t n = matrix.count() / BatchCount; size_t c = std::max(n / pool.count() / 2u, 1u); size_t m = n / c * c; for (size_t i = 0; i != m; i += c) { group->submit(Closure::New(&Kmc2CentroidsGenerator::UpdateMatrixScores, owner, i, i + c, &probs[0])); } for (size_t i = m; i != n; i += 1) { group->submit(Closure::New(&Kmc2CentroidsGenerator::UpdateMatrixScores, owner, i, i + 1, &probs[0])); } } if (!cache.empty()) { group->submit(Closure::New(&Kmc2CentroidsGenerator::UpdateCacheScores, owner, &probs[matrix.count()])); } group->wait_finish(); // Update probabilities double p_sum = std::accumulate(probs.begin(), probs.end(), 0.0); for (auto it = probs.begin(); it != probs.end(); ++it) { *it = static_cast((*it / p_sum + 1.0 / probs.size()) * 0.5); } std::mt19937 mt((std::random_device())()); std::uniform_real_distribution dist(0.0, 1.0); ContainerType benches(cache.dimension()); std::vector scores; std::vector bench_probs; for (size_t i = 1; i < owner->k_value(); ++i) { RandomSelectBenches(cache, matrix, chain_length_, probs, &benches, &bench_probs); // Update bench scores scores.resize(benches.count()); for (size_t j = 0; j != scores.size(); ++j) { group->submit(Closure::New(&Kmc2CentroidsGenerator::UpdateBenchScores, owner->mutable_centroids(), benches[j], &scores[j])); } group->wait_finish(); // Update scores with probabilities for (size_t j = 0; j != scores.size(); ++j) { scores[j] /= bench_probs[j]; } //! Select the better centroid randomly float x = scores[0]; size_t xj = 0; for (size_t j = 1; j != scores.size(); ++j) { float y = scores[j]; if (x == 0.0f || x * dist(mt) < y) { x = y; xj = j; } } owner->mutable_centroids()->append(benches[xj], benches.dimension()); } // end of for } //! Update matrix score static void UpdateMatrixScores(const OwnerType *owner, size_t first, size_t last, float *out) { const auto &matrix = owner->feature_matrix(); const auto *bench = owner->centroids().data(); for (size_t i = first * BatchCount; i != last * BatchCount; i += BatchCount) { ContextType::template BatchDistance<1>(matrix[i], bench, matrix.dimension(), &out[i]); } } //! Update cache score static void UpdateCacheScores(const OwnerType *owner, float *out) { const auto &cache = owner->feature_cache(); const auto *bench = owner->centroids().data(); for (size_t i = 0, n = cache.count(); i != n; ++i) { ContextType::Distance(bench, cache[i], cache.dimension(), &out[i]); } } //! Update bench score static void UpdateBenchScores(const ContainerType *benches, const StoreType *feat, float *out) { float min_score = std::numeric_limits::max(); for (size_t i = 0, c = benches->count(); i != c; ++i) { float new_score; ContextType::Distance(benches->at(i), feat, benches->dimension(), &new_score); if (new_score < min_score) { min_score = new_score; } } *out = min_score; } //! Select k benches randomly static void RandomSelectBenches(const ContainerType &cache, const ContainerType &matrix, size_t k, ContainerType *benches) { ContainerType rows(cache.dimension()); size_t m = matrix.count(); size_t n = m + cache.count(); std::mt19937 mt((std::random_device())()); rows.resize(BatchCount); benches->reset(cache.dimension()); benches->reserve(k); for (size_t i = 0; k > 0 && i < n; ++i) { if (mt() % (n - i) >= k) { continue; } // Selected a feature if (i < m) { ContextType::MatrixReverseTranspose(matrix[i / BatchCount * BatchCount], matrix.dimension(), rows.data()); benches->append(rows[i & (BatchCount - 1u)], matrix.dimension()); } else { benches->append(cache[i - m], cache.dimension()); } --k; } // end of for } //! Select k benches randomly static void RandomSelectBenches(const ContainerType &cache, const ContainerType &matrix, size_t k, const std::vector &probs, ContainerType *benches, std::vector *bench_probs) { std::mt19937 mt((std::random_device())()); std::uniform_real_distribution dist(0.0, 1.0); // Sample features KeyValueHeap> samples(k); for (size_t i = 0; i < probs.size(); ++i) { samples.emplace(i, std::pow(dist(mt), 1.0 / probs[i])); } ContainerType rows(cache.dimension()); size_t matrix_count = matrix.count(); rows.resize(BatchCount); benches->reset(cache.dimension()); benches->reserve(k); bench_probs->clear(); bench_probs->reserve(k); for (const auto &it : samples) { // Selected a feature if (it.first < matrix_count) { ContextType::MatrixReverseTranspose( matrix[it.first / BatchCount * BatchCount], matrix.dimension(), rows.data()); benches->append(rows[it.first & (BatchCount - 1u)], matrix.dimension()); } else { benches->append(cache[it.first - matrix_count], cache.dimension()); } bench_probs->push_back(probs[it.first]); } } private: size_t chain_length_{32}; bool assumption_free_{false}; }; /*! Numerical K-Means Context */ template class NumericalKmeansContext { public: //! constexpr variables constexpr static size_t BatchCount = BATCH_COUNT; //! Type of values using ValueType = typename std::remove_cv::type; using StoreType = typename std::remove_cv::type; // Check supporting type static_assert(IsSignedArithmetic::value, "ValueType must be signed arithmetic"); /*! K-Means Context Cluster */ class Cluster { public: //! Constructor Cluster(size_t dim) : accum_(dim, 0.0) {} //! Constructor Cluster(const Cluster &rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(rhs.accum_) {} //! Constructor Cluster(Cluster &&rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(std::move(rhs.accum_)) {} //! Assignment Cluster &operator=(const Cluster &rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = rhs.accum_; return *this; } //! Assignment Cluster &operator=(Cluster &&rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = std::move(rhs.accum_); return *this; } //! Append a vector void append(const ValueType *vec, size_t dim, float dist) { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); mutex_.lock(); cost_ += dist; count_ += 1; for (size_t i = 0; i != dim; ++i) { accum_[i] += vec[i]; } mutex_.unlock(); } //! Retrieve the centroid of vectors void centroid(ValueType *out, size_t dim) const { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); for (size_t i = 0; i != dim; ++i) { out[i] = count_ == 0 ? FloatCast(NAN) : FloatCast(accum_[i] / count_); } } //! Retrieve squared error double cost(void) const { return cost_; } //! Retrieve feature count size_t count(void) const { return count_; } protected: //! Convert float type to another type template static auto FloatCast(const double val) -> typename std::enable_if::value, U>::type { return static_cast(val); } //! Convert float type to another type template static auto FloatCast(const double val) -> typename std::enable_if::value, U>::type { return static_cast(std::round(val)); } private: SpinMutex mutex_{}; double cost_{0.0}; size_t count_{0u}; std::vector accum_{}; }; //! operator [] const Cluster &operator[](size_t i) const { return clusters_[i]; } //! operator [] Cluster &operator[](size_t i) { return clusters_[i]; } //! Clear the context void clear(void) { clusters_.clear(); } //! Reset the context void reset(size_t k_value, size_t dim) { clusters_.clear(); clusters_.resize(k_value, dim); } //! Retrieve context of clusters const std::vector &clusters(void) const { return clusters_; } //! Compute the distance between matrix and query (batch) template static void BatchDistance(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); } //! Compute the distance between matrix and query (single) static void Distance(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); } //! Transpose a matrix template static auto MatrixTranspose(const U *src, size_t dim, T *dst) -> typename std::enable_if= 2>::type { MatrixHelper::Transpose(src, dim, dst); } //! Transpose a matrix template static auto MatrixTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if::type { MatrixHelper::Transpose(src, dim >> 2, dst); } //! Reverse transpose a matrix template static auto MatrixReverseTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if= 2>::type { MatrixHelper::ReverseTranspose(src, dim, dst); } //! Reverse transpose a matrix template static auto MatrixReverseTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if::type { MatrixHelper::ReverseTranspose(src, dim >> 2, dst); } //! Compute Norm2 template ::value>::type> static void Norm2(ValueType *data, size_t dim, float *norm) { Normalizer::L2(data, dim, norm); } //! Compute Norm2, for non-float do nothing static void Norm2(ValueType * /*data*/, size_t /*dim*/, float *norm) { *norm = 0.0f; } private: //! Members std::vector clusters_{}; }; /*! Nibble K-Means Context (INT4) */ template class NibbleKmeansContext { public: //! constexpr variables constexpr static size_t BatchCount = BATCH_COUNT; //! Type of values using ValueType = typename std::remove_cv::type; using StoreType = typename std::make_unsigned::type; // Check supporting type static_assert(std::is_same::value || std::is_same::value, "ValueType must be int32_t or int64_t"); /*! K-Means Context Cluster */ class Cluster { public: //! Constructor Cluster(size_t dim) : accum_(dim, 0.0) {} //! Constructor Cluster(const Cluster &rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(rhs.accum_) {} //! Constructor Cluster(Cluster &&rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(std::move(rhs.accum_)) {} //! Assignment Cluster &operator=(const Cluster &rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = rhs.accum_; return *this; } //! Assignment Cluster &operator=(Cluster &&rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = std::move(rhs.accum_); return *this; } //! Append a vector void append(const StoreType *vec, size_t dim, float dist) { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); mutex_.lock(); cost_ += dist; count_ += 1; const uint8_t *arr = reinterpret_cast(vec); dim = (dim >> 1) << 1; for (size_t i = 0; i != dim; i += 2) { uint8_t val = arr[i >> 1]; accum_[i] += ((int8_t)(val << 4) >> 4); accum_[i + 1] += ((int8_t)(val) >> 4); } mutex_.unlock(); } //! Retrieve the centroid of vectors void centroid(StoreType *out, size_t dim) const { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); uint8_t *arr = reinterpret_cast(out); dim = (dim >> 1) << 1; for (size_t i = 0; i != dim; i += 2) { int lo = count_ == 0 ? 0 : static_cast(std::round(accum_[i] / count_)); int hi = count_ == 0 ? 0 : static_cast(std::round(accum_[i + 1] / count_)); arr[i >> 1] = (uint8_t)((hi << 4) & 0xf0) | (uint8_t)(lo & 0xf); } } //! Retrieve squared error double cost(void) const { return cost_; } //! Retrieve feature count size_t count(void) const { return count_; } private: SpinMutex mutex_{}; double cost_{0.0}; size_t count_{0u}; std::vector accum_{}; }; //! operator [] const Cluster &operator[](size_t i) const { return clusters_[i]; } //! operator [] Cluster &operator[](size_t i) { return clusters_[i]; } //! Clear the context void clear(void) { clusters_.clear(); } //! Reset the context void reset(size_t k_value, size_t dim) { clusters_.clear(); clusters_.resize(k_value, dim); } //! Retrieve context of clusters const std::vector &clusters(void) const { return clusters_; } //! Compute the distance between matrix and query (batch) template static void BatchDistance(const StoreType *m, const StoreType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, out); } //! Compute the distance between matrix and query (single) static void Distance(const StoreType *m, const StoreType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, out); } //! Transpose a matrix static void MatrixTranspose(const StoreType *src, size_t dim, StoreType *dst) { MatrixHelper::Transpose(src, dim >> 3, dst); } //! Reverse transpose a matrix static void MatrixReverseTranspose(const StoreType *src, size_t dim, StoreType *dst) { MatrixHelper::ReverseTranspose(src, dim >> 3, dst); } //! Compute and do norm2 static void Norm2(StoreType * /*data*/, size_t /*dim*/, float *norm) { *norm = 0; } private: //! Members std::vector clusters_{}; }; /*! Binary K-Means Context */ template class BinaryKmeansContext { public: //! constexpr variables constexpr static size_t BatchCount = BATCH_COUNT; //! Type of values using ValueType = typename std::remove_cv::type; using StoreType = typename std::remove_cv::type; // Check supporting type static_assert(std::is_same::value || std::is_same::value, "ValueType must be uint32_t or uint64_t"); /*! K-Means Context Cluster */ class Cluster { public: //! Constructor Cluster(size_t dim) : accum_(dim, 0) {} //! Constructor Cluster(const Cluster &rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(rhs.accum_) {} //! Constructor Cluster(Cluster &&rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(std::move(rhs.accum_)) {} //! Assignment Cluster &operator=(const Cluster &rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = rhs.accum_; return *this; } //! Assignment Cluster &operator=(Cluster &&rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = std::move(rhs.accum_); return *this; } //! Append a vector void append(const ValueType *vec, size_t dim, float dist) { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); mutex_.lock(); cost_ += dist; count_ += 1; const uint8_t *arr = reinterpret_cast(vec); for (size_t i = 0; i != dim; ++i) { if (arr[i >> 3] & (1u << (i & 7))) { accum_[i] += 1; } } mutex_.unlock(); } //! Retrieve the centroid of vectors void centroid(ValueType *out, size_t dim) const { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); uint8_t *arr = reinterpret_cast(out); size_t half = count_ >> 1; for (size_t i = 0; i != dim; ++i) { if (accum_[i] > half) { arr[i >> 3] |= static_cast(1 << (i & 0x7)); } else { arr[i >> 3] &= ~static_cast(1 << (i & 0x7)); } } } //! Retrieve squared error double cost(void) const { return cost_; } //! Retrieve feature count size_t count(void) const { return count_; } private: SpinMutex mutex_{}; double cost_{0.0}; size_t count_{0u}; std::vector accum_{}; }; //! operator [] const Cluster &operator[](size_t i) const { return clusters_[i]; } //! operator [] Cluster &operator[](size_t i) { return clusters_[i]; } //! Clear the context void clear(void) { clusters_.clear(); } //! Reset the context void reset(size_t k_value, size_t dim) { clusters_.clear(); clusters_.resize(k_value, dim); } //! Retrieve context of clusters const std::vector &clusters(void) const { return clusters_; } //! Compute the distance between matrix and query (batch) template static void BatchDistance(const ValueType *m, const ValueType *q, size_t dim, float *out) { HammingDistanceMatrix::Compute(m, q, dim, out); } //! Compute the distance between matrix and query (single) static void Distance(const ValueType *m, const ValueType *q, size_t dim, float *out) { HammingDistanceMatrix::Compute(m, q, dim, out); } //! Transpose a matrix static void MatrixTranspose(const ValueType *src, size_t dim, T *dst) { MatrixHelper::Transpose( src, (dim >> 3) / sizeof(ValueType), dst); } //! Reverse transpose a matrix static void MatrixReverseTranspose(const ValueType *src, size_t dim, T *dst) { MatrixHelper::ReverseTranspose( src, (dim >> 3) / sizeof(ValueType), dst); } //! Compute Norm2 static void Norm2(ValueType * /*data*/, size_t /*dim*/, float *norm) { *norm = 0; } private: //! Members std::vector clusters_{}; }; /*! Numerical InnerProduct K-Means Context */ template class NumericalInnerProductKmeansContext { public: //! constexpr variables constexpr static size_t BatchCount = BATCH_COUNT; //! Type of values using ValueType = typename std::remove_cv::type; using StoreType = typename std::remove_cv::type; // Check supporting type static_assert(IsSignedArithmetic::value, "ValueType must be signed arithmetic"); /*! K-Means Context Cluster */ class Cluster { public: //! Constructor Cluster(size_t dim) : accum_(dim, 0.0) {} //! Constructor Cluster(const Cluster &rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(rhs.accum_) {} //! Constructor Cluster(Cluster &&rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(std::move(rhs.accum_)) {} //! Assignment Cluster &operator=(const Cluster &rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = rhs.accum_; return *this; } //! Assignment Cluster &operator=(Cluster &&rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = std::move(rhs.accum_); return *this; } //! Append a vector void append(const ValueType *vec, size_t dim, float dist) { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); mutex_.lock(); cost_ += dist; count_ += 1; for (size_t i = 0; i != dim; ++i) { accum_[i] += vec[i]; } mutex_.unlock(); } //! Retrieve the centroid of vectors void centroid(ValueType *out, size_t dim) const { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); for (size_t i = 0; i != dim; ++i) { out[i] = count_ == 0 ? FloatCast(NAN) : FloatCast(accum_[i] / count_); } } //! Retrieve squared error double cost(void) const { return cost_; } //! Retrieve feature count size_t count(void) const { return count_; } protected: //! Convert float type to another type template static auto FloatCast(const double val) -> typename std::enable_if::value, U>::type { return static_cast(val); } //! Convert float type to another type template static auto FloatCast(const double val) -> typename std::enable_if::value, U>::type { return static_cast(std::round(val)); } private: SpinMutex mutex_{}; double cost_{0.0}; size_t count_{0u}; std::vector accum_{}; }; //! operator [] const Cluster &operator[](size_t i) const { return clusters_[i]; } //! operator [] Cluster &operator[](size_t i) { return clusters_[i]; } //! Clear the context void clear(void) { clusters_.clear(); } //! Reset the context void reset(size_t k_value, size_t dim) { clusters_.clear(); clusters_.resize(k_value, dim); } //! Retrieve context of clusters const std::vector &clusters(void) const { return clusters_; } //! Compute the distance between matrix and query (batch) template static void BatchDistance(const ValueType *m, const ValueType *q, size_t dim, float *out) { MinusInnerProductMatrix::Compute(m, q, dim, out); } //! Compute the distance between matrix and query (single) static void Distance(const ValueType *m, const ValueType *q, size_t dim, float *out) { MinusInnerProductMatrix::Compute(m, q, dim, out); } //! Transpose a matrix template static auto MatrixTranspose(const U *src, size_t dim, T *dst) -> typename std::enable_if= 2>::type { MatrixHelper::Transpose(src, dim, dst); } //! Transpose a matrix template static auto MatrixTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if::type { MatrixHelper::Transpose(src, dim >> 2, dst); } //! Reverse transpose a matrix template static auto MatrixReverseTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if= 2>::type { MatrixHelper::ReverseTranspose(src, dim, dst); } //! Reverse transpose a matrix template static auto MatrixReverseTranspose(const U *src, size_t dim, U *dst) -> typename std::enable_if::type { MatrixHelper::ReverseTranspose(src, dim >> 2, dst); } //! Compute Norm2 template ::value>::type> static void Norm2(ValueType *data, size_t dim, float *norm) { Normalizer::L2(data, dim, norm); } //! Compute Norm2, for non-float do nothing static void Norm2(ValueType * /*data*/, size_t /*dim*/, float *norm) { *norm = 0.0f; } private: //! Members std::vector clusters_{}; }; /*! Nibble InnerProduct K-Means Context (INT4) */ template class NibbleInnerProductKmeansContext { public: //! constexpr variables constexpr static size_t BatchCount = BATCH_COUNT; //! Type of values using ValueType = typename std::remove_cv::type; using StoreType = typename std::make_unsigned::type; // Check supporting type static_assert(std::is_same::value || std::is_same::value, "ValueType must be int32_t or int64_t"); /*! K-Means Context Cluster */ class Cluster { public: //! Constructor Cluster(size_t dim) : accum_(dim, 0.0) {} //! Constructor Cluster(const Cluster &rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(rhs.accum_) {} //! Constructor Cluster(Cluster &&rhs) : cost_(rhs.cost_), count_(rhs.count_), accum_(std::move(rhs.accum_)) {} //! Assignment Cluster &operator=(const Cluster &rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = rhs.accum_; return *this; } //! Assignment Cluster &operator=(Cluster &&rhs) { cost_ = rhs.cost_; count_ = rhs.count_; accum_ = std::move(rhs.accum_); return *this; } //! Append a vector void append(const StoreType *vec, size_t dim, float dist) { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); mutex_.lock(); cost_ += dist; count_ += 1; const uint8_t *arr = reinterpret_cast(vec); dim = (dim >> 1) << 1; for (size_t i = 0; i != dim; i += 2) { uint8_t val = arr[i >> 1]; accum_[i] += ((int8_t)(val << 4) >> 4); accum_[i + 1] += ((int8_t)(val) >> 4); } mutex_.unlock(); } //! Retrieve the centroid of vectors void centroid(StoreType *out, size_t dim) const { ailego_check_with(dim == accum_.size(), "Unmatched dimension"); uint8_t *arr = reinterpret_cast(out); dim = (dim >> 1) << 1; for (size_t i = 0; i != dim; i += 2) { int lo = count_ == 0 ? 0 : static_cast(std::round(accum_[i] / count_)); int hi = count_ == 0 ? 0 : static_cast(std::round(accum_[i + 1] / count_)); arr[i >> 1] = (uint8_t)((hi << 4) & 0xf0) | (uint8_t)(lo & 0xf); } } //! Retrieve squared error double cost(void) const { return cost_; } //! Retrieve feature count size_t count(void) const { return count_; } private: SpinMutex mutex_{}; double cost_{0.0}; size_t count_{0u}; std::vector accum_{}; }; //! operator [] const Cluster &operator[](size_t i) const { return clusters_[i]; } //! operator [] Cluster &operator[](size_t i) { return clusters_[i]; } //! Clear the context void clear(void) { clusters_.clear(); } //! Reset the context void reset(size_t k_value, size_t dim) { clusters_.clear(); clusters_.resize(k_value, dim); } //! Retrieve context of clusters const std::vector &clusters(void) const { return clusters_; } //! Compute the distance between matrix and query (batch) template static void BatchDistance(const StoreType *m, const StoreType *q, size_t dim, float *out) { MinusInnerProductMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, out); } //! Compute the distance between matrix and query (single) static void Distance(const StoreType *m, const StoreType *q, size_t dim, float *out) { MinusInnerProductMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, out); } //! Transpose a matrix static void MatrixTranspose(const StoreType *src, size_t dim, StoreType *dst) { MatrixHelper::Transpose(src, dim >> 3, dst); } //! Reverse transpose a matrix static void MatrixReverseTranspose(const StoreType *src, size_t dim, StoreType *dst) { MatrixHelper::ReverseTranspose(src, dim >> 3, dst); } //! Compute Norm2 static void Norm2(StoreType * /*data*/, size_t /*dim*/, float *norm) { *norm = 0; } private: //! Members std::vector clusters_{}; }; /*! Numerical K-Means cluster algorithm */ template > using NumericalKmeans = LloydCluster>; /*! Nibble K-Means cluster algorithm */ template > using NibbleKmeans = LloydCluster>; /*! Binary K-Means cluster algorithm */ template > using BinaryKmeans = LloydCluster>; /*! Numerical K-Means cluster algorithm */ template > using NumericalInnerProductKmeans = LloydCluster>; /*! Nibble K-Means cluster algorithm */ template > using NibbleInnerProductKmeans = LloydCluster>; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/algorithm/lloyd_cluster.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include namespace zvec { namespace ailego { /*! Random Centroids Generator */ template struct RandomCentroidsGenerator { //! Type of values using OwnerType = typename std::decay::type; using ContainerType = typename OwnerType::ContainerType; using ContextType = typename OwnerType::ContextType; using ThreadPoolType = TPool; //! constexpr variables constexpr static size_t BatchCount = OwnerType::BatchCount; //! Generate centroids void operator()(OwnerType *owner, ThreadPoolType &) const { const auto &matrix = owner->feature_matrix(); const auto &cache = owner->feature_cache(); auto *centroids = owner->mutable_centroids(); ContainerType rows(cache.dimension()); size_t m = matrix.count(); size_t n = m + cache.count(); size_t k = owner->k_value(); std::mt19937 mt((std::random_device())()); rows.resize(BatchCount); centroids->reset(cache.dimension()); centroids->reserve(k); for (size_t i = 0; k > 0 && i < n; ++i) { if (mt() % (n - i) >= k) { continue; } // Selected a feature if (i < m) { ContextType::MatrixReverseTranspose(matrix[i / BatchCount * BatchCount], matrix.dimension(), rows.data()); centroids->append(rows[i & (BatchCount - 1u)], matrix.dimension()); } else { centroids->append(cache[i - m], cache.dimension()); } --k; } } }; /*! Lloyd's algorithm cluster */ template class LloydCluster { public: //! constexpr variables constexpr static size_t BatchCount = TContext::BatchCount; //! Type of values using ThreadPoolType = TPool; using ContainerType = TContainer; using ContextType = TContext; using ValueType = typename TContext::ValueType; using StoreType = typename TContext::StoreType; //! Constructor LloydCluster(size_t k, size_t dim) : k_value_(k), feature_cache_(dim), feature_matrix_(dim), centroids_matrix_(dim), centroids_(dim) {} //! Constructor LloydCluster(size_t k, size_t dim, bool spherical) : k_value_(k), feature_cache_(dim), feature_matrix_(dim), centroids_matrix_(dim), centroids_(dim), spherical_{spherical} {} //! Constructor LloydCluster(void) {} //! Destructor ~LloydCluster(void) {} //! Append a feature void append(const StoreType *arr, size_t dim) { feature_cache_.append(arr, dim); if (feature_cache_.count() == BatchCount) { size_t pos = feature_matrix_.count(); feature_matrix_.resize(pos + BatchCount); ContextType::MatrixTranspose(feature_cache_.data(), dim, feature_matrix_[pos]); feature_cache_.clear(); } } //! Reset cluster void reset(size_t k, size_t dim) { k_value_ = k; feature_cache_.reset(dim); feature_matrix_.reset(dim); centroids_.reset(dim); centroids_matrix_.reset(dim); context_.clear(); } //! Reset cluster void reset(size_t k, size_t dim, bool spherical) { k_value_ = k; feature_cache_.reset(dim); feature_matrix_.reset(dim); centroids_.reset(dim); centroids_matrix_.reset(dim); context_.clear(); spherical_ = spherical; } //! Initialize centroids template > void init_centroids(ThreadPoolType &pool, const G &g = G()) { g(this, pool); } //! Cluster one time template bool cluster_once(ThreadPoolType &pool, double *cost) { if (centroids_.empty()) { RandomCentroidsGenerator g; this->init_centroids(pool, g); } if (centroids_.count() != k_value_) { return false; } context_.reset(centroids_.count(), centroids_.dimension()); size_t count = centroids_.count() / BatchCount * BatchCount; centroids_matrix_.resize(count); for (size_t i = 0; i != count; i += BatchCount) { ContextType::MatrixTranspose(centroids_[i], centroids_.dimension(), centroids_matrix_[i]); } size_t remain = static_cast(centroids_.count() - count); if (remain > 0) { centroids_matrix_.append(centroids_[count], centroids_.dimension(), remain); } // Using thread pool auto group = pool.make_group(); if (!feature_matrix_.empty()) { size_t n = feature_matrix_.count() / BatchCount; size_t c = std::max(n / pool.count() / 2u, 1u); size_t m = n / c * c; for (size_t i = 0; i != m; i += c) { group->submit(Closure::New(this, &LloydCluster::cluster_matrix_features, i, i + c)); } for (size_t i = m; i != n; i += 1) { group->submit(Closure::New(this, &LloydCluster::cluster_matrix_features, i, i + 1)); } } if (!feature_cache_.empty()) { group->submit(Closure::New(this, &LloydCluster::cluster_cache_features)); } group->wait_finish(); *cost = 0.0; for (size_t i = 0, n = centroids_.count(); i != n; ++i) { const auto &item = context_[i]; item.centroid(centroids_[i], centroids_.dimension()); *cost += item.cost(); } if (spherical_) { for (size_t i = 0, n = centroids_.count(); i != n; ++i) { float norm; ContextType::Norm2(centroids_[i], centroids_.dimension(), &norm); } } return true; } //! Retrieve the controids ContainerType *mutable_centroids(void) { return ¢roids_; } //! Retrieve the controids const ContainerType ¢roids(void) const { return centroids_; } //! Retrieve the K value size_t k_value(void) const { return k_value_; } //! Retrieve context const ContextType &context(void) const { return context_; } //! Retrieve the feature cache const ContainerType &feature_cache(void) const { return feature_cache_; } //! Retrieve the feature matrix const ContainerType &feature_matrix(void) const { return feature_matrix_; } //! Reserve the feature matrix void feature_matrix_reserve(size_t count) { feature_matrix_.reserve(count); } protected: //! Cluster the cache features void cluster_cache_features(void) { std::array scores; for (size_t i = 0, n = feature_cache_.count(); i != n; ++i) { size_t count = centroids_matrix_.count() / BatchCount * BatchCount; const StoreType *feature = feature_cache_[i]; float nearest_score = std::numeric_limits::max(); size_t nearest_index = 0; for (size_t j = 0; j != count; j += BatchCount) { ContextType::template BatchDistance<1>(centroids_matrix_[j], feature, centroids_matrix_.dimension(), scores.data()); for (size_t k = 0; k < BatchCount; ++k) { if (scores[k] < nearest_score) { nearest_score = scores[k]; nearest_index = j + k; } } } // end of for for (size_t j = count, total = centroids_matrix_.count(); j != total; ++j) { ContextType::Distance(centroids_matrix_[j], feature, centroids_matrix_.dimension(), scores.data()); if (scores[0] < nearest_score) { nearest_score = scores[0]; nearest_index = j; } } context_[nearest_index].append(feature, feature_cache_.dimension(), nearest_score); } // end of for } //! Cluster the matrix features void cluster_matrix_features(size_t first, size_t last) { std::array scores; ContainerType rows(centroids_matrix_.dimension()); auto comp = [](float i, float j) { if (std::isnan(i)) return false; if (std::isnan(j)) return true; return i < j; }; std::array nearest_scores; std::array nearest_indexes; rows.resize(BatchCount); for (size_t i = first * BatchCount; i != last * BatchCount; i += BatchCount) { size_t count = centroids_matrix_.count() / BatchCount * BatchCount; const StoreType *block = feature_matrix_[i]; std::fill(nearest_indexes.data(), nearest_indexes.data() + BatchCount, 0); std::fill(nearest_scores.data(), nearest_scores.data() + BatchCount, std::numeric_limits::max()); for (size_t j = 0; j != count; j += BatchCount) { ContextType::template BatchDistance( centroids_matrix_[j], block, centroids_matrix_.dimension(), scores.data()); for (size_t k = 0; k < BatchCount; ++k) { const float *start = &scores[k * BatchCount]; const float *result = std::min_element(start, start + BatchCount, comp); if (*result < nearest_scores[k]) { nearest_scores[k] = *result; nearest_indexes[k] = j + (result - start); } } } // end of for for (size_t j = count, total = centroids_matrix_.count(); j != total; ++j) { ContextType::template BatchDistance<1>(block, centroids_matrix_[j], centroids_matrix_.dimension(), scores.data()); for (size_t k = 0; k < BatchCount; ++k) { float score = scores[k]; if (score < nearest_scores[k]) { nearest_scores[k] = score; nearest_indexes[k] = j; } } } // end of for ContextType::MatrixReverseTranspose(block, feature_matrix_.dimension(), rows.data()); for (size_t k = 0; k < BatchCount; ++k) { context_[nearest_indexes[k]].append( rows[k], feature_matrix_.dimension(), nearest_scores[k]); } } // end of for } private: //! Members size_t k_value_{0u}; ContainerType feature_cache_{}; ContainerType feature_matrix_{}; ContainerType centroids_matrix_{}; ContainerType centroids_{}; ContextType context_{}; bool spherical_{false}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/buffer/buffer_manager.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunused-parameter" #pragma clang diagnostic ignored "-Wshadow" #elif defined(__GNUC__) || defined(__GNUG__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wshadow" #endif #include #ifdef __clang__ #pragma clang diagnostic pop #elif defined(__GNUC__) || defined(__GNUG__) #pragma GCC diagnostic pop #endif namespace zvec { namespace ailego { namespace { struct IDHash { size_t operator()(const BufferID &buffer_id) const { size_t hash = std::hash{}(static_cast(buffer_id.type)); hash = hash ^ (std::hash{}(buffer_id.file_id)); if (buffer_id.type == BufferID::TYPE::PARQUET) { hash = hash * 31 + std::hash{}(buffer_id.parquet().column); hash = hash * 31 + std::hash{}(buffer_id.parquet().row_group); } else if (buffer_id.type == BufferID::TYPE::VECTOR) { hash = hash * 31 + std::hash{}(buffer_id.vector().offset); } return hash; } }; struct IDEqual { bool operator()(const BufferID &a, const BufferID &b) const { if (a.type != b.type) { return false; } if (a.file_name != b.file_name) { return false; } if (a.file_id != b.file_id) { return false; } if (a.mtime != b.mtime) { return false; } if (a.type == BufferID::TYPE::PARQUET) { return a.parquet().column == b.parquet().column && a.parquet().row_group == b.parquet().row_group; } else if (a.type == BufferID::TYPE::VECTOR) { return a.vector().offset == b.vector().offset; } else { return false; } } }; } // namespace struct BufferManager::BufferContext { BufferContext(const BufferID &id, BufferPool *p) : id(id), pool(p) {}; BufferContext(const BufferContext &) = delete; BufferContext(BufferContext &&) = delete; BufferContext &operator=(const BufferContext &) = delete; BufferContext &operator=(BufferContext &&) = delete; ~BufferContext() { if (vector) { free(vector); } } typedef std::unique_ptr Pointer; enum State : uint32_t { IDLE = 0, // Empty and not held by any users, not in LRU RESERVED = 1, // Pinned by a user but no data yet, not in LRU IN_USE = 2, // Pinned by a user and data is present, not in LRU CACHED = 3, // Data is present but not held by any users, in LRU ERROR = 4 // Something went wrong, not in LRU }; // Identifier for the buffer BufferID id; // Current state State state{IDLE}; // The size of the buffer uint32_t size{0}; // Handle of the file backing this buffer File file; // The number of external references to this buffer (via pin/unpin) std::atomic refs_buf{0}; // The number of external references to this context (via BufferHandle) std::atomic refs_context{0}; BufferPool *pool{nullptr}; // A shared pointer to the buffers allocated for arrow parquet data std::shared_ptr arrow{nullptr}; // Guard original arrow buffers to prevent premature deletion std::vector> arrow_refs{}; // A pointer to the buffer allocated for vector data void *vector{nullptr}; // Doubly linked LRU list BufferContext *next{nullptr}; BufferContext *prev{nullptr}; // Return a string representation of the status const std::string status_string() const; // Populate the buffer with parquet data arrow::Status read_arrow_parquet(); // Populate the buffer with vector data bool read_vector(); }; const std::string BufferManager::BufferContext::status_string() const { std::string msg{id.to_string() + ": "}; switch (state) { case State::IDLE: { msg += "Idle"; break; } case State::RESERVED: { msg += "Reserved"; break; } case State::IN_USE: { msg += "In use"; break; } case State::CACHED: { msg += "Cached"; break; } case State::ERROR: { msg += "Error"; break; } } return msg; } arrow::Status BufferManager::BufferContext::read_arrow_parquet() { // TODO: file handler and memory pool can be optimized arrow::MemoryPool *mem_pool = arrow::default_memory_pool(); // Open file std::shared_ptr input; const auto &file_name = id.file_name; ARROW_ASSIGN_OR_RAISE(input, arrow::io::ReadableFile::Open(file_name)); // Open reader std::unique_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, parquet::arrow::OpenFile(input, mem_pool)); // Perform read int row_group = id.parquet().row_group; int column = id.parquet().column; auto s = reader->RowGroup(row_group)->Column(column)->Read(&arrow); if (!s.ok()) { LOG_ERROR("Failed to read parquet file[%s]", file_name.c_str()); arrow = nullptr; return s; } // Compute the memory usage and hijack Arrow's buffers with our implementation for (auto &array : arrow->chunks()) { auto &buffers = array->data()->buffers; for (size_t buf_idx = 0; buf_idx < buffers.size(); ++buf_idx) { if (buffers[buf_idx] == nullptr) { continue; } // Keep references to original buffers to prevent premature deletion arrow_refs.emplace_back(buffers[buf_idx]); size += buffers[buf_idx]->capacity(); // Create hijacked buffer with custom deleter that notifies us when Arrow // is finished with the buffer std::shared_ptr hijacked_buffer( buffers[buf_idx].get(), BufferManager::ArrowBufferDeleter(this)); buffers[buf_idx] = hijacked_buffer; } } return arrow::Status::OK(); } bool BufferManager::BufferContext::read_vector() { const auto &file_name = id.file_name; if (!file.is_valid()) { if (!File::IsExist(file_name)) { LOG_ERROR("File[%s] does not exist", file_name.c_str()); return false; } if (!File::IsRegular(file_name)) { LOG_ERROR("[%s] is not a regular file", file_name.c_str()); return false; } if (!file.open(file_name.c_str(), true, false)) { LOG_ERROR("Failed to open file[%s]", file_name.c_str()); return false; } } AILEGO_DEFER([this] { file.close(); }); uint32_t len = id.vector().length; auto ret = posix_memalign((void **)&vector, 64, len); // 64-byte alignment if (ret != 0 || vector == nullptr) { LOG_ERROR("Failed to allocate buffer for file[%s]", file_name.c_str()); return false; } uint32_t offset = id.vector().offset; if (file.read(offset, vector, len) != len) { LOG_ERROR("Failed to read file[%s]", file_name.c_str()); free(vector); vector = nullptr; return false; } size = len; return true; } // Thread-safe buffer pool implementation. // // BufferContext states: // 1. Must exist in the lookup (hash) table. // 2. LRU list presence: // - In LRU: holds memory but not pinned by any users // - Not in LRU: either holds memory pinned by users, or doesn't hold memory // 3. External references: when an external user acquires a context and pins the // memory, that context is removed from LRU list; when they unpins the // memory, that context is moved to LRU list if it was the last reference. // // Any operation on the hash table is protected by mutex_table_. // Any change to context state and LRU list is protected by mutex_context_. // class BufferManager::BufferPool { public: explicit BufferPool(uint64_t limit) : limit_(limit) { sentinel_.next = &sentinel_; sentinel_.prev = &sentinel_; } BufferContext *acquire_locked(BufferID &id) { std::lock_guard lock(mutex_context_); if (auto iter = table_.find(id); iter != table_.end()) { return iter->second.get(); } auto [iter, _] = table_.emplace(id, std::make_unique(id, this)); return iter->second.get(); } void try_release_context_locked(BufferContext *context) { if (context->refs_context.load() != 0) { return; } std::lock_guard lock(mutex_table_); if (context->refs_context.load() != 0) { return; } if (context->state == BufferContext::State::IDLE) { table_.erase(context->id); } } void pin_locked(BufferContext *ctx) { std::lock_guard lock(mutex_context_); if (ctx->state == BufferContext::State::IDLE) { return pin_at_IDLE(ctx); } if (ctx->state == BufferContext::State::IN_USE) { return pin_at_IN_USE(ctx); } if (ctx->state == BufferContext::State::CACHED) { return pin_at_CACHED(ctx); } if (ctx->state == BufferContext::State::ERROR) { return; } } bool unpin_locked(BufferContext *ctx) { uint32_t prev_refs = ctx->refs_buf.fetch_sub(1); if (prev_refs > 1) { return false; } std::lock_guard lock(mutex_context_); if (ctx->refs_buf.load() == 0 && ctx->state != BufferContext::State::CACHED) { ctx->state = BufferContext::State::CACHED; LRU_insert(ctx); return true; } else { return false; } } void LRU_insert_locked(BufferContext *context) { std::lock_guard lock(mutex_context_); LRU_insert(context); } void LRU_remove_locked(BufferContext *context) { std::lock_guard lock(mutex_context_); LRU_remove(context); } uint64_t usage() const { return usage_; } private: void pin_at_IDLE(BufferContext *ctx) { ctx->state = BufferContext::State::RESERVED; while (usage_ >= limit_) { // The tail of LRU list is the least recently used context BufferContext *victim = sentinel_.prev; if (victim == &sentinel_) { // No victim could be found ctx->state = BufferContext::State::ERROR; return; } if (victim->state == BufferContext::State::ERROR) { LRU_remove(victim); try_release_context_locked(ctx); continue; } if (victim->id.type == BufferID::TYPE::PARQUET) { victim->arrow_refs.clear(); } else { free(victim->vector); victim->vector = nullptr; } victim->state = BufferContext::State::IDLE; LRU_remove(victim); try_release_context_locked(ctx); usage_ -= victim->size; } if (ctx->id.type == BufferID::TYPE::PARQUET) { if (ctx->read_arrow_parquet().ok()) { ctx->state = BufferContext::State::IN_USE; ctx->refs_buf.fetch_add(ctx->arrow_refs.size()); usage_ += ctx->size; } else { LOG_ERROR("Failed to read to %s", ctx->id.to_string().c_str()); ctx->state = BufferContext::State::ERROR; } } else { if (ctx->read_vector()) { ctx->state = BufferContext::State::IN_USE; ctx->refs_buf.fetch_add(1); usage_ += ctx->size; } else { LOG_ERROR("Failed to read to %s", ctx->id.to_string().c_str()); ctx->state = BufferContext::State::ERROR; } } } void pin_at_IN_USE(BufferContext *ctx) { if (ctx->id.type == BufferID::TYPE::PARQUET) { ctx->refs_buf.fetch_add(ctx->arrow_refs.size()); } else { ctx->refs_buf.fetch_add(1); } } void pin_at_CACHED(BufferContext *ctx) { if (ctx->id.type == BufferID::TYPE::PARQUET) { ctx->refs_buf.fetch_add(ctx->arrow_refs.size()); } else { ctx->refs_buf.fetch_add(1); } LRU_remove(ctx); ctx->state = BufferContext::State::IN_USE; } void LRU_insert(BufferContext *context) { if (context->refs_buf > 0) { return; // Already pinned, should not be evicted } if (context->next != nullptr || context->prev != nullptr) { return; } // Insert the context to the head of LRU list context->next = sentinel_.next; context->prev = &sentinel_; sentinel_.next = context; context->next->prev = context; inactive_ += context->size; } void LRU_remove(BufferContext *context) { if (context->next == nullptr) { return; // Not in LRU list } context->next->prev = context->prev; context->prev->next = context->next; context->next = nullptr; context->prev = nullptr; inactive_ -= context->size; } private: using Table = std::unordered_map; uint64_t limit_; std::atomic usage_{0}; std::atomic inactive_{0}; Table table_{}; std::mutex mutex_table_{}; BufferContext sentinel_{BufferID{}, this}; // LRU list sentinel std::mutex mutex_context_{}; }; BufferManager::ArrowBufferDeleter::ArrowBufferDeleter(BufferContext *c) : context(c) {} void BufferManager::ArrowBufferDeleter::operator()(arrow::Buffer *) { context->pool->unpin_locked(context); } BufferHandle::BufferHandle(BufferContext *context) : context_(context) { if (context_ != nullptr) { pool_ = context_->pool; context_->refs_context.fetch_add(1); } } BufferHandle::~BufferHandle() { if (context_ != nullptr) { uint32_t prev_refs = context_->refs_context.fetch_sub(1); if (prev_refs > 1) { return; } if (context_->state == BufferContext::State::IDLE) { pool_->try_release_context_locked(context_); } } } std::shared_ptr BufferHandle::pin_parquet_data() { pool_->pin_locked(context_); return context_->arrow; } void *BufferHandle::pin_vector_data() { if (!context_) { return nullptr; } pool_->pin_locked(context_); return context_->vector; } bool BufferHandle::unpin_vector_data() { if (!context_) { return true; } return pool_->unpin_locked(context_); } uint32_t BufferHandle::references() const { return context_->refs_buf.load(); } uint32_t BufferHandle::size() const { return context_->size; } void BufferManager::init(uint64_t limit, uint32_t num_shards) { pools_.clear(); uint64_t limit_per_shard = ailego_align(limit / num_shards, 4096); for (uint32_t i = 0; i < num_shards; ++i) { auto pool = new BufferPool(limit_per_shard); pools_.push_back(pool); } LOG_INFO( "BufferManager initialized with [%u] buffer pools, [%zu] bytes memory " "limit per pool, total memory limit [%zu] bytes", num_shards, (size_t)limit_per_shard, (size_t)limit); } BufferHandle BufferManager::acquire(BufferID &buffer_id) { static IDHash id_hash{}; auto hash_val = id_hash(buffer_id); auto ctx = pools_[hash_val % pools_.size()]->acquire_locked(buffer_id); return BufferHandle(ctx); } std::unique_ptr BufferManager::acquire_ptr(BufferID &buffer_id) { static IDHash id_hash{}; auto hash_val = id_hash(buffer_id); auto ctx = pools_[hash_val % pools_.size()]->acquire_locked(buffer_id); return std::make_unique(ctx); } uint64_t BufferManager::total_size_in_bytes() const { uint64_t total_usage = 0; for (auto pool : pools_) { total_usage += pool->usage(); } return total_usage; } BufferManager::~BufferManager() { for (auto pool : pools_) { delete pool; } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/buffer/buffer_pool.cc ================================================ #include #include namespace zvec { namespace ailego { int LRUCache::init(size_t block_size) { block_size_ = block_size; for (size_t i = 0; i < CATCH_QUEUE_NUM; i++) { queues_.push_back(ConcurrentQueue(block_size)); } return 0; } bool LRUCache::evict_single_block(BlockType &item) { bool found = false; for (size_t i = 0; i < CATCH_QUEUE_NUM; i++) { found = queues_[i].try_dequeue(item); if (found) { break; } } return found; } bool LRUCache::add_single_block(const LPMap *lp_map, const BlockType &block, int block_type) { bool ok = queues_[block_type].enqueue(block); if (!ok) { LOG_ERROR("enqueue failed."); return false; } evict_queue_insertions_.fetch_add(1, std::memory_order_relaxed); if (evict_queue_insertions_ % block_size_ == 0) { this->clear_dead_node(lp_map); } return true; } void LRUCache::clear_dead_node(const LPMap *lp_map) { for (size_t i = 0; i < CATCH_QUEUE_NUM; i++) { size_t clear_size = block_size_ * 2; if (queues_[i].size_approx() < clear_size * 4) { continue; } size_t clear_count = 0; ConcurrentQueue tmp(block_size_); BlockType item; while (queues_[i].try_dequeue(item) && (clear_count++ < clear_size)) { if (!lp_map->isDeadBlock(item)) { if (!tmp.enqueue(item)) { LOG_ERROR("enqueue failed."); } } } while (tmp.try_dequeue(item)) { if (!lp_map->isDeadBlock(item)) { if (!queues_[i].enqueue(item)) { LOG_ERROR("enqueue failed."); } } } } } void LPMap::init(size_t entry_num) { if (entries_) { delete[] entries_; } entry_num_ = entry_num; entries_ = new Entry[entry_num_]; for (size_t i = 0; i < entry_num_; i++) { entries_[i].ref_count.store(std::numeric_limits::min()); entries_[i].load_count.store(0); entries_[i].buffer = nullptr; } cache_.init(entry_num * 4); } char *LPMap::acquire_block(block_id_t block_id, bool lru_mode) { assert(block_id < entry_num_); Entry &entry = entries_[block_id]; if (!lru_mode) { return entry.buffer; } while (true) { int current_count = entry.ref_count.load(std::memory_order_acquire); if (current_count < 0) { return nullptr; } if (entry.ref_count.compare_exchange_weak(current_count, current_count + 1, std::memory_order_acq_rel, std::memory_order_acquire)) { if (current_count == 0) { entry.load_count.fetch_add(1, std::memory_order_relaxed); } return entry.buffer; } } } void LPMap::release_block(block_id_t block_id) { assert(block_id < entry_num_); Entry &entry = entries_[block_id]; if (entry.ref_count.fetch_sub(1, std::memory_order_release) == 1) { std::atomic_thread_fence(std::memory_order_acquire); LRUCache::BlockType block; block.first = block_id; block.second = entry.load_count.load(); cache_.add_single_block(this, block, 0); } } char *LPMap::evict_block(block_id_t block_id) { assert(block_id < entry_num_); Entry &entry = entries_[block_id]; int expected = 0; if (entry.ref_count.compare_exchange_strong( expected, std::numeric_limits::min())) { char *buffer = entry.buffer; entry.buffer = nullptr; return buffer; } else { return nullptr; } } char *LPMap::set_block_acquired(block_id_t block_id, char *buffer) { assert(block_id < entry_num_); Entry &entry = entries_[block_id]; while (true) { int current_count = entry.ref_count.load(std::memory_order_relaxed); if (current_count >= 0) { if (entry.ref_count.compare_exchange_weak( current_count, current_count + 1, std::memory_order_acq_rel, std::memory_order_acquire)) { return entry.buffer; } } else { if (entry.ref_count.compare_exchange_weak(current_count, 1, std::memory_order_acq_rel, std::memory_order_acquire)) { entry.buffer = buffer; entry.load_count.fetch_add(1, std::memory_order_relaxed); return entry.buffer; } } } } void LPMap::recycle(moodycamel::ConcurrentQueue &free_buffers) { LRUCache::BlockType block; do { bool ok = cache_.evict_single_block(block); if (!ok) { return; } } while (isDeadBlock(block)); char *buffer = evict_block(block.first); if (buffer) { if (!free_buffers.enqueue(buffer)) { LOG_ERROR("recycle buffer enqueue failed."); ailego_free(buffer); } } } VecBufferPool::VecBufferPool(const std::string &filename) { fd_ = open(filename.c_str(), O_RDONLY); if (fd_ < 0) { throw std::runtime_error("Failed to open file: " + filename); } struct stat st; if (fstat(fd_, &st) < 0) { ::close(fd_); throw std::runtime_error("Failed to stat file: " + filename); } file_size_ = st.st_size; } int VecBufferPool::init(size_t pool_capacity, size_t block_size, size_t segment_count) { if (block_size == 0) { LOG_ERROR("block_size must not be 0"); return -1; } pool_capacity_ = pool_capacity; size_t buffer_num = pool_capacity_ / block_size + 10; size_t block_num = segment_count + 10; lp_map_.init(block_num); mutex_vec_.reserve(block_num); for (int i = 0; i < block_num; i++) { mutex_vec_.emplace_back(std::make_unique()); } for (size_t i = 0; i < buffer_num; i++) { char *buffer = (char *)ailego_malloc(block_size); if (buffer != nullptr) { if (!free_buffers_.enqueue(buffer)) { LOG_ERROR("recycle buffer enqueue failed."); ailego_free(buffer); return -1; } } else { LOG_ERROR("aligned_alloc %zu(size: %zu) failed", i, block_size); return -1; } } LOG_DEBUG("Buffer pool num: %zu, entry num: %zu", buffer_num, lp_map_.entry_num()); no_lru_mode_ = false; if (lp_map_.entry_num() <= buffer_num) { no_lru_mode_ = true; } return 0; } VecBufferPoolHandle VecBufferPool::get_handle() { return VecBufferPoolHandle(*this); } char *VecBufferPool::acquire_buffer(block_id_t block_id, size_t offset, size_t size, int retry) { char *buffer = lp_map_.acquire_block(block_id, !no_lru_mode()); if (buffer) { return buffer; } std::lock_guard lock(*mutex_vec_[block_id]); buffer = lp_map_.acquire_block(block_id, !no_lru_mode()); if (buffer) { return buffer; } { bool found = free_buffers_.try_dequeue(buffer); if (!found && !no_lru_mode_) { for (int i = 0; i < retry; i++) { lp_map_.recycle(free_buffers_); found = free_buffers_.try_dequeue(buffer); if (found) { break; } } } if (!found) { LOG_ERROR("Buffer pool failed to get free buffer"); return nullptr; } } ssize_t read_bytes = pread(fd_, buffer, size, offset); if (read_bytes != static_cast(size)) { LOG_ERROR("Buffer pool failed to read file at offset: %zu", offset); free_buffers_.enqueue(buffer); return nullptr; } return lp_map_.set_block_acquired(block_id, buffer); } int VecBufferPool::get_meta(size_t offset, size_t length, char *buffer) { ssize_t read_bytes = pread(fd_, buffer, length, offset); if (read_bytes != static_cast(length)) { LOG_ERROR("Buffer pool failed to read file at offset: %zu", offset); return -1; } return 0; } char *VecBufferPoolHandle::get_block(size_t offset, size_t size, size_t block_id) { char *buffer = pool_.acquire_buffer(block_id, offset, size, 5); return buffer; } int VecBufferPoolHandle::get_meta(size_t offset, size_t length, char *buffer) { return pool_.get_meta(offset, length, buffer); } void VecBufferPoolHandle::release_one(block_id_t block_id) { if (!pool_.no_lru_mode()) { pool_.lp_map_.release_block(block_id); } } void VecBufferPoolHandle::acquire_one(block_id_t block_id) { if (!pool_.no_lru_mode()) { pool_.lp_map_.acquire_block(block_id, true); } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/bitmap.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "bitmap.h" namespace zvec { namespace ailego { size_t Bitset::BitwiseAndCardinality(const Bitset &lhs, const Bitset &rhs) { return BitsetHelper::BitwiseAndCardinality( lhs.array_.data(), rhs.array_.data(), std::min(lhs.array_.size(), rhs.array_.size())); } size_t Bitset::BitwiseAndnotCardinality(const Bitset &lhs, const Bitset &rhs) { size_t lsize = lhs.array_.size(); size_t rsize = rhs.array_.size(); if (lsize > rsize) { return ( BitsetHelper::BitwiseAndnotCardinality(lhs.array_.data(), rhs.array_.data(), rsize) + BitsetHelper::Cardinality(lhs.array_.data() + rsize, lsize - rsize)); } return BitsetHelper::BitwiseAndnotCardinality(lhs.array_.data(), rhs.array_.data(), lsize); } size_t Bitset::BitwiseXorCardinality(const Bitset &lhs, const Bitset &rhs) { size_t lsize = lhs.array_.size(); size_t rsize = rhs.array_.size(); if (lsize < rsize) { return ( BitsetHelper::BitwiseXorCardinality(lhs.array_.data(), rhs.array_.data(), lsize) + BitsetHelper::Cardinality(rhs.array_.data() + lsize, rsize - lsize)); } else if (lsize > rsize) { return ( BitsetHelper::BitwiseXorCardinality(lhs.array_.data(), rhs.array_.data(), rsize) + BitsetHelper::Cardinality(lhs.array_.data() + rsize, lsize - rsize)); } return BitsetHelper::BitwiseXorCardinality(lhs.array_.data(), rhs.array_.data(), lsize); } size_t Bitset::BitwiseOrCardinality(const Bitset &lhs, const Bitset &rhs) { size_t lsize = lhs.array_.size(); size_t rsize = rhs.array_.size(); if (lsize < rsize) { return ( BitsetHelper::BitwiseOrCardinality(lhs.array_.data(), rhs.array_.data(), lsize) + BitsetHelper::Cardinality(rhs.array_.data() + lsize, rsize - lsize)); } else if (lsize > rsize) { return ( BitsetHelper::BitwiseOrCardinality(lhs.array_.data(), rhs.array_.data(), rsize) + BitsetHelper::Cardinality(lhs.array_.data() + rsize, lsize - rsize)); } return BitsetHelper::BitwiseOrCardinality(lhs.array_.data(), rhs.array_.data(), lsize); } void Bitmap::clear(void) { for (std::vector::iterator iter = array_.begin(); iter != array_.end(); ++iter) { delete (*iter); } array_.clear(); } void Bitmap::copy(const Bitmap &rhs) { this->clear(); for (std::vector::const_iterator iter = rhs.array_.begin(); iter != rhs.array_.end(); ++iter) { Bucket *bucket = NULL; if (*iter) { bucket = new Bucket(*(*iter)); } array_.push_back(bucket); } } void Bitmap::shrink_to_fit(void) { size_t shrink_count = 0; std::vector::reverse_iterator iter; for (iter = array_.rbegin(); iter != array_.rend(); ++iter) { if (*iter) { if (!(*iter)->test_none()) { break; } delete (*iter); *iter = NULL; } ++shrink_count; } for (; iter != array_.rend(); ++iter) { if ((*iter) && (*iter)->test_none()) { delete (*iter); *iter = NULL; } } if (shrink_count != 0) { array_.resize(array_.size() - shrink_count); } } bool Bitmap::test(size_t num) const { // High 16 bits size_t offset = num >> 16; if (offset < array_.size()) { const Bucket *bucket = array_[offset]; if (bucket) { // Low 16 bits return bucket->test(static_cast(num)); } } return false; } void Bitmap::set(size_t num) { // High 16 bits size_t offset = num >> 16; if (offset >= array_.size()) { array_.resize(offset + 1, NULL); } Bucket *&bucket = array_[offset]; if (!bucket) { bucket = new Bucket; } // Low 16 bits bucket->set(static_cast(num)); } void Bitmap::reset(size_t num) { // High 16 bits size_t offset = num >> 16; if (offset >= array_.size()) { array_.resize(offset + 1, NULL); } if (offset < array_.size()) { Bucket *bucket = array_[offset]; if (bucket) { // Low 16 bits bucket->reset(static_cast(num)); } } } void Bitmap::flip(size_t num) { // High 16 bits uint16_t offset = (uint16_t)(num >> 16); if (offset >= array_.size()) { array_.resize(offset + 1, NULL); } Bucket *&bucket = array_[offset]; if (!bucket) { bucket = new Bucket; } // Low 16 bits bucket->flip(static_cast(num)); } void Bitmap::bitwise_and(const Bitmap &rhs) { size_t overlap = std::min(array_.size(), rhs.array_.size()); for (size_t i = 0; i < overlap; ++i) { Bucket *&dst = array_[i]; if (dst) { const Bucket *src = rhs.array_[i]; if (src) { dst->bitwise_and(*src); } else { delete dst; dst = NULL; } } } for (size_t i = overlap; i < array_.size(); ++i) { Bucket *&dst = array_[i]; delete dst; dst = NULL; } } void Bitmap::bitwise_andnot(const Bitmap &rhs) { size_t overlap = std::min(array_.size(), rhs.array_.size()); for (size_t i = 0; i < overlap; ++i) { Bucket *&dst = array_[i]; if (dst) { const Bucket *src = rhs.array_[i]; if (src) { dst->bitwise_andnot(*src); } } } } void Bitmap::bitwise_or(const Bitmap &rhs) { size_t overlap = std::min(array_.size(), rhs.array_.size()); for (size_t i = 0; i < overlap; ++i) { const Bucket *src = rhs.array_[i]; if (src) { Bucket *&dst = array_[i]; if (dst) { dst->bitwise_or(*src); } else { dst = new Bucket(*src); } } } for (size_t i = overlap; i < rhs.array_.size(); ++i) { const Bucket *src = rhs.array_[i]; Bucket *bucket = NULL; if (src) { bucket = new Bucket(*src); } array_.push_back(bucket); } } void Bitmap::bitwise_xor(const Bitmap &rhs) { size_t overlap = std::min(array_.size(), rhs.array_.size()); for (size_t i = 0; i < overlap; ++i) { const Bucket *src = rhs.array_[i]; if (src) { Bucket *&dst = array_[i]; if (dst) { dst->bitwise_xor(*src); } else { dst = new Bucket(*src); } } } for (size_t i = overlap; i < rhs.array_.size(); ++i) { const Bucket *src = rhs.array_[i]; Bucket *bucket = NULL; if (src) { bucket = new Bucket(*src); } array_.push_back(bucket); } } void Bitmap::bitwise_not(void) { for (std::vector::iterator iter = array_.begin(); iter != array_.end(); ++iter) { Bucket *&bucket = *iter; if (!bucket) { bucket = new Bucket; } bucket->bitwise_not(); } } bool Bitmap::test_all(void) const { if (array_.empty()) { return false; } for (std::vector::const_iterator iter = array_.begin(); iter != array_.end(); ++iter) { if (!(*iter) || !(*iter)->test_all()) { return false; } } return true; } bool Bitmap::test_any(void) const { for (std::vector::const_iterator iter = array_.begin(); iter != array_.end(); ++iter) { if (*iter && (*iter)->test_any()) { return true; } } return false; } bool Bitmap::test_none(void) const { for (std::vector::const_iterator iter = array_.begin(); iter != array_.end(); ++iter) { if (*iter && !(*iter)->test_none()) { return false; } } return true; } size_t Bitmap::cardinality(void) const { size_t result = 0; for (std::vector::const_iterator iter = array_.begin(); iter != array_.end(); ++iter) { if (*iter) { result += (*iter)->cardinality(); } } return result; } void Bitmap::extract(size_t base, std::vector *out) const { for (std::vector::const_iterator iter = array_.begin(); iter != array_.end(); ++iter) { if (*iter) { (*iter)->extract(base, out); } base += Bucket::MAX_SIZE; } } size_t Bitmap::BitwiseAndCardinality(const Bitmap &lhs, const Bitmap &rhs) { size_t overlap = std::min(lhs.array_.size(), rhs.array_.size()); size_t dist = 0; for (size_t i = 0; i < overlap; ++i) { const Bucket *l = lhs.array_[i]; const Bucket *r = rhs.array_[i]; if (l && r) { dist += Bucket::BitwiseAndCardinality(*l, *r); } } return dist; } size_t Bitmap::BitwiseAndnotCardinality(const Bitmap &lhs, const Bitmap &rhs) { size_t overlap = std::min(lhs.array_.size(), rhs.array_.size()); size_t dist = 0; for (size_t i = 0; i < overlap; ++i) { const Bucket *l = lhs.array_[i]; if (l) { const Bucket *r = rhs.array_[i]; if (r) { dist += Bucket::BitwiseAndnotCardinality(*l, *r); } else { dist += l->cardinality(); } } } for (size_t i = overlap; i < lhs.array_.size(); ++i) { const Bucket *l = lhs.array_[i]; if (l) { dist += l->cardinality(); } } return dist; } size_t Bitmap::BitwiseXorCardinality(const Bitmap &lhs, const Bitmap &rhs) { size_t overlap = std::min(lhs.array_.size(), rhs.array_.size()); size_t dist = 0; for (size_t i = 0; i < overlap; ++i) { const Bucket *l = lhs.array_[i]; const Bucket *r = rhs.array_[i]; if (l && r) { dist += Bucket::BitwiseXorCardinality(*l, *r); } else if (l) { dist += l->cardinality(); } else if (r) { dist += r->cardinality(); } } for (size_t i = overlap; i < lhs.array_.size(); ++i) { const Bucket *l = lhs.array_[i]; if (l) { dist += l->cardinality(); } } for (size_t i = overlap; i < rhs.array_.size(); ++i) { const Bucket *r = rhs.array_[i]; if (r) { dist += r->cardinality(); } } return dist; } size_t Bitmap::BitwiseOrCardinality(const Bitmap &lhs, const Bitmap &rhs) { size_t overlap = std::min(lhs.array_.size(), rhs.array_.size()); size_t dist = 0; for (size_t i = 0; i < overlap; ++i) { const Bucket *l = lhs.array_[i]; const Bucket *r = rhs.array_[i]; if (l && r) { dist += Bucket::BitwiseOrCardinality(*l, *r); } else if (l) { dist += l->cardinality(); } else if (r) { dist += r->cardinality(); } } for (size_t i = overlap; i < lhs.array_.size(); ++i) { const Bucket *l = lhs.array_[i]; if (l) { dist += l->cardinality(); } } for (size_t i = overlap; i < rhs.array_.size(); ++i) { const Bucket *r = rhs.array_[i]; if (r) { dist += r->cardinality(); } } return dist; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/bitmap.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include namespace zvec { namespace ailego { /*! Fixed Bitset */ template ::type> class FixedBitset { public: enum { MAX_SIZE = N }; //! Constructor FixedBitset(void) { memset(array_, 0, sizeof(array_)); } //! Constructor FixedBitset(const FixedBitset &rhs) { memcpy(array_, rhs.array_, sizeof(array_)); } //! Destructor ~FixedBitset(void) {} //! Assignment FixedBitset &operator=(const FixedBitset &rhs) { memcpy(array_, rhs.array_, sizeof(array_)); return *this; } //! Retrieve data pointer uint32_t *data(void) { return reinterpret_cast(array_); } //! Retrieve data pointer const uint32_t *data(void) const { return reinterpret_cast(array_); } //! Retrieve count of bits in set constexpr size_t size(void) const { return MAX_SIZE; } // !Clear the bitset void clear(void) { memset(array_, 0, sizeof(array_)); } //! Test a bit in bitset bool test(size_t num) const { ailego_assert_with(N > num, "overflow argument"); return ((array_[num >> 5] & (1u << (num & 0x1f))) != 0); } //! Set a bit in bitset void set(size_t num) { ailego_assert_with(N > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] |= mask; } //! Clear a bit in bitset void reset(size_t num) { ailego_assert_with(N > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] &= ~mask; } //! Toggle a bit in bitset void flip(size_t num) { ailego_assert_with(N > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] ^= mask; } //! Perform binary AND void bitwise_and(const FixedBitset &rhs) { BitsetHelper::BitwiseAnd(array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Perform binary AND NOT void bitwise_andnot(const FixedBitset &rhs) { BitsetHelper::BitwiseAndnot(array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Perform binary OR void bitwise_or(const FixedBitset &rhs) { BitsetHelper::BitwiseOr(array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Perform binary XOR void bitwise_xor(const FixedBitset &rhs) { BitsetHelper::BitwiseXor(array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Perform binary NOT void bitwise_not(void) { BitsetHelper::BitwiseNot(array_, ((N + 0x1f) >> 5)); } //! Check if all bits are set to true bool test_all(void) const { return BitsetHelper::TestAll(array_, ((N + 0x1f) >> 5)); } //! Check if any bits are set to true bool test_any(void) const { return BitsetHelper::TestAny(array_, ((N + 0x1f) >> 5)); } //! Check if none of the bits are set to true bool test_none(void) const { return BitsetHelper::TestNone(array_, ((N + 0x1f) >> 5)); } //! Compute the cardinality of a bitset size_t cardinality(void) const { return BitsetHelper::Cardinality(array_, ((N + 0x1f) >> 5)); } //! Extract the bitset to an array void extract(size_t base, std::vector *out) const { const uint32_t *iter = array_; const uint32_t *last = array_ + ((N + 0x1f) >> 5); for (; iter != last; ++iter) { uint32_t w = *iter; while (w != 0) { uint32_t c = ailego_ctz32(w); w &= ~(1u << c); out->push_back(base + c); } base += 32u; } } //! Extract the bitset to an array void extract(std::vector *out) const { this->extract(0, out); } //! Compute the AND cardinality between two bitsets static size_t BitwiseAndCardinality(const FixedBitset &lhs, const FixedBitset &rhs) { return BitsetHelper::BitwiseAndCardinality(lhs.array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Compute the ANDNOT cardinality between two bitsets static size_t BitwiseAndnotCardinality(const FixedBitset &lhs, const FixedBitset &rhs) { return BitsetHelper::BitwiseAndnotCardinality(lhs.array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Compute the XOR cardinality between two bitsets static size_t BitwiseXorCardinality(const FixedBitset &lhs, const FixedBitset &rhs) { return BitsetHelper::BitwiseXorCardinality(lhs.array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Compute the OR cardinality between two bitsets static size_t BitwiseOrCardinality(const FixedBitset &lhs, const FixedBitset &rhs) { return BitsetHelper::BitwiseOrCardinality(lhs.array_, rhs.array_, ((N + 0x1f) >> 5)); } //! Convert a array pointer to bitset pointer static FixedBitset *Cast(uint32_t *arr) { return reinterpret_cast *>(arr); } //! Convert a array pointer to bitset pointer static const FixedBitset *Cast(const uint32_t *arr) { return reinterpret_cast *>(arr); } //! Convert a array pointer to bitset pointer static FixedBitset *Cast(uint64_t *arr) { return reinterpret_cast *>(arr); } //! Convert a array pointer to bitset pointer static const FixedBitset *Cast(const uint64_t *arr) { return reinterpret_cast *>(arr); } private: uint32_t array_[(N + 0x1f) >> 5]; }; /*! Fixed Bitset (Special) */ template <> class FixedBitset<0> { public: enum { MAX_SIZE = 0 }; //! Retrieve max size of bitset constexpr size_t size(void) const { return MAX_SIZE; } }; /*! Bitset */ class Bitset { public: //! Constructor Bitset(void) : array_() {} //! Constructor Bitset(size_t bits) : array_((bits + 0x1f) >> 5) {} //! Constructor Bitset(const Bitset &rhs) : array_(rhs.array_) {} //! Constructor Bitset(Bitset &&rhs) : array_(std::move(rhs.array_)) {} //! Destructor ~Bitset(void) {} //! Assignment Bitset &operator=(const Bitset &rhs) { array_ = rhs.array_; return *this; } //! Assignment Bitset &operator=(Bitset &&rhs) { array_ = std::move(rhs.array_); return *this; } //! Retrieve data pointer uint32_t *data(void) { return array_.data(); } //! Retrieve data pointer const uint32_t *data(void) const { return array_.data(); } //! Retrieve count of bits in set size_t size(void) const { return (array_.size() << 5); } //! Resize the bitset void resize(size_t bits) { array_.resize((bits + 0x1f) >> 5); } // !Clear the bitset void clear(void) { array_.clear(); } //! Test a bit in bitset bool test(size_t num) const { ailego_assert_with(this->size() > num, "overflow argument"); return ((array_[num >> 5] & (1u << (num & 0x1f))) != 0); } //! Set a bit in bitset void set(size_t num) { ailego_assert_with(this->size() > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] |= mask; } //! Clear a bit in bitset void reset(size_t num) { ailego_assert_with(this->size() > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] &= ~mask; } //! Toggle a bit in bitset void flip(size_t num) { ailego_assert_with(this->size() > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] ^= mask; } //! Perform binary AND void bitwise_and(const Bitset &rhs) { BitsetHelper::BitwiseAnd(array_.data(), rhs.array_.data(), std::min(array_.size(), rhs.array_.size())); } //! Perform binary AND NOT void bitwise_andnot(const Bitset &rhs) { BitsetHelper::BitwiseAndnot(array_.data(), rhs.array_.data(), std::min(array_.size(), rhs.array_.size())); } //! Perform binary OR void bitwise_or(const Bitset &rhs) { BitsetHelper::BitwiseOr(array_.data(), rhs.array_.data(), std::min(array_.size(), rhs.array_.size())); } //! Perform binary XOR void bitwise_xor(const Bitset &rhs) { BitsetHelper::BitwiseXor(array_.data(), rhs.array_.data(), std::min(array_.size(), rhs.array_.size())); } //! Perform binary NOT void bitwise_not(void) { BitsetHelper::BitwiseNot(array_.data(), array_.size()); } //! Check if all bits are set to true bool test_all(void) const { return BitsetHelper::TestAll(array_.data(), array_.size()); } //! Check if any bits are set to true bool test_any(void) const { return BitsetHelper::TestAny(array_.data(), array_.size()); } //! Check if none of the bits are set to true bool test_none(void) const { return BitsetHelper::TestNone(array_.data(), array_.size()); } //! Compute the cardinality of a bitset size_t cardinality(void) const { return BitsetHelper::Cardinality(array_.data(), array_.size()); } //! Extract the bitset to an array void extract(size_t base, std::vector *out) const { const uint32_t *iter = array_.data(); const uint32_t *last = array_.data() + array_.size(); for (; iter != last; ++iter) { uint32_t w = *iter; while (w != 0) { uint32_t c = ailego_ctz32(w); w &= ~(1u << c); out->push_back(base + c); } base += 32u; } } //! Extract the bitset to an array void extract(std::vector *out) const { this->extract(0, out); } //! Compute the AND cardinality between two bitsets static size_t BitwiseAndCardinality(const Bitset &lhs, const Bitset &rhs); //! Compute the ANDNOT cardinality between two bitsets static size_t BitwiseAndnotCardinality(const Bitset &lhs, const Bitset &rhs); //! Compute the XOR cardinality between two bitsets static size_t BitwiseXorCardinality(const Bitset &lhs, const Bitset &rhs); //! Compute the OR cardinality between two bitsets static size_t BitwiseOrCardinality(const Bitset &lhs, const Bitset &rhs); private: std::vector array_; }; /*! Bitmap */ class Bitmap { public: typedef FixedBitset<65536u> Bucket; //! Constructor Bitmap(void) : array_() {} //! Constructor Bitmap(const Bitmap &rhs) { this->copy(rhs); } //! Destructor ~Bitmap(void) { this->clear(); } //! Assignment Bitmap &operator=(const Bitmap &rhs) { this->copy(rhs); return *this; } //! Retrieve bucket size of bitmap size_t bucket_size(void) const { return array_.size(); } // !Clear the bitmap void clear(void); //! Remove the none buckets void shrink_to_fit(void); //! Test a bit in bitmap bool test(size_t num) const; //! Set a bit in bitmap void set(size_t num); //! Reset a bit in bitmap void reset(size_t num); //! Toggle a bit in bitmap void flip(size_t num); //! Perform binary AND void bitwise_and(const Bitmap &rhs); //! Perform binary AND NOT void bitwise_andnot(const Bitmap &rhs); //! Perform binary OR void bitwise_or(const Bitmap &rhs); //! Perform binary XOR void bitwise_xor(const Bitmap &rhs); //! Perform binary NOT (It will expand the whole map) void bitwise_not(void); //! Check if all bits are set to true bool test_all(void) const; //! Check if any bits are set to true bool test_any(void) const; //! Check if none of the bits are set to true bool test_none(void) const; //! Compute the cardinality of a bitmap size_t cardinality(void) const; //! Extract the bitmap to an array void extract(size_t base, std::vector *out) const; //! Extract the bitmap to an array void extract(std::vector *out) const { this->extract(0, out); } //! Compute the AND cardinality between two bitmaps static size_t BitwiseAndCardinality(const Bitmap &lhs, const Bitmap &rhs); //! Compute the ANDNOT cardinality between two bitmaps static size_t BitwiseAndnotCardinality(const Bitmap &lhs, const Bitmap &rhs); //! Compute the XOR cardinality between two bitmaps static size_t BitwiseXorCardinality(const Bitmap &lhs, const Bitmap &rhs); //! Compute the OR cardinality between two bitmaps static size_t BitwiseOrCardinality(const Bitmap &lhs, const Bitmap &rhs); protected: //! Copy the content from another bitmap void copy(const Bitmap &rhs); private: std::vector array_; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/bloom_filter.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include namespace zvec { namespace ailego { /*! Bloom Filter Calculator */ struct BloomFilterCalculator { /** * \brief Calculate probability of false positives * \param n Number of items in the filter * \param m Number of bits in the filter * \param k Number of hash functions * \return Probability of false positives */ static double Probability(size_t n, size_t m, size_t k) { return std::pow(1.0 - std::exp(-((double)k / (double)m * (double)n)), k); } /** * \brief Calculate number of items in the filter * \param m Number of bits in the filter * \param k Number of hash functions * \param p Probability of false positives * \return Number of items in the filter */ static size_t NumberOfItems(size_t m, size_t k, double p) { return (size_t)std::ceil( -((double)m / (double)k * std::log(1.0 - std::exp(std::log(p) / (double)k)))); } /** * \brief Calculate number of bits in the filter * \param n Number of items in the filter * \param p Probability of false positives * \return Number of bits in the filter */ static size_t NumberOfBits(size_t n, double p) { return (size_t)std::ceil((double)n * std::log(p) / std::log(1.0 / std::pow(2.0, std::log(2.0)))); } /** * \brief Calculate number of bits in the filter * \param n Number of items in the filter * \param k Number of hash functions * \param p Probability of false positives * \return Number of bits in the filter */ static size_t NumberOfBits(size_t n, size_t k, double p) { return (size_t)std::ceil(-((double)k * (double)n / std::log(1.0 - std::pow(p, 1.0 / (double)k)))); } /** * \brief Calculate number of bytes in the filter * \param n Number of items in the filter * \param p Probability of false positives * \return Number of bytes in the filter */ static size_t NumberOfBytes(size_t n, double p) { return ((NumberOfBits(n, p) + 7) >> 3); } /** * \brief Calculate number of bits in the filter * \param n Number of items in the filter * \param k Number of hash functions * \param p Probability of false positives * \return Number of bits in the filter */ static size_t NumberOfBytes(size_t n, size_t k, double p) { return ((NumberOfBits(n, k, p) + 7) >> 3); } /** * \brief Calculate number of hash functions * \param n Number of items in the filter * \param m Number of bits in the filter * \return Number of hash functions */ static size_t NumberOfHash(size_t n, size_t m) { return (size_t)std::round((double)m / (double)n * std::log(2.0)); } }; /*! Bloom Filter */ template class BloomFilter { public: //! Constructor BloomFilter(void) {} //! Constructor BloomFilter(size_t n, double p) { if (n > 0 && p > 0.0 && p < 1.0) { capacity_ = n; bits_count_ = BloomFilterCalculator::NumberOfBits(n, K, p); bits_count_ = ((bits_count_ + 31) >> 5) << 5; probability_ = BloomFilterCalculator::Probability(n, bits_count_, K); bitset_ = new uint32_t[bits_count_ >> 5]; memset(bitset_, 0, (bits_count_ >> 3)); } } //! Constructor BloomFilter(BloomFilter &&rhs) : bitset_(rhs.bitset_), bits_count_(rhs.bits_count_), capacity_(rhs.capacity_), count_(rhs.count_), probability_(rhs.probability_) { rhs.bitset_ = nullptr; rhs.bits_count_ = 0u; rhs.capacity_ = 0u; rhs.count_ = 0u; rhs.probability_ = 0u; } //! Destructor ~BloomFilter(void) { delete[] bitset_; } //! Test if the filter is valid bool is_valid(void) const { return (bitset_ != nullptr); } //! Reset the bloom filter bool reset(size_t n, double p) { if (n <= 0 || p <= 0.0 || p >= 1.0) { return false; } delete[] bitset_; capacity_ = n; count_ = 0u; bits_count_ = BloomFilterCalculator::NumberOfBits(n, K, p); bits_count_ = ((bits_count_ + 31) >> 5) << 5; probability_ = BloomFilterCalculator::Probability(n, bits_count_, K); bitset_ = new (std::nothrow) uint32_t[bits_count_ >> 5]; if (!bitset_) { return false; } memset(bitset_, 0, (bits_count_ >> 3)); return true; } //! Clear the bloom filter void clear(void) { if (bitset_) { memset(bitset_, 0, (bits_count_ >> 3)); count_ = 0u; } } //! Insert a item into bloom filter template ...>::value && sizeof...(TArgs) == K>::type> bool insert(TArgs... vals) { if (count_ >= capacity_) { return false; } this->set_bits(vals...); ++count_; return true; } //! Force insert a item into bloom filter template ...>::value && sizeof...(TArgs) == K>::type> void force_insert(TArgs... vals) { this->set_bits(vals...); ++count_; } //! Insert a item into bloom filter template ...>::value && sizeof...(TArgs) == K>::type> bool has(TArgs... vals) const { return this->test_bits(vals...); } //! Retrieve count of bits in bloom filter size_t bits_count(void) const { return bits_count_; } //! Retrieve capacity of bloom filter size_t capacity(void) const { return capacity_; } //! Retrieve count of items in bloom filter size_t count(void) const { return count_; } //! Retrieve probability of false positives double probability(void) const { return probability_; } protected: //! Disable them BloomFilter(const BloomFilter &) = delete; BloomFilter &operator=(const BloomFilter &) = delete; //! Set bits in bloom filter template void set_bits(TArg val) { size_t num = static_cast(val) % bits_count_; bitset_[num >> 5] |= (1u << (num & 0x1f)); } //! Set bits in bloom filter template void set_bits(TArg val, TArgs... vals) { this->set_bits(val); this->set_bits(vals...); } //! Test bits in bloom filter template bool test_bits(TArg val) const { size_t num = static_cast(val) % bits_count_; return ((bitset_[num >> 5] & (1u << (num & 0x1f))) != 0); } //! Test bits in bloom filter template bool test_bits(TArg val, TArgs... vals) const { if (!this->test_bits(val)) { return false; } return this->test_bits(vals...); } private: uint32_t *bitset_{nullptr}; size_t bits_count_{0u}; size_t capacity_{0u}; size_t count_{0u}; double probability_{0.0}; }; /*! Bloom Filter (Special) */ template <> struct BloomFilter<0> {}; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/params.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include //! Global environ variable extern char **environ; namespace zvec { namespace ailego { static void ParseFromJsonObject(const ailego::JsonObject &obj, Params *params) { for (ailego::JsonObject::const_iterator it = obj.begin(); it != obj.end(); ++it) { const ailego::JsonValue &val = it->value(); if (val.is_boolean()) { params->set(it->key().as_stl_string(), val.as_bool()); } else if (val.is_integer()) { params->set(it->key().as_stl_string(), static_cast(val.as_integer())); } else if (val.is_float()) { params->set(it->key().as_stl_string(), val.as_float()); } else if (val.is_string()) { params->set(it->key().as_stl_string(), val.as_string().decode().as_stl_string()); } else if (val.is_object()) { Params subparams; ParseFromJsonObject(val.as_object(), &subparams); params->set(it->key().as_stl_string(), std::move(subparams)); } } } bool Params::ParseFromBuffer(const std::string &buf, Params *params) { ailego::JsonValue val; ailego::JsonParser parser; parser.set_comment(true); parser.set_simple(true); parser.set_squote(true); parser.set_unstrict(false); if (!parser.parse(buf.c_str(), &val)) { return false; } if (!val.is_object()) { return false; } ParseFromJsonObject(val.as_object(), params); return true; } void Params::ParseFromEnvironment(Params *params) { // Dump all environ string for (size_t i = 0; environ[i]; ++i) { const char *env = environ[i]; const char *p = std::strchr(env, '='); if (p) { params->set(std::string(env, p - env), std::string(p + 1)); } } } static void SerializeToJsonObject(const Params ¶ms, ailego::JsonObject *obj) { for (const auto &it : params.hypercube().cubes()) { const ailego::Cube &cube = it.second; const char *key = it.first.c_str(); if (cube.compatible()) { const auto &val = cube.unsafe_cast(); ailego::JsonString str(val.data(), val.size()); obj->set(key, ailego::JsonValue(str.encode())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { obj->set(key, ailego::JsonValue(cube.unsafe_cast())); } else if (cube.compatible()) { ailego::JsonObject subobj; SerializeToJsonObject(cube.unsafe_cast(), &subobj); obj->set(key, ailego::JsonValue(subobj)); } else { LOG_WARN("Unsupported serializing \'%s\' <%s>.", key, cube.type().name()); } } } void Params::SerializeToBuffer(const Params ¶ms, std::string *buf) { if (buf != nullptr) { ailego::JsonObject obj; SerializeToJsonObject(params, &obj); buf->assign(ailego::JsonValue(obj).as_json_string().as_stl_string()); } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/reservoir.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { /*! Sampling Reservoir */ template > class Reservoir { public: //! Constructor Reservoir(size_t cnt) : samples_(cnt), total_(0), mt_(std::random_device()()), pool_() { pool_.reserve(samples_); } //! Constructor Reservoir(const Reservoir &rhs) : samples_(rhs.samples_), total_(rhs.total_), mt_(std::random_device()()), pool_(rhs.pool_) {} //! Constructor Reservoir(Reservoir &&rhs) : samples_(rhs.samples_), total_(rhs.total_), mt_(std::random_device()()), pool_(std::move(rhs.pool_)) {} //! Destructor ~Reservoir(void) {} //! Assignment Reservoir &operator=(const Reservoir &rhs) { samples_ = rhs.samples_; total_ = rhs.total_; pool_ = rhs.pool_; return *this; } //! Assignment Reservoir &operator=(Reservoir &&rhs) { samples_ = rhs.samples_; total_ = rhs.total_; pool_ = std::move(rhs.pool_); return *this; } //! Retrieve pool of reservoir std::vector *mutable_pool(void) { return &pool_; } //! Retrieve pool of reservoir const std::vector &pool(void) const { return pool_; } //! Retrieve count of samples size_t samples(void) const { return samples_; } //! Retrieve total count of filling size_t total(void) const { return total_; } //! Reset the reservoir void reset(void) { total_ = 0; pool_.clear(); pool_.reserve(samples_); } //! Fill the reservoir void fill(const T &item) { if (samples_ > 0) { if (pool_.size() >= samples_) { std::uniform_int_distribution dt(0, total_); size_t i = dt(mt_); if (i < samples_) { pool_[i] = item; } } else { pool_.push_back(item); } } ++total_; } //! Fill the reservoir void fill(T &&item) { if (samples_ > 0) { if (pool_.size() >= samples_) { std::uniform_int_distribution dt(0, total_); size_t i = dt(mt_); if (i < samples_) { pool_[i] = std::move(item); } } else { pool_.push_back(std::move(item)); } } ++total_; } private: //! Disable them Reservoir(void) = delete; //! Members size_t samples_; size_t total_; std::mt19937 mt_; std::vector pool_; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/container/vector_array.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { /*! Numerical Vector Array */ template ::value>::type> class NumericalVectorArray { public: //! Type of value using ValueType = typename NumericalVector::ValueType; //! Constructor NumericalVectorArray(void) {} //! Constructor explicit NumericalVectorArray(size_t dim) : dimension_(dim) {} //! Constructor NumericalVectorArray(const NumericalVectorArray &rhs) : dimension_(rhs.dimension_), buffer_(rhs.buffer_) {} //! Constructor NumericalVectorArray(NumericalVectorArray &&rhs) : dimension_(rhs.dimension_), buffer_(std::move(rhs.buffer_)) {} //! Assignment NumericalVectorArray &operator=(const NumericalVectorArray &rhs) { dimension_ = rhs.dimension_; buffer_ = rhs.buffer_; return *this; } //! Assignment NumericalVectorArray &operator=(NumericalVectorArray &&rhs) { dimension_ = rhs.dimension_; buffer_ = std::move(rhs.buffer_); return *this; } //! Overloaded operator [] ValueType *operator[](size_t i) { return (reinterpret_cast(&buffer_[0]) + i * dimension_); } //! Overloaded operator [] const ValueType *operator[](size_t i) const { return (reinterpret_cast(buffer_.data()) + i * dimension_); } //! Append a vector void append(const ValueType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), dim * sizeof(ValueType)); } //! Append vectors void append(const ValueType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), cnt * dim * sizeof(ValueType)); } //! Append a vector void append(const NumericalVector &vec) { this->append(vec.data(), vec.dimension()); } //! Replace a vector void replace(size_t index, const ValueType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = dim * sizeof(ValueType); buffer_.replace(index * element_size, element_size, reinterpret_cast(vec), element_size); } //! Replace a vector void replace(size_t index, const ValueType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = dim * sizeof(ValueType); size_t total = element_size * cnt; buffer_.replace(index * element_size, total, reinterpret_cast(vec), total); } //! Replace a vector void replace(size_t index, const NumericalVector &vec) { this->replace(index, vec.data(), vec.dimension()); } //! Request a change in capacity void reserve(size_t n) { buffer_.reserve(n * dimension_ * sizeof(ValueType)); } //! Resize the array to a length of n elements void resize(size_t n) { buffer_.resize(n * dimension_ * sizeof(ValueType)); } //! Clear the vector array void clear(void) { buffer_.clear(); } //! Reset the vector array void reset(size_t dim) { dimension_ = dim; buffer_.clear(); } //! Requests the removal of unused capacity. void shrink_to_fit(void) { buffer_.shrink_to_fit(); } //! Retrieve pointer of data ValueType *data(void) { return reinterpret_cast(&buffer_[0]); } //! Retrieve pointer of data const ValueType *data(void) const { return reinterpret_cast(buffer_.data()); } //! Retrieve pointer of data ValueType *at(size_t i) { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return (reinterpret_cast(&buffer_[0]) + i * dimension_); } //! Retrieve pointer of data const ValueType *at(size_t i) const { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return (reinterpret_cast(buffer_.data()) + i * dimension_); } //! Test if the array is empty bool empty(void) const { return buffer_.empty(); } //! Retrieve count of vectors size_t count(void) const { return (dimension_ > 0 ? buffer_.size() / (dimension_ * sizeof(ValueType)) : 0u); } //! Retrieve dimension of vector size_t dimension(void) const { return dimension_; } //! Retrieve size of array in bytes size_t bytes(void) const { return buffer_.size(); } private: size_t dimension_{0u}; std::string buffer_{}; }; /*! Nibble Vector Array */ template ::value>::type> class NibbleVectorArray { public: //! Type of value using ValueType = typename NibbleVector::ValueType; using StoreType = typename NibbleVector::StoreType; //! Constructor NibbleVectorArray(void) {} //! Constructor explicit NibbleVectorArray(size_t dim) : dimension_((dim + (sizeof(ValueType) << 1) - 1) / (sizeof(ValueType) << 1) * sizeof(ValueType) << 1) {} //! Constructor NibbleVectorArray(const NibbleVectorArray &rhs) : dimension_(rhs.dimension_), buffer_(rhs.buffer_) {} //! Constructor NibbleVectorArray(NibbleVectorArray &&rhs) : dimension_(rhs.dimension_), buffer_(std::move(rhs.buffer_)) {} //! Assignment NibbleVectorArray &operator=(const NibbleVectorArray &rhs) { dimension_ = rhs.dimension_; buffer_ = rhs.buffer_; return *this; } //! Assignment NibbleVectorArray &operator=(NibbleVectorArray &&rhs) { dimension_ = rhs.dimension_; buffer_ = std::move(rhs.buffer_); return *this; } //! Overloaded operator [] StoreType *operator[](size_t i) { return reinterpret_cast(&buffer_[0] + i * (dimension_ >> 1)); } //! Overloaded operator [] const StoreType *operator[](size_t i) const { return reinterpret_cast(&buffer_[0] + i * (dimension_ >> 1)); } //! Append a vector void append(const StoreType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), dim >> 1); } //! Append vectors void append(const StoreType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), cnt * (dim >> 1)); } //! Append a vector void append(const NibbleVector &vec) { this->append(vec.data(), vec.dimension()); } //! Replace a vector void replace(size_t index, const StoreType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = (dim >> 1); buffer_.replace(index * element_size, element_size, reinterpret_cast(vec), element_size); } //! Replace a vector void replace(size_t index, const StoreType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = (dim >> 1); size_t total = element_size * cnt; buffer_.replace(index * element_size, total, reinterpret_cast(vec), total); } //! Replace a vector void replace(size_t index, const NibbleVector &vec) { this->replace(index, vec.data(), vec.dimension()); } //! Request a change in capacity void reserve(size_t n) { buffer_.reserve(n * (dimension_ >> 1)); } //! Resize the array to a length of n elements void resize(size_t n) { buffer_.resize(n * (dimension_ >> 1)); } //! Clear the vector array void clear(void) { buffer_.clear(); } //! Reset the vector array void reset(size_t dim) { dimension_ = (dim + (sizeof(ValueType) << 1) - 1) / (sizeof(ValueType) << 1) * sizeof(ValueType) << 1; buffer_.clear(); } //! Requests the removal of unused capacity. void shrink_to_fit(void) { buffer_.shrink_to_fit(); } //! Retrieve pointer of data StoreType *data(void) { return reinterpret_cast(&buffer_[0]); } //! Retrieve pointer of data const StoreType *data(void) const { return reinterpret_cast(buffer_.data()); } //! Retrieve pointer of data StoreType *at(size_t i) { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return reinterpret_cast(&buffer_[0] + i * (dimension_ >> 1)); } //! Retrieve pointer of data const StoreType *at(size_t i) const { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return reinterpret_cast(buffer_.data() + i * (dimension_ >> 1)); } //! Test if the array is empty bool empty(void) const { return buffer_.empty(); } //! Retrieve count of vectors size_t count(void) const { return (dimension_ > 1 ? buffer_.size() / (dimension_ >> 1) : 0u); } //! Retrieve dimension of vector size_t dimension(void) const { return dimension_; } //! Retrieve size of array in bytes size_t bytes(void) const { return buffer_.size(); } private: size_t dimension_{0u}; std::string buffer_{}; }; /*! Binary Vector Array */ template ::value>::type> class BinaryVectorArray { public: //! Type of value using ValueType = typename BinaryVector::ValueType; //! Constructor BinaryVectorArray(void) {} //! Constructor explicit BinaryVectorArray(size_t dim) : dimension_((dim + (sizeof(ValueType) << 3) - 1) / (sizeof(ValueType) << 3) * (sizeof(ValueType) << 3)) {} //! Constructor BinaryVectorArray(const BinaryVectorArray &rhs) : dimension_(rhs.dimension_), buffer_(rhs.buffer_) {} //! Constructor BinaryVectorArray(BinaryVectorArray &&rhs) : dimension_(rhs.dimension_), buffer_(std::move(rhs.buffer_)) {} //! Assignment BinaryVectorArray &operator=(const BinaryVectorArray &rhs) { dimension_ = rhs.dimension_; buffer_ = rhs.buffer_; return *this; } //! Assignment BinaryVectorArray &operator=(BinaryVectorArray &&rhs) { dimension_ = rhs.dimension_; buffer_ = std::move(rhs.buffer_); return *this; } //! Overloaded operator [] ValueType *operator[](size_t i) { return reinterpret_cast(&buffer_[0] + i * (dimension_ >> 3)); } //! Overloaded operator [] const ValueType *operator[](size_t i) const { return reinterpret_cast(buffer_.data() + i * (dimension_ >> 3)); } //! Append a vector void append(const ValueType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), (dim >> 3)); } //! Append vectors void append(const ValueType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } buffer_.append(reinterpret_cast(vec), cnt * (dim >> 3)); } //! Append a vector void append(const BinaryVector &vec) { this->append(vec.data(), vec.dimension()); } //! Replace a vector void replace(size_t index, const ValueType *vec, size_t dim) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = (dim >> 3); buffer_.replace(index * element_size, element_size, reinterpret_cast(vec), element_size); } //! Replace a vector void replace(size_t index, const ValueType *vec, size_t dim, size_t cnt) { if (ailego_unlikely(dim != dimension_)) { throw std::length_error("Unmatched dimension"); } size_t element_size = (dim >> 3); size_t total = element_size * cnt; buffer_.replace(index * element_size, total, reinterpret_cast(vec), total); } //! Replace a vector void replace(size_t index, const BinaryVector &vec) { this->replace(index, vec.data(), vec.dimension()); } //! Request a change in capacity void reserve(size_t n) { buffer_.reserve(n * (dimension_ >> 3)); } //! Resize the array to a length of n elements void resize(size_t n) { buffer_.resize(n * (dimension_ >> 3)); } //! Clear the vector array void clear(void) { buffer_.clear(); } //! Reset the vector array void reset(size_t dim) { dimension_ = (dim + (sizeof(ValueType) << 3) - 1) / (sizeof(ValueType) << 3) * (sizeof(ValueType) << 3); buffer_.clear(); } //! Requests the removal of unused capacity. void shrink_to_fit(void) { buffer_.shrink_to_fit(); } //! Retrieve pointer of data ValueType *data(void) { return reinterpret_cast(&buffer_[0]); } //! Retrieve pointer of data const ValueType *data(void) const { return reinterpret_cast(buffer_.data()); } //! Retrieve pointer of data ValueType *at(size_t i) { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return reinterpret_cast(&buffer_[0] + i * (dimension_ >> 3)); } //! Retrieve pointer of data const ValueType *at(size_t i) const { if (ailego_unlikely(i >= this->count())) { throw std::out_of_range("Index overflow"); } return reinterpret_cast(buffer_.data() + i * (dimension_ >> 3)); } //! Test if the array is empty bool empty(void) const { return buffer_.empty(); } //! Retrieve count of vectors size_t count(void) const { return (dimension_ > 0 ? buffer_.size() / (dimension_ >> 3) : 0u); } //! Retrieve dimension of vector size_t dimension(void) const { return dimension_; } //! Retrieve size of array in bytes size_t bytes(void) const { return buffer_.size(); } private: size_t dimension_{0u}; std::string buffer_{}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/encoding/json/mod_json.c ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #ifndef MOD_JSON_TOKEN_DEFOPTS #define MOD_JSON_TOKEN_DEFOPTS 0 /* default options of token */ #endif #ifndef MOD_JSON_TOKEN_DEFOBJDEP #define MOD_JSON_TOKEN_DEFOBJDEP 64 /* default objects depth of token */ #endif #ifndef MOD_JSON_TOKEN_DEFARRDEP #define MOD_JSON_TOKEN_DEFARRDEP 64 /* default arrays depth of token */ #endif #ifndef MOD_JSON_STRING_DEFSIZE #define MOD_JSON_STRING_DEFSIZE 32 /* default started size of string */ #endif #ifndef MOD_JSON_ARRAY_DEFSIZE #define MOD_JSON_ARRAY_DEFSIZE 32 /* default started size of array */ #endif #ifndef MOD_JSON_OBJECT_DEFSIZE #define MOD_JSON_OBJECT_DEFSIZE 32 /* default started size of object */ #endif #ifndef mod_json_malloc #define mod_json_malloc malloc #endif #ifndef mod_json_free #define mod_json_free free #endif #ifdef __GNUC__ #define mod_json_likely(x) __builtin_expect(!!(x), 1) #define mod_json_unlikely(x) __builtin_expect(!!(x), 0) #else #define mod_json_likely(x) (x) #define mod_json_unlikely(x) (x) #endif #define mod_json_minus_if_ne_zero(COND) \ if (mod_json_unlikely((COND) != 0)) return (-1) #define mod_json_minus_if_false(COND) \ if (mod_json_unlikely(!(COND))) return (-1) #define mod_json_null_if_ne_zero(COND) \ if (mod_json_unlikely((COND) != 0)) return (NULL) #define mod_json_null_if_false(COND) \ if (mod_json_unlikely(!(COND))) return (NULL) #if defined(_MSC_VER) #pragma warning(disable : 4200) #define strtoull _strtoui64 #define snprintf(buf, size, format, ...) \ _snprintf_s(buf, size, _TRUNCATE, format, ##__VA_ARGS__) #endif #define mod_json_utils_snprintf snprintf #define mod_json_utils_strtoi strtoull #define mod_json_utils_strtof strtod #define mod_json_utils_strlen strlen /*! JSON Token */ struct mod_json_token { mod_json_state_t state; mod_json_error_t error; mod_json_cchar_t *context; mod_json_size_t options; mod_json_size_t object_max_depth; mod_json_size_t array_max_depth; mod_json_size_t object_depth; mod_json_size_t array_depth; mod_json_event_t event_code; mod_json_event_proc event_proc; mod_json_void_t *param; mod_json_char_t tags[0]; }; typedef struct mod_json_parser mod_json_parser_t; /*! JSON Parser */ struct mod_json_parser { mod_json_string_t *key; mod_json_value_t *val_null; mod_json_value_t *val_true; mod_json_value_t *val_false; mod_json_value_t *val_zero; mod_json_value_t *val_zerof; mod_json_value_t *val_empty; mod_json_value_t *vals[0]; }; static inline mod_json_size_t mod_json_utils_clp2(mod_json_size_t n) { n = n - 1; n = n | (n >> 1); n = n | (n >> 2); n = n | (n >> 4); n = n | (n >> 8); n = n | (n >> 16); return (n + 1); } static inline mod_json_size_t mod_json_utils_itostr(mod_json_char_t *buf, mod_json_integer_t val) { mod_json_char_t *pos, *first, *last; pos = buf; if (val < 0) { *pos++ = '-'; val = -val; } /* save pointer to first digit */ first = pos; do { /* convert to ASCII and store */ *pos++ = (mod_json_char_t)(val % 10 + '0'); /* next digit */ val /= 10; } while (val > 0); *pos = '\0'; /* save pointer to last digit */ last = pos - 1; /* reverse digit string */ while (first < last) { mod_json_char_t temp = *first; *first++ = *last; *last-- = temp; } return (mod_json_size_t)(pos - buf); } static inline mod_json_float_t mod_json_utils_pow10(int n) { /* 1e-308...1e308: 617 * 8 bytes = 4936 bytes */ static const mod_json_float_t etab[] = { 1e-308, 1e-307, 1e-306, 1e-305, 1e-304, 1e-303, 1e-302, 1e-301, 1e-300, 1e-299, 1e-298, 1e-297, 1e-296, 1e-295, 1e-294, 1e-293, 1e-292, 1e-291, 1e-290, 1e-289, 1e-288, 1e-287, 1e-286, 1e-285, 1e-284, 1e-283, 1e-282, 1e-281, 1e-280, 1e-279, 1e-278, 1e-277, 1e-276, 1e-275, 1e-274, 1e-273, 1e-272, 1e-271, 1e-270, 1e-269, 1e-268, 1e-267, 1e-266, 1e-265, 1e-264, 1e-263, 1e-262, 1e-261, 1e-260, 1e-259, 1e-258, 1e-257, 1e-256, 1e-255, 1e-254, 1e-253, 1e-252, 1e-251, 1e-250, 1e-249, 1e-248, 1e-247, 1e-246, 1e-245, 1e-244, 1e-243, 1e-242, 1e-241, 1e-240, 1e-239, 1e-238, 1e-237, 1e-236, 1e-235, 1e-234, 1e-233, 1e-232, 1e-231, 1e-230, 1e-229, 1e-228, 1e-227, 1e-226, 1e-225, 1e-224, 1e-223, 1e-222, 1e-221, 1e-220, 1e-219, 1e-218, 1e-217, 1e-216, 1e-215, 1e-214, 1e-213, 1e-212, 1e-211, 1e-210, 1e-209, 1e-208, 1e-207, 1e-206, 1e-205, 1e-204, 1e-203, 1e-202, 1e-201, 1e-200, 1e-199, 1e-198, 1e-197, 1e-196, 1e-195, 1e-194, 1e-193, 1e-192, 1e-191, 1e-190, 1e-189, 1e-188, 1e-187, 1e-186, 1e-185, 1e-184, 1e-183, 1e-182, 1e-181, 1e-180, 1e-179, 1e-178, 1e-177, 1e-176, 1e-175, 1e-174, 1e-173, 1e-172, 1e-171, 1e-170, 1e-169, 1e-168, 1e-167, 1e-166, 1e-165, 1e-164, 1e-163, 1e-162, 1e-161, 1e-160, 1e-159, 1e-158, 1e-157, 1e-156, 1e-155, 1e-154, 1e-153, 1e-152, 1e-151, 1e-150, 1e-149, 1e-148, 1e-147, 1e-146, 1e-145, 1e-144, 1e-143, 1e-142, 1e-141, 1e-140, 1e-139, 1e-138, 1e-137, 1e-136, 1e-135, 1e-134, 1e-133, 1e-132, 1e-131, 1e-130, 1e-129, 1e-128, 1e-127, 1e-126, 1e-125, 1e-124, 1e-123, 1e-122, 1e-121, 1e-120, 1e-119, 1e-118, 1e-117, 1e-116, 1e-115, 1e-114, 1e-113, 1e-112, 1e-111, 1e-110, 1e-109, 1e-108, 1e-107, 1e-106, 1e-105, 1e-104, 1e-103, 1e-102, 1e-101, 1e-100, 1e-99, 1e-98, 1e-97, 1e-96, 1e-95, 1e-94, 1e-93, 1e-92, 1e-91, 1e-90, 1e-89, 1e-88, 1e-87, 1e-86, 1e-85, 1e-84, 1e-83, 1e-82, 1e-81, 1e-80, 1e-79, 1e-78, 1e-77, 1e-76, 1e-75, 1e-74, 1e-73, 1e-72, 1e-71, 1e-70, 1e-69, 1e-68, 1e-67, 1e-66, 1e-65, 1e-64, 1e-63, 1e-62, 1e-61, 1e-60, 1e-59, 1e-58, 1e-57, 1e-56, 1e-55, 1e-54, 1e-53, 1e-52, 1e-51, 1e-50, 1e-49, 1e-48, 1e-47, 1e-46, 1e-45, 1e-44, 1e-43, 1e-42, 1e-41, 1e-40, 1e-39, 1e-38, 1e-37, 1e-36, 1e-35, 1e-34, 1e-33, 1e-32, 1e-31, 1e-30, 1e-29, 1e-28, 1e-27, 1e-26, 1e-25, 1e-24, 1e-23, 1e-22, 1e-21, 1e-20, 1e-19, 1e-18, 1e-17, 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e+0, 1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10, 1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20, 1e+21, 1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32, 1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40, 1e+41, 1e+42, 1e+43, 1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54, 1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60, 1e+61, 1e+62, 1e+63, 1e+64, 1e+65, 1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76, 1e+77, 1e+78, 1e+79, 1e+80, 1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87, 1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98, 1e+99, 1e+100, 1e+101, 1e+102, 1e+103, 1e+104, 1e+105, 1e+106, 1e+107, 1e+108, 1e+109, 1e+110, 1e+111, 1e+112, 1e+113, 1e+114, 1e+115, 1e+116, 1e+117, 1e+118, 1e+119, 1e+120, 1e+121, 1e+122, 1e+123, 1e+124, 1e+125, 1e+126, 1e+127, 1e+128, 1e+129, 1e+130, 1e+131, 1e+132, 1e+133, 1e+134, 1e+135, 1e+136, 1e+137, 1e+138, 1e+139, 1e+140, 1e+141, 1e+142, 1e+143, 1e+144, 1e+145, 1e+146, 1e+147, 1e+148, 1e+149, 1e+150, 1e+151, 1e+152, 1e+153, 1e+154, 1e+155, 1e+156, 1e+157, 1e+158, 1e+159, 1e+160, 1e+161, 1e+162, 1e+163, 1e+164, 1e+165, 1e+166, 1e+167, 1e+168, 1e+169, 1e+170, 1e+171, 1e+172, 1e+173, 1e+174, 1e+175, 1e+176, 1e+177, 1e+178, 1e+179, 1e+180, 1e+181, 1e+182, 1e+183, 1e+184, 1e+185, 1e+186, 1e+187, 1e+188, 1e+189, 1e+190, 1e+191, 1e+192, 1e+193, 1e+194, 1e+195, 1e+196, 1e+197, 1e+198, 1e+199, 1e+200, 1e+201, 1e+202, 1e+203, 1e+204, 1e+205, 1e+206, 1e+207, 1e+208, 1e+209, 1e+210, 1e+211, 1e+212, 1e+213, 1e+214, 1e+215, 1e+216, 1e+217, 1e+218, 1e+219, 1e+220, 1e+221, 1e+222, 1e+223, 1e+224, 1e+225, 1e+226, 1e+227, 1e+228, 1e+229, 1e+230, 1e+231, 1e+232, 1e+233, 1e+234, 1e+235, 1e+236, 1e+237, 1e+238, 1e+239, 1e+240, 1e+241, 1e+242, 1e+243, 1e+244, 1e+245, 1e+246, 1e+247, 1e+248, 1e+249, 1e+250, 1e+251, 1e+252, 1e+253, 1e+254, 1e+255, 1e+256, 1e+257, 1e+258, 1e+259, 1e+260, 1e+261, 1e+262, 1e+263, 1e+264, 1e+265, 1e+266, 1e+267, 1e+268, 1e+269, 1e+270, 1e+271, 1e+272, 1e+273, 1e+274, 1e+275, 1e+276, 1e+277, 1e+278, 1e+279, 1e+280, 1e+281, 1e+282, 1e+283, 1e+284, 1e+285, 1e+286, 1e+287, 1e+288, 1e+289, 1e+290, 1e+291, 1e+292, 1e+293, 1e+294, 1e+295, 1e+296, 1e+297, 1e+298, 1e+299, 1e+300, 1e+301, 1e+302, 1e+303, 1e+304, 1e+305, 1e+306, 1e+307, 1e+308}; return (n < -308 ? 0.0 : etab[n + 308]); } static inline mod_json_cchar_t *mod_json_utils_strskpb(mod_json_cchar_t *cstr) { static const mod_json_char_t blanks[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, '\t', '\n', '\v', '\f', '\r', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ' ', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; while (*(blanks + *cstr)) { ++cstr; } return cstr; } static inline mod_json_cchar_t *mod_json_utils_strskpc1( mod_json_cchar_t *cstr) { mod_json_char_t c; while ((c = *cstr++) != '\0') { if (c == '\r' || c == '\n') { return mod_json_utils_strskpb(cstr); } } return (cstr - 1); } static inline mod_json_cchar_t *mod_json_utils_strskpc2( mod_json_cchar_t *cstr) { mod_json_char_t c; while ((c = *cstr++) != '\0') { /* asterisk, slash */ if (c == '*' && *cstr == '/') { return mod_json_utils_strskpb(cstr + 1); } } return (cstr - 1); } static inline mod_json_cchar_t *mod_json_utils_strskp(mod_json_cchar_t *cstr) { cstr = mod_json_utils_strskpb(cstr); /* treat it as comments? */ while (*cstr == '/') { mod_json_char_t c = *(cstr + 1); /* second char */ if (c == '/') { /* two slashes */ cstr = mod_json_utils_strskpc1(cstr + 2); } else if (c == '*') { /* slash, asterisk */ cstr = mod_json_utils_strskpc2(cstr + 2); } else { /* invalid format */ break; } } return cstr; } static inline int mod_json_utils_char2hex(mod_json_char_t ch) { static const mod_json_char_t char2hex[256] = { 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 16, 16, 16, 16, 16, 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16}; return *(char2hex + ch); } static inline mod_json_cchar_t *mod_json_utils_strfquo(mod_json_cchar_t *cstr, mod_json_char_t quo) { mod_json_char_t c; for (c = *cstr; c != quo; c = *(++cstr)) { if ((mod_json_uchar_t)c <= 0x1f) { return NULL; } if (c != '\\') { continue; } /* next char */ switch (*(++cstr)) { case '\"': case '/': case 'b': case 'f': case '\\': case 'n': case 'r': case 't': /* ignore next char */ break; case 'u': if (mod_json_utils_char2hex(*(cstr + 1)) > 15) { return NULL; } if (mod_json_utils_char2hex(*(cstr + 2)) > 15) { return NULL; } if (mod_json_utils_char2hex(*(cstr + 3)) > 15) { return NULL; } if (mod_json_utils_char2hex(*(cstr + 4)) > 15) { return NULL; } cstr += 4; break; default: /* invalid */ return NULL; } } /* found it */ return cstr; } static inline mod_json_cchar_t *mod_json_utils_strfquo2(mod_json_cchar_t *cstr, mod_json_char_t quo) { mod_json_char_t c; for (c = *cstr; c; c = *(++cstr)) { if (c == quo) { /* found it */ return cstr; } if (c == '\\') { /* ignore next char */ if (*(++cstr) == '\0') { break; } } } return NULL; } static inline mod_json_cchar_t *mod_json_utils_strfsep(mod_json_cchar_t *cstr) { mod_json_char_t c; while ((c = *cstr++) != '\0') { switch (c) { case ':': case ' ': case '\t': case '\r': case '\n': case '\f': case '\v': return (cstr - 1); } } return (cstr - 1); } static inline mod_json_cchar_t *mod_json_utils_strfsep2( mod_json_cchar_t *cstr) { mod_json_char_t c; while ((c = *cstr++) != '\0') { switch (c) { case ':': case ' ': case '\t': case '\r': case '\n': case '\f': case '\v': return (cstr - 1); case '/': if (*cstr == '/' || *cstr == '*') { return (cstr - 1); } } } return (cstr - 1); } static inline mod_json_char_t *mod_json_utils_uni2utf8(mod_json_char_t *buf, mod_json_size_t size, mod_json_uchar_t high, mod_json_uchar_t low) { /* convert to UTF-8 */ if (high >= 0x8) { /* 0800 - FFFF | 1110xxxx 10xxxxxx 10xxxxxx */ if (size >= 3) { *buf++ = (mod_json_char_t)(0xE0 | (high >> 4)); *buf++ = (mod_json_char_t)(0x80 | ((high & 0xF) << 2) | (low >> 6)); *buf++ = (mod_json_char_t)(0x80 | (low & 0x3F)); return buf; } } else if (high > 0 || low >= 0x80) { /* 0080 - 07FF | 110xxxxx 10xxxxxx */ if (size >= 2) { *buf++ = (mod_json_char_t)(0xC0 | (high << 2) | (low >> 6)); *buf++ = (mod_json_char_t)(0x80 | (low & 0x3F)); return buf; } } else { /* 0000 - 007F | 0xxxxxxx */ if (size >= 1) { *buf++ = (mod_json_char_t)(low); return buf; } } return (mod_json_char_t *)0; } mod_json_value_t *mod_json_value_set_null(void) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_null; val->data.c_int = 0; return val; } mod_json_value_t *mod_json_value_set_object(mod_json_object_t *obj) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_object; val->data.c_obj = obj ? mod_json_object_grab(obj) : NULL; return val; } mod_json_value_t *mod_json_value_set_array(mod_json_array_t *arr) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_array; val->data.c_arr = arr ? mod_json_array_grab(arr) : NULL; return val; } mod_json_value_t *mod_json_value_set_string(mod_json_string_t *str) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_string; val->data.c_str = str ? mod_json_string_grab(str) : NULL; return val; } mod_json_value_t *mod_json_value_set_buffer(mod_json_cchar_t *buf, mod_json_size_t len) { mod_json_value_t *val; mod_json_string_t *str; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); /* create a string */ str = mod_json_string_set(buf, len); if (mod_json_unlikely(!str)) { mod_json_free(val); return NULL; } val->refer = 1; val->type = mod_json_type_string; val->data.c_str = str; return val; } mod_json_value_t *mod_json_value_set_integer(mod_json_integer_t num) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_integer; val->data.c_int = num; return val; } mod_json_value_t *mod_json_value_set_float(mod_json_float_t dbl) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_float; val->data.c_float = dbl; return val; } mod_json_value_t *mod_json_value_set_boolean(mod_json_boolean_t bol) { mod_json_value_t *val; /* create a value */ val = (mod_json_value_t *)mod_json_malloc(sizeof(mod_json_value_t)); mod_json_null_if_false(val); val->refer = 1; val->type = mod_json_type_boolean; val->data.c_bool = bol ? MOD_JSON_TRUE : MOD_JSON_FALSE; return val; } static inline void mod_json_value_clear(mod_json_value_t *val) { switch (val->type) { case mod_json_type_object: mod_json_object_unset(val->data.c_obj); break; case mod_json_type_array: mod_json_array_unset(val->data.c_arr); break; case mod_json_type_string: mod_json_string_unset(val->data.c_str); break; default: break; } } void mod_json_value_assign_null(mod_json_value_t *val) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_null; val->data.c_int = 0; } } void mod_json_value_assign_object(mod_json_value_t *val, mod_json_object_t *obj) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_object; val->data.c_obj = obj ? mod_json_object_grab(obj) : NULL; } } void mod_json_value_assign_array(mod_json_value_t *val, mod_json_array_t *arr) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_array; val->data.c_arr = arr ? mod_json_array_grab(arr) : NULL; } } void mod_json_value_assign_string(mod_json_value_t *val, mod_json_string_t *str) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_string; val->data.c_str = str ? mod_json_string_grab(str) : NULL; } } void mod_json_value_assign_integer(mod_json_value_t *val, mod_json_integer_t num) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_integer; val->data.c_int = num; } } void mod_json_value_assign_float(mod_json_value_t *val, mod_json_float_t dbl) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_float; val->data.c_float = dbl; } } void mod_json_value_assign_boolean(mod_json_value_t *val, mod_json_boolean_t bol) { if (val) { mod_json_value_clear(val); val->type = mod_json_type_boolean; val->data.c_bool = bol ? MOD_JSON_TRUE : MOD_JSON_FALSE; } } void mod_json_value_assign(mod_json_value_t *dst, mod_json_value_t *src) { if (!dst || dst == src) { return; } if (!src) { /* treat as JSON null */ mod_json_value_assign_null(dst); return; } switch (src->type) { case mod_json_type_boolean: mod_json_value_assign_boolean(dst, src->data.c_bool); break; case mod_json_type_integer: mod_json_value_assign_integer(dst, src->data.c_int); break; case mod_json_type_float: mod_json_value_assign_float(dst, src->data.c_float); break; case mod_json_type_string: mod_json_value_assign_string(dst, src->data.c_str); break; case mod_json_type_array: mod_json_value_assign_array(dst, src->data.c_arr); break; case mod_json_type_object: mod_json_value_assign_object(dst, src->data.c_obj); break; default: mod_json_value_assign_null(dst); break; } } static inline int mod_json_value_merge_array(mod_json_value_t *val, mod_json_array_t *arr) { if (val->type != mod_json_type_array || !val->data.c_arr) { mod_json_value_assign_array(val, arr); return 0; } if (arr) { if (mod_json_array_is_shared(val->data.c_arr)) { mod_json_array_put(val->data.c_arr); val->data.c_arr = mod_json_array_clone(val->data.c_arr); } return mod_json_array_merge(val->data.c_arr, arr); } return 0; } static inline int mod_json_value_merge_object(mod_json_value_t *val, mod_json_object_t *obj) { if (val->type != mod_json_type_object || !val->data.c_obj) { mod_json_value_assign_object(val, obj); return 0; } if (obj) { if (mod_json_object_is_shared(val->data.c_obj)) { mod_json_object_put(val->data.c_obj); val->data.c_obj = mod_json_object_clone(val->data.c_obj); } return mod_json_object_merge(val->data.c_obj, obj); } return 0; } int mod_json_value_merge(mod_json_value_t *dst, mod_json_value_t *src) { mod_json_minus_if_false(dst && dst != src); if (!src) { mod_json_value_assign_null(dst); return 0; } switch (src->type) { case mod_json_type_boolean: mod_json_value_assign_boolean(dst, src->data.c_bool); break; case mod_json_type_integer: mod_json_value_assign_integer(dst, src->data.c_int); break; case mod_json_type_float: mod_json_value_assign_float(dst, src->data.c_float); break; case mod_json_type_string: mod_json_value_assign_string(dst, src->data.c_str); break; case mod_json_type_array: return mod_json_value_merge_array(dst, src->data.c_arr); case mod_json_type_object: return mod_json_value_merge_object(dst, src->data.c_obj); default: mod_json_value_assign_null(dst); break; } return 0; } mod_json_object_t *mod_json_value_object(mod_json_value_t *val) { if (val && val->type == mod_json_type_object) { return (val->data.c_obj); } return NULL; } mod_json_array_t *mod_json_value_array(mod_json_value_t *val) { if (val && val->type == mod_json_type_array) { return (val->data.c_arr); } return NULL; } mod_json_string_t *mod_json_value_string(mod_json_value_t *val) { if (val && val->type == mod_json_type_string) { return (val->data.c_str); } return NULL; } mod_json_cchar_t *mod_json_value_cstring(mod_json_value_t *val) { if (val && val->type == mod_json_type_string) { return mod_json_string_cstr(val->data.c_str); } return NULL; } mod_json_float_t mod_json_value_float(mod_json_value_t *val) { if (val) { switch (val->type) { case mod_json_type_boolean: return (val->data.c_bool ? 1.0 : 0.0); case mod_json_type_integer: return (mod_json_float_t)(val->data.c_int); case mod_json_type_float: return (val->data.c_float); case mod_json_type_string: return mod_json_string_float(val->data.c_str); default: break; } } return (0.0); } mod_json_boolean_t mod_json_value_boolean(mod_json_value_t *val) { if (val) { switch (val->type) { case mod_json_type_null: return MOD_JSON_FALSE; case mod_json_type_object: return (mod_json_object_count(val->data.c_obj) != 0); case mod_json_type_array: return (mod_json_array_count(val->data.c_arr) != 0); case mod_json_type_string: return (mod_json_string_length(val->data.c_str) != 0); case mod_json_type_integer: return (val->data.c_int != 0); case mod_json_type_float: return (val->data.c_float != 0); case mod_json_type_boolean: return (val->data.c_bool); default: break; } } return MOD_JSON_FALSE; } mod_json_integer_t mod_json_value_integer(mod_json_value_t *val) { if (val) { switch (val->type) { case mod_json_type_boolean: return (val->data.c_bool ? 1 : 0); case mod_json_type_integer: return (val->data.c_int); case mod_json_type_float: return (mod_json_integer_t)(val->data.c_float); case mod_json_type_string: return mod_json_string_integer(val->data.c_str); default: break; } } return (0); } mod_json_value_t *mod_json_value_clone(mod_json_value_t *val) { if (val) { switch (val->type) { case mod_json_type_null: return mod_json_value_set_null(); case mod_json_type_object: return mod_json_value_set_object(val->data.c_obj); case mod_json_type_array: return mod_json_value_set_array(val->data.c_arr); case mod_json_type_string: return mod_json_value_set_string(val->data.c_str); case mod_json_type_integer: return mod_json_value_set_integer(val->data.c_int); case mod_json_type_float: return mod_json_value_set_float(val->data.c_float); case mod_json_type_boolean: return mod_json_value_set_boolean(val->data.c_bool); default: break; } } return NULL; } static inline mod_json_boolean_t mod_json_value_is_equal_float( mod_json_float_t lhs, mod_json_float_t rhs) { mod_json_float_t diff = lhs - rhs; return ((diff < DBL_EPSILON) && (diff > -DBL_EPSILON)); } mod_json_boolean_t mod_json_value_is_equal(mod_json_value_t *lhs, mod_json_value_t *rhs) { if (lhs == rhs) { /* The same pointer */ return MOD_JSON_TRUE; } if (lhs && rhs && lhs->type == rhs->type) { switch (lhs->type) { case mod_json_type_null: return MOD_JSON_TRUE; case mod_json_type_object: return mod_json_object_is_equal(lhs->data.c_obj, rhs->data.c_obj); case mod_json_type_array: return mod_json_array_is_equal(lhs->data.c_arr, rhs->data.c_arr); case mod_json_type_string: return (mod_json_string_compare(lhs->data.c_str, rhs->data.c_str) == 0); case mod_json_type_integer: return (lhs->data.c_int == rhs->data.c_int); case mod_json_type_float: return mod_json_value_is_equal_float(lhs->data.c_float, rhs->data.c_float); case mod_json_type_boolean: return ((!lhs->data.c_bool) == (!rhs->data.c_bool)); default: break; } } return MOD_JSON_FALSE; } void mod_json_value_unset(mod_json_value_t *val) { if (val && mod_json_value_put(val) <= 0) { mod_json_value_clear(val); mod_json_free(val); } } static inline int mod_json_string_expand(mod_json_string_t *str, mod_json_size_t size) { mod_json_char_t *cstr; mod_json_size_t len; size = mod_json_utils_clp2(size); if (size < MOD_JSON_STRING_DEFSIZE) { size = MOD_JSON_STRING_DEFSIZE; } mod_json_minus_if_false(size > str->size); cstr = (mod_json_char_t *)mod_json_malloc(size * sizeof(mod_json_char_t)); mod_json_minus_if_false(cstr); len = (mod_json_size_t)(str->last - str->first); if (len != 0) { memcpy(cstr, str->first, len + 1); } else { *cstr = '\0'; /* terminal character */ } mod_json_free(str->first); str->first = cstr; str->last = cstr + len; str->size = size; /* success */ return 0; } int mod_json_string_reserve(mod_json_string_t *str, mod_json_size_t n) { mod_json_minus_if_false(str); if (str->size >= n + 1) { /* needn't grow */ return 0; } return mod_json_string_expand(str, n + 1); } static inline mod_json_string_t *mod_json_string_malloc(mod_json_size_t size) { mod_json_string_t *str; mod_json_char_t *buf; buf = (mod_json_char_t *)mod_json_malloc(size * sizeof(mod_json_char_t)); mod_json_null_if_false(buf); str = (mod_json_string_t *)mod_json_malloc(sizeof(mod_json_string_t)); if (mod_json_unlikely(!str)) { mod_json_free(buf); return NULL; } str->refer = 1; str->size = size; str->first = buf; str->last = buf; *buf = '\0'; return str; } int mod_json_string_assign(mod_json_string_t *str, mod_json_cchar_t *cstr, mod_json_size_t len) { mod_json_string_reset(str); mod_json_minus_if_ne_zero(mod_json_string_reserve(str, len)); if (cstr && len) { memcpy(str->first, cstr, len); } str->last = str->first + len; *(str->last) = '\0'; /* success */ return 0; } static inline mod_json_string_t *mod_json_string_set_empty(void) { return mod_json_string_malloc(MOD_JSON_STRING_DEFSIZE); } static inline mod_json_string_t *mod_json_string_set_cstr( mod_json_cchar_t *cstr, mod_json_size_t len) { mod_json_string_t *str; str = mod_json_string_malloc(mod_json_utils_clp2(len + 1)); mod_json_null_if_false(str); str->last = str->first + len; memcpy(str->first, cstr, len); *(str->last) = '\0'; return str; } mod_json_string_t *mod_json_string_set(mod_json_cchar_t *cstr, mod_json_size_t len) { return ((cstr && len) ? mod_json_string_set_cstr(cstr, len) : mod_json_string_set_empty()); } void mod_json_string_unset(mod_json_string_t *str) { if (str && mod_json_string_put(str) <= 0) { mod_json_free(str->first); mod_json_free(str); } } void mod_json_string_reset(mod_json_string_t *str) { if (str) { str->last = str->first; *(str->first) = '\0'; } } static inline int mod_json_string_add_char(mod_json_string_t *str, mod_json_char_t ch) { mod_json_size_t need; need = (mod_json_size_t)(str->last - str->first) + 2; if (need > str->size) { mod_json_minus_if_ne_zero(mod_json_string_expand(str, need)); } /* append to string */ *(str->last++) = ch; *(str->last) = '\0'; /* success */ return 0; } static inline int mod_json_string_add_cstr(mod_json_string_t *str, mod_json_cchar_t *cstr, mod_json_size_t len) { if (cstr && len) { mod_json_size_t need; need = len + (mod_json_size_t)(str->last - str->first) + 1; if (need > str->size) { mod_json_minus_if_ne_zero(mod_json_string_expand(str, need)); } /* append to string */ memcpy(str->last, cstr, len); str->last += len; *(str->last) = '\0'; } /* success */ return 0; } static inline int mod_json_string_add_jstr(mod_json_string_t *str, mod_json_string_t *val) { return mod_json_string_add_cstr(str, val->first, (mod_json_size_t)(val->last - val->first)); } int mod_json_string_add(mod_json_string_t *str, mod_json_string_t *val) { return mod_json_string_add_jstr(str, val); } int mod_json_string_append(mod_json_string_t *str, mod_json_cchar_t *cstr, mod_json_size_t len) { return mod_json_string_add_cstr(str, cstr, len); } mod_json_size_t mod_json_string_hash(mod_json_string_t *str) { mod_json_size_t hash = 1; if (str) { mod_json_cchar_t *iter = str->first; mod_json_cchar_t *last = str->last; for (; iter != last; ++iter) { mod_json_size_t c = (mod_json_size_t)(*iter); hash = hash * 131 + c; } } return hash; } int mod_json_string_compare(mod_json_string_t *str1, mod_json_string_t *str2) { mod_json_size_t len1 = 0, len2 = 0; if (str1 == str2) { /* The same pointer */ return 0; } if (str1) { len1 = (mod_json_size_t)(str1->last - str1->first); if (str2) { len2 = (mod_json_size_t)(str2->last - str2->first); if (len1 == len2) { return memcmp(str1->first, str2->first, len1); } } } else { /* The first string is null, and the second string it not null. */ len2 = (mod_json_size_t)(str2->last - str2->first); } return (int)(len1 - len2); } mod_json_integer_t mod_json_string_integer(mod_json_string_t *str) { return (str ? (mod_json_integer_t)mod_json_utils_strtoi(str->first, NULL, 0) : 0); } mod_json_float_t mod_json_string_float(mod_json_string_t *str) { return (str ? mod_json_utils_strtof(str->first, NULL) : 0.0); } static inline int mod_json_string_flat(mod_json_string_t *dst, mod_json_string_t *src) { static mod_json_cchar_t *flattab[32] = { "\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b", "\\t", "\\n", "\\u000b", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012", "\\u0013", "\\u0014", "\\u0015", "\\u0016", "\\u0017", "\\u0018", "\\u0019", "\\u001a", "\\u001b", "\\u001c", "\\u001d", "\\u001e", "\\u001f"}; /* length of items in flat table */ static const mod_json_uchar_t flatlen[32] = {6, 6, 6, 6, 6, 6, 6, 6, 2, 2, 2, 6, 2, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6}; mod_json_cchar_t *first = src->first; mod_json_cchar_t *iter = src->first; mod_json_cchar_t *last = src->last; /* the whole string */ for (; iter != last; ++iter) { int c = *iter; if ((mod_json_uchar_t)c <= 0x1f) { if (iter > first) { mod_json_minus_if_ne_zero(mod_json_string_add_cstr( dst, first, (mod_json_size_t)(iter - first))); } mod_json_minus_if_ne_zero( mod_json_string_add_cstr(dst, flattab[c], flatlen[c])); /* skip current character */ first = iter + 1; } else if (c == '\"' || c == '\\') { if (iter > first) { mod_json_minus_if_ne_zero(mod_json_string_add_cstr( dst, first, (mod_json_size_t)(iter - first))); } mod_json_minus_if_ne_zero(mod_json_string_add_char(dst, '\\')); /* don't skip current character */ first = iter; } } if (iter > first) { mod_json_minus_if_ne_zero( mod_json_string_add_cstr(dst, first, (mod_json_size_t)(iter - first))); } /* success */ return 0; } static inline int mod_json_string_unflat(mod_json_string_t *dst, mod_json_string_t *src) { enum { state_normal, state_rev_slash, state_digit_1, state_digit_2, state_digit_3, state_digit_4 } state; mod_json_char_t *pbuf = dst->first; mod_json_char_t *pend = dst->first + dst->size; mod_json_cchar_t *iter = src->first; mod_json_cchar_t *last = src->last; mod_json_uchar_t high = 0; mod_json_uchar_t low = 0; /* the whole string */ for (state = state_normal; iter != last; ++iter) { int c = *iter; switch (state) { case state_normal: if (c != '\\') { mod_json_minus_if_false(pbuf < pend); *pbuf++ = (mod_json_char_t)c; } else { /* '\\' in process */ state = state_rev_slash; } break; case state_rev_slash: mod_json_minus_if_false(pbuf < pend); switch (c) { case '\"': state = state_normal; *pbuf++ = '\"'; break; case '/': state = state_normal; *pbuf++ = '/'; break; case 'b': state = state_normal; *pbuf++ = '\b'; break; case 'f': state = state_normal; *pbuf++ = '\f'; break; case '\\': state = state_normal; *pbuf++ = '\\'; break; case 'n': state = state_normal; *pbuf++ = '\n'; break; case 'r': state = state_normal; *pbuf++ = '\r'; break; case 't': state = state_normal; *pbuf++ = '\t'; break; case 'u': state = state_digit_1; break; default: return -1; } break; case state_digit_1: if ((c = mod_json_utils_char2hex((mod_json_char_t)c)) > 15) { /* invalid character */ return -1; } high = (mod_json_uchar_t)(c << 4); state = state_digit_2; break; case state_digit_2: if ((c = mod_json_utils_char2hex((mod_json_char_t)c)) > 15) { /* invalid character */ return -1; } high |= (mod_json_uchar_t)c; state = state_digit_3; break; case state_digit_3: if ((c = mod_json_utils_char2hex((mod_json_char_t)c)) > 15) { /* invalid character */ return -1; } low = (mod_json_uchar_t)(c << 4); state = state_digit_4; break; case state_digit_4: if ((c = mod_json_utils_char2hex((mod_json_char_t)c)) > 15) { /* invalid character */ return -1; } low |= (mod_json_uchar_t)c; /* decode as a UTF-8 string */ pbuf = mod_json_utils_uni2utf8(pbuf, (mod_json_size_t)(pend - pbuf), high, low); if (!pbuf) { /* lack of buffer */ return -1; } state = state_normal; break; } } if (state != state_normal) { /* uncompleted state */ return -1; } mod_json_minus_if_false(pbuf < pend); /* update the last pointer */ *(dst->last = pbuf) = '\0'; /* success */ return 0; } mod_json_string_t *mod_json_string_encode(mod_json_string_t *src) { mod_json_string_t *dst; mod_json_null_if_false(src); dst = mod_json_string_malloc( mod_json_utils_clp2((mod_json_size_t)(src->last - src->first) + 1)); mod_json_null_if_false(dst); if (mod_json_unlikely(mod_json_string_flat(dst, src) != 0)) { mod_json_string_unset(dst); return NULL; } return dst; } mod_json_string_t *mod_json_string_decode(mod_json_string_t *src) { mod_json_string_t *dst; mod_json_null_if_false(src); dst = mod_json_string_malloc( mod_json_utils_clp2((mod_json_size_t)(src->last - src->first) + 1)); mod_json_null_if_false(dst); if (mod_json_unlikely(mod_json_string_unflat(dst, src) != 0)) { mod_json_string_unset(dst); return NULL; } return dst; } mod_json_array_t *mod_json_array_set(mod_json_size_t size) { mod_json_array_t *arr; mod_json_value_t **buf; size = (size ? mod_json_utils_clp2(size) : MOD_JSON_ARRAY_DEFSIZE); buf = (mod_json_value_t **)mod_json_malloc(size * sizeof(mod_json_value_t *)); mod_json_null_if_false(buf); /* create an array */ arr = (mod_json_array_t *)mod_json_malloc(sizeof(mod_json_array_t)); if (mod_json_unlikely(!arr)) { mod_json_free(buf); return NULL; } arr->refer = 1; arr->size = size; arr->first = buf; arr->last = buf; return arr; } mod_json_array_t *mod_json_array_clone(mod_json_array_t *arr) { mod_json_array_t *arr2 = NULL; if (arr) { arr2 = mod_json_array_set((mod_json_size_t)(arr->last - arr->first)); if (arr2) { mod_json_value_t **iter = arr->first; /* clone items */ for (; iter != arr->last; ++iter) { *arr2->last++ = *iter ? mod_json_value_grab(*iter) : NULL; } } } return arr2; } mod_json_boolean_t mod_json_array_is_equal(mod_json_array_t *lhs, mod_json_array_t *rhs) { mod_json_value_t **itl, **itr; if (lhs == rhs) { return MOD_JSON_TRUE; } if (!lhs || !rhs || ((lhs->last - lhs->first) != (rhs->last - rhs->first))) { return MOD_JSON_FALSE; } /* compare items */ for (itl = lhs->first, itr = rhs->first; itl != lhs->last; ++itl, ++itr) { if (!mod_json_value_is_equal(*itl, *itr)) { return MOD_JSON_FALSE; } } return MOD_JSON_TRUE; } void mod_json_array_unset(mod_json_array_t *arr) { if (arr && mod_json_array_put(arr) <= 0) { mod_json_value_t **iter = arr->first; for (; iter != arr->last; ++iter) { mod_json_value_unset(*iter); } mod_json_free(arr->first); mod_json_free(arr); } } void mod_json_array_reset(mod_json_array_t *arr) { if (arr) { mod_json_value_t **iter = arr->first; for (; iter != arr->last; ++iter) { mod_json_value_unset(*iter); } arr->last = arr->first; } } static inline void mod_json_array_migrate(mod_json_array_t *arr, mod_json_value_t **buf, mod_json_size_t size) { mod_json_size_t count = (mod_json_size_t)(arr->last - arr->first); if (count > 0) { memcpy(buf, arr->first, count * sizeof(mod_json_value_t *)); } mod_json_free(arr->first); arr->first = buf; arr->last = buf + count; arr->size = size; } static inline int mod_json_array_expand(mod_json_array_t *arr, mod_json_size_t n) { mod_json_size_t size; mod_json_value_t **vals; size = mod_json_utils_clp2(n); if (size < MOD_JSON_ARRAY_DEFSIZE) { size = MOD_JSON_ARRAY_DEFSIZE; } mod_json_minus_if_false(size > arr->size); vals = (mod_json_value_t **)mod_json_malloc(size * sizeof(mod_json_value_t *)); mod_json_minus_if_false(vals); /* use new buffer */ mod_json_array_migrate(arr, vals, size); /* success */ return 0; } int mod_json_array_reserve(mod_json_array_t *arr, mod_json_size_t n) { mod_json_minus_if_false(arr); if (arr->size >= n) { /* needn't grow */ return 0; } return mod_json_array_expand(arr, n); } void mod_json_array_reverse(mod_json_array_t *arr) { if (arr) { mod_json_value_t **first = arr->first; mod_json_value_t **last = arr->last - 1; while (first < last) { mod_json_value_t *temp = *first; *first++ = *last; *last-- = temp; } } } int mod_json_array_push(mod_json_array_t *arr, mod_json_value_t *val) { mod_json_size_t count; mod_json_minus_if_false(arr); count = (mod_json_size_t)(arr->last - arr->first); if (count >= arr->size) { mod_json_minus_if_ne_zero(mod_json_array_expand(arr, count + 1)); } *arr->last++ = val ? mod_json_value_grab(val) : NULL; return 0; } void mod_json_array_pop(mod_json_array_t *arr) { if (arr && arr->first != arr->last) { mod_json_value_unset(*(--arr->last)); } } void mod_json_array_shift(mod_json_array_t *arr) { if (arr && arr->first != arr->last) { mod_json_value_t **it = arr->first; mod_json_value_t **last = --arr->last; mod_json_value_unset(*it++); for (; it <= last; ++it) { *(it - 1) = *it; } } } mod_json_value_t *mod_json_array_at(mod_json_array_t *arr, mod_json_size_t id) { if (arr && ((arr->first + id) < arr->last)) { return (arr->first[id]); } return NULL; } int mod_json_array_merge(mod_json_array_t *dst, mod_json_array_t *src) { long count, len1, len2; mod_json_minus_if_false(dst && src && dst != src); /* update length of array */ len1 = (mod_json_size_t)(src->last - src->first); len2 = (mod_json_size_t)(dst->last - dst->first); mod_json_minus_if_false(len1 >= 0 && len2 >= 0); /* append empty values */ count = len1 - len2; for (; count > 0; --count) { mod_json_array_push(dst, NULL); } /* It must be assigned again. */ len2 = (mod_json_size_t)(dst->last - dst->first); count = (len1 < len2 ? len1 : len2); while ((count--) > 0) { mod_json_value_t **iter1 = src->first + count; mod_json_value_t **iter2 = dst->first + count; if (!(*iter2)) { *iter2 = *iter1 ? mod_json_value_grab(*iter1) : NULL; continue; } if (mod_json_value_is_shared(*iter2)) { mod_json_value_put(*iter2); *iter2 = mod_json_value_clone(*iter2); } mod_json_value_merge(*iter2, *iter1); } /* success */ return 0; } int mod_json_array_resize(mod_json_array_t *arr, mod_json_size_t n, mod_json_value_t *val) { mod_json_size_t orig; /* check input */ mod_json_minus_if_false(arr); /* original count of array */ orig = (mod_json_size_t)(arr->last - arr->first); if (orig < n) { mod_json_value_t **iter; if (arr->size < n) { mod_json_minus_if_ne_zero(mod_json_array_expand(arr, n)); } iter = arr->last; arr->last = arr->first + n; /* grab the first one, but get the others */ *iter++ = val = val ? mod_json_value_grab(val) : NULL; for (; iter != arr->last; ++iter) { *iter = val ? mod_json_value_get(val) : NULL; } } else if (orig > n) { mod_json_value_t **iter = arr->first + n; for (; iter != arr->last; ++iter) { mod_json_value_unset(*iter); *iter = NULL; } arr->last = arr->first + n; } /* success */ return 0; } static inline void mod_json_pair_init(mod_json_pair_t *pair, mod_json_string_t *key, mod_json_value_t *val) { pair->key = mod_json_string_grab(key); pair->val = val ? mod_json_value_grab(val) : NULL; } static inline void mod_json_pair_cleanup(mod_json_pair_t *pair) { mod_json_string_unset(pair->key); mod_json_value_unset(pair->val); pair->key = NULL; pair->val = NULL; } mod_json_object_t *mod_json_object_set(mod_json_size_t size) { mod_json_object_t *obj; mod_json_pair_t *buf; size = (size ? mod_json_utils_clp2(size) : MOD_JSON_OBJECT_DEFSIZE); buf = (mod_json_pair_t *)mod_json_malloc(size * sizeof(mod_json_pair_t)); mod_json_null_if_false(buf); /* create a object */ obj = (mod_json_object_t *)mod_json_malloc(sizeof(mod_json_object_t)); if (mod_json_unlikely(!obj)) { mod_json_free(buf); return NULL; } obj->refer = 1; obj->size = size; obj->first = buf; obj->last = buf; return obj; } void mod_json_object_unset(mod_json_object_t *obj) { if (obj && mod_json_object_put(obj) <= 0) { mod_json_pair_t *iter = obj->first; for (; iter != obj->last; ++iter) { mod_json_pair_cleanup(iter); } mod_json_free(obj->first); mod_json_free(obj); } } void mod_json_object_reset(mod_json_object_t *obj) { if (obj) { mod_json_pair_t *iter = obj->first; for (; iter != obj->last; ++iter) { mod_json_pair_cleanup(iter); } obj->last = obj->first; } } static inline void mod_json_object_migrate(mod_json_object_t *obj, mod_json_pair_t *buf, mod_json_size_t size) { mod_json_size_t count = (mod_json_size_t)(obj->last - obj->first); if (count > 0) { memcpy(buf, obj->first, count * sizeof(mod_json_pair_t)); } mod_json_free(obj->first); obj->first = buf; obj->last = buf + count; obj->size = size; } static inline int mod_json_object_expand(mod_json_object_t *obj, mod_json_size_t n) { mod_json_size_t size; mod_json_pair_t *buf; size = mod_json_utils_clp2(n); if (size < MOD_JSON_OBJECT_DEFSIZE) { size = MOD_JSON_OBJECT_DEFSIZE; } mod_json_minus_if_false(size > obj->size); buf = (mod_json_pair_t *)mod_json_malloc(size * sizeof(mod_json_pair_t)); mod_json_minus_if_false(buf); /* use new buffer */ mod_json_object_migrate(obj, buf, size); /* success */ return 0; } static inline mod_json_pair_t *mod_json_object_find_pair(mod_json_object_t *obj, mod_json_string_t *key, mod_json_size_t *out) { mod_json_pair_t *first = obj->first; mod_json_pair_t *last = obj->last; while (first < last) { mod_json_pair_t *middle = first + ((last - first) >> 2); int diff = mod_json_string_compare(middle->key, key); if (diff < 0) { first = middle + 1; } else if (diff > 0) { last = middle; } else /*if (diff == 0)*/ { *out = (mod_json_size_t)(middle - obj->first); return middle; } } *out = (mod_json_size_t)(first - obj->first); return NULL; } mod_json_pair_t *mod_json_object_insert_force(mod_json_object_t *obj, mod_json_size_t npos, mod_json_string_t *key, mod_json_value_t *val) { mod_json_pair_t *iter, *pos; mod_json_size_t count; count = (mod_json_size_t)(obj->last - obj->first); if (count >= obj->size) { mod_json_null_if_ne_zero(mod_json_object_expand(obj, count + 1)); } pos = obj->first + npos; iter = obj->last++; for (; iter != pos; --iter) { mod_json_pair_t *prev = iter - 1; iter->key = prev->key; iter->val = prev->val; } mod_json_pair_init(pos, key, val); return pos; } mod_json_pair_t *mod_json_object_insert(mod_json_object_t *obj, mod_json_string_t *key, mod_json_value_t *val) { mod_json_size_t npos; mod_json_null_if_false(obj && key); if (mod_json_object_find_pair(obj, key, &npos)) { /* One in object */ return NULL; } return mod_json_object_insert_force(obj, npos, key, val); } mod_json_pair_t *mod_json_object_assign(mod_json_object_t *obj, mod_json_string_t *key, mod_json_value_t *val) { mod_json_pair_t *elem = NULL; if (obj && key) { mod_json_size_t npos; elem = mod_json_object_find_pair(obj, key, &npos); if (elem) { if (!elem->val) { elem->val = val ? mod_json_value_grab(val) : NULL; } else { /* overwrite the old value */ mod_json_value_assign(elem->val, val); } } else { /* insert a new one */ elem = mod_json_object_insert_force(obj, npos, key, val); } } return elem; } mod_json_pair_t *mod_json_object_touch(mod_json_object_t *obj, mod_json_cchar_t *key) { mod_json_pair_t *elem = NULL; if (obj && key) { mod_json_string_t str; mod_json_size_t npos; str.first = (mod_json_char_t *)key; str.last = str.first + mod_json_utils_strlen(key); elem = mod_json_object_find_pair(obj, &str, &npos); if (!elem) { mod_json_string_t *jkey; /* insert a new one */ jkey = mod_json_string_set(key, (mod_json_size_t)mod_json_utils_strlen(key)); elem = mod_json_object_insert_force(obj, npos, jkey, NULL); mod_json_string_unset(jkey); } } return elem; } mod_json_object_t *mod_json_object_clone(mod_json_object_t *obj) { mod_json_object_t *obj2 = NULL; if (obj) { obj2 = mod_json_object_set((mod_json_size_t)(obj->last - obj->first)); if (obj2) { mod_json_pair_t *iter = obj->first; /* clone items */ for (; iter != obj->last; ++iter) { mod_json_pair_init(obj2->last++, iter->key, iter->val); } } } return obj2; } mod_json_boolean_t mod_json_object_is_equal(mod_json_object_t *lhs, mod_json_object_t *rhs) { mod_json_pair_t *itl, *itr; if (lhs == rhs) { /* The same pointer */ return MOD_JSON_TRUE; } if (!lhs || !rhs || ((lhs->last - lhs->first) != (rhs->last - rhs->first))) { return MOD_JSON_FALSE; } /* compare items */ for (itl = lhs->first, itr = rhs->first; itl != lhs->last; ++itl, ++itr) { if ((mod_json_string_compare(itl->key, itr->key) != 0) || (!mod_json_value_is_equal(itl->val, itr->val))) { return MOD_JSON_FALSE; } } return MOD_JSON_TRUE; } void mod_json_object_erase(mod_json_object_t *obj, mod_json_cchar_t *key) { if (obj && key) { mod_json_string_t str; mod_json_pair_t *iter; mod_json_size_t npos; str.first = (mod_json_char_t *)key; str.last = str.first + mod_json_utils_strlen(key); iter = mod_json_object_find_pair(obj, &str, &npos); if (iter) { mod_json_pair_cleanup(iter++); for (; iter != obj->last; ++iter) { mod_json_pair_t *prev = iter - 1; prev->key = iter->key; prev->val = iter->val; } --obj->last; } } } mod_json_value_t *mod_json_object_at(mod_json_object_t *obj, mod_json_cchar_t *key) { if (obj && key) { mod_json_string_t str; mod_json_pair_t *elem; mod_json_size_t npos; str.first = (mod_json_char_t *)key; str.last = str.first + mod_json_utils_strlen(key); elem = mod_json_object_find_pair(obj, &str, &npos); if (elem) { return (elem->val); } } return NULL; } mod_json_pair_t *mod_json_object_find(mod_json_object_t *obj, mod_json_cchar_t *key) { if (obj && key) { mod_json_string_t str; mod_json_size_t npos; str.first = (mod_json_char_t *)key; str.last = str.first + mod_json_utils_strlen(key); return mod_json_object_find_pair(obj, &str, &npos); } return NULL; } int mod_json_object_merge(mod_json_object_t *dst, mod_json_object_t *src) { mod_json_pair_t *iter; mod_json_minus_if_false(dst && src && dst != src); for (iter = src->first; iter != src->last; ++iter) { mod_json_pair_t *elem; mod_json_size_t npos; elem = mod_json_object_find_pair(dst, iter->key, &npos); if (!elem) { /* insert a new one */ mod_json_object_insert_force(dst, npos, iter->key, iter->val); continue; } if (!elem->val) { elem->val = iter->val ? mod_json_value_grab(iter->val) : NULL; continue; } if (mod_json_value_is_shared(elem->val)) { mod_json_value_put(elem->val); elem->val = mod_json_value_clone(elem->val); } mod_json_value_merge(elem->val, iter->val); } return 0; } static inline mod_json_cchar_t *mod_json_token_strskp(mod_json_token_t *tok, mod_json_cchar_t *cstr) { if ((tok->options & MOD_JSON_COMMENT) == 0) { return mod_json_utils_strskpb(cstr); } return mod_json_utils_strskp(cstr); } static inline mod_json_cchar_t *mod_json_token_strfquo(mod_json_token_t *tok, mod_json_cchar_t *cstr, mod_json_char_t quo) { if ((tok->options & MOD_JSON_UNSTRICT) == 0) { return mod_json_utils_strfquo(cstr, quo); } return mod_json_utils_strfquo2(cstr, quo); } static inline mod_json_cchar_t *mod_json_token_strfsep(mod_json_token_t *tok, mod_json_cchar_t *cstr) { if ((tok->options & MOD_JSON_COMMENT) == 0) { return mod_json_utils_strfsep(cstr); } return mod_json_utils_strfsep2(cstr); } mod_json_token_t *mod_json_token_create(mod_json_option_t *opt) { mod_json_token_t *tok; mod_json_size_t opts = MOD_JSON_TOKEN_DEFOPTS; mod_json_size_t mobj = MOD_JSON_TOKEN_DEFOBJDEP; mod_json_size_t marr = MOD_JSON_TOKEN_DEFARRDEP; if (opt) { opts = opt->options; if (opt->object_depth > 0) { mobj = opt->object_depth; } if (opt->array_depth > 0) { marr = opt->array_depth; } } tok = (mod_json_token_t *)mod_json_malloc( (mobj + marr) * sizeof(mod_json_char_t) + sizeof(mod_json_token_t)); mod_json_null_if_false(tok); memset(tok, 0, sizeof(mod_json_token_t)); tok->state = mod_json_state_null; tok->error = mod_json_error_null; tok->options = opts; tok->object_max_depth = mobj; tok->array_max_depth = marr; return tok; } void mod_json_token_destroy(mod_json_token_t *tok) { mod_json_free(tok); } static inline void mod_json_token_set_tag(mod_json_token_t *tok, mod_json_char_t tag) { mod_json_size_t depth = tok->object_depth + tok->array_depth; if (depth != 0) { tok->tags[depth - 1] = tag; } } static inline mod_json_char_t mod_json_token_tag(mod_json_token_t *tok) { mod_json_size_t depth = tok->object_depth + tok->array_depth; /* type of current depth */ return (depth ? tok->tags[depth - 1] : (mod_json_char_t)-1); } mod_json_error_t mod_json_token_error(mod_json_token_t *tok) { return (tok->error); } mod_json_cchar_t *mod_json_token_context(mod_json_token_t *tok) { return (tok->context); } mod_json_state_t mod_json_token_state(mod_json_token_t *tok) { return (tok->state); } mod_json_size_t mod_json_token_object_depth(mod_json_token_t *tok) { return (tok->object_depth); } mod_json_size_t mod_json_token_array_depth(mod_json_token_t *tok) { return (tok->array_depth); } mod_json_size_t mod_json_token_depth(mod_json_token_t *tok) { return (tok->object_depth + tok->array_depth); } mod_json_size_t mod_json_token_max_object_depth(mod_json_token_t *tok) { return (tok->object_max_depth); } mod_json_size_t mod_json_token_max_array_depth(mod_json_token_t *tok) { return (tok->array_max_depth); } mod_json_size_t mod_json_token_max_depth(mod_json_token_t *tok) { return (tok->object_max_depth + tok->array_max_depth); } mod_json_void_t *mod_json_token_param(mod_json_token_t *tok) { return (tok->param); } void mod_json_token_set_param(mod_json_token_t *tok, mod_json_void_t *param) { tok->param = param; } void mod_json_token_set_event(mod_json_token_t *tok, mod_json_event_proc proc) { tok->event_proc = proc; } mod_json_event_t mod_json_token_event(mod_json_token_t *tok) { return (tok->event_code); } static inline int mod_json_token_invoke_field(mod_json_token_t *tok, mod_json_cchar_t *val, mod_json_size_t len) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_field; return invoke(tok, (mod_json_void_t *)val, len); } return 0; } static inline int mod_json_token_invoke_object(mod_json_token_t *tok) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_object; return invoke(tok, NULL, 0); } return 0; } static inline int mod_json_token_invoke_array(mod_json_token_t *tok) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_array; return invoke(tok, NULL, 0); } return 0; } static inline int mod_json_token_invoke_null(mod_json_token_t *tok) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_null; return invoke(tok, NULL, 0); } return 0; } static inline int mod_json_token_invoke_boolean(mod_json_token_t *tok, mod_json_boolean_t val) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_boolean; return invoke(tok, &val, sizeof(val)); } return 0; } static inline int mod_json_token_invoke_integer(mod_json_token_t *tok, mod_json_integer_t val) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_integer; return invoke(tok, &val, sizeof(val)); } return 0; } static inline int mod_json_token_invoke_float(mod_json_token_t *tok, mod_json_float_t val) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_float; return invoke(tok, &val, sizeof(val)); } return 0; } static inline int mod_json_token_invoke_string(mod_json_token_t *tok, mod_json_cchar_t *val, mod_json_size_t len) { mod_json_event_proc invoke = tok->event_proc; if (invoke) { tok->event_code = mod_json_event_string; return invoke(tok, (mod_json_void_t *)val, len); } return 0; } static inline mod_json_cchar_t *mod_json_token_start(mod_json_token_t *tok, mod_json_cchar_t *cstr) { cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case '{': tok->state = mod_json_state_object_start; return (cstr + 1); case '[': tok->state = mod_json_state_array_start; return (cstr + 1); case '\0': tok->error = mod_json_error_empty; tok->context = cstr; break; default: tok->error = mod_json_error_start; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_value_null( mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_char_t c1 = *(cstr + 1); mod_json_char_t c2 = *(cstr + 2); mod_json_char_t c3 = *(cstr + 3); if ((c1 != 'u' && c1 != 'U') || (c2 != 'l' && c2 != 'L') || (c3 != 'l' && c3 != 'L')) { tok->error = mod_json_error_value; tok->context = cstr; return NULL; } if (mod_json_token_invoke_null(tok) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } return (cstr + 4); } static inline mod_json_cchar_t *mod_json_token_value_true( mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_char_t c1 = *(cstr + 1); mod_json_char_t c2 = *(cstr + 2); mod_json_char_t c3 = *(cstr + 3); if ((c1 != 'r' && c1 != 'R') || (c2 != 'u' && c2 != 'U') || (c3 != 'e' && c3 != 'E')) { tok->error = mod_json_error_value; tok->context = cstr; return NULL; } if (mod_json_token_invoke_boolean(tok, MOD_JSON_TRUE) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } return (cstr + 4); } static inline mod_json_cchar_t *mod_json_token_value_false( mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_char_t c1 = *(cstr + 1); mod_json_char_t c2 = *(cstr + 2); mod_json_char_t c3 = *(cstr + 3); mod_json_char_t c4 = *(cstr + 4); if ((c1 != 'a' && c1 != 'A') || (c2 != 'l' && c2 != 'L') || (c3 != 's' && c3 != 'S') || (c4 != 'e' && c4 != 'E')) { tok->error = mod_json_error_value; tok->context = cstr; return NULL; } if (mod_json_token_invoke_boolean(tok, MOD_JSON_FALSE) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } return (cstr + 5); } static inline mod_json_cchar_t *mod_json_token_value_infinity( mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_char_t c1 = *(cstr + 1); mod_json_char_t c2 = *(cstr + 2); if ((c1 != 'n' && c1 != 'N') || (c2 != 'f' && c2 != 'F')) { tok->error = mod_json_error_value; tok->context = cstr; return NULL; } if (mod_json_token_invoke_float(tok, MOD_JSON_INFINITY) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } return (cstr + 3); } static inline mod_json_cchar_t *mod_json_token_value_string( mod_json_token_t *tok, mod_json_cchar_t *cstr, mod_json_char_t quo) { mod_json_cchar_t *cstr2 = mod_json_token_strfquo(tok, ++cstr, quo); if (!cstr2) { tok->error = mod_json_error_quote; tok->context = cstr; return NULL; } if (mod_json_token_invoke_string(tok, cstr, (mod_json_size_t)(cstr2 - cstr)) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } return (cstr2 + 1); } static inline mod_json_cchar_t *mod_json_token_value_number( mod_json_token_t *tok, mod_json_cchar_t *cstr) { enum { number_integer, number_float } num_type = number_integer; mod_json_float_t dbl = 0.0; uint32_t dig = 0; uint64_t u64 = 0; int32_t minus = 0; int32_t exp_frac = 0, exp = 0; /* Parse minus */ minus = *cstr; if (minus == '-' || minus == '+') { ++cstr; } /* The first digit */ if ((dig = (uint32_t)(*cstr - '0')) > 9) { return NULL; } /* Save the first digit */ u64 = dig; /* Parse as 64bit integer */ if (minus != '-') { while ((dig = (uint32_t)(*(++cstr) - '0')) <= 9) { if (u64 >= 1844674407370955161uLL) { /* 2^64 - 1 = 18446744073709551615 */ if (u64 != 1844674407370955161uLL || dig > 5) { dbl = (mod_json_float_t)u64 * 10 + dig; num_type = number_float; break; } } u64 = u64 * 10 + dig; } } else { while ((dig = (uint32_t)(*(++cstr) - '0')) <= 9) { /* 2^63 = 9223372036854775808 */ if (u64 >= 922337203685477580uLL) { if (u64 != 922337203685477580uLL || dig > 8) { dbl = (mod_json_float_t)u64 * 10 + dig; num_type = number_float; break; } } u64 = u64 * 10 + dig; } } /* Force double for big integer */ if (num_type == number_float) { while ((dig = (uint32_t)(*(++cstr) - '0')) <= 9) { if (dbl >= 1E307) { /* Number too big to store in double */ return NULL; } dbl = dbl * 10 + dig; } } /* Parse frac = decimal-point 1*DIGIT */ if (*cstr == '.') { if (num_type != number_float) { dbl = (mod_json_float_t)u64; num_type = number_float; } if ((dig = (uint32_t)(*(++cstr) - '0')) > 9) { /* At least one digit in fraction part */ return NULL; } dbl = dbl * 10 + dig; --exp_frac; while ((dig = (uint32_t)(*(++cstr) - '0')) <= 9) { if (exp_frac > -16) { dbl = dbl * 10 + dig; --exp_frac; } } } /* Parse exp = e [ minus / plus ] 1*DIGIT */ if (*cstr == 'e' || *cstr == 'E') { int32_t exp_minus = 0; if (num_type != number_float) { dbl = (mod_json_float_t)u64; num_type = number_float; } exp_minus = *(++cstr); if (exp_minus == '-' || exp_minus == '+') { ++cstr; } /* The first number char after 'e/E' */ if ((dig = (uint32_t)(*cstr - '0')) > 9) { return NULL; } exp = (int32_t)dig; while ((dig = (uint32_t)(*(++cstr) - '0')) <= 9) { exp = exp * 10 + (int32_t)dig; if (exp > 308) { /* Number too big to store in double */ return NULL; } } if (exp_minus == '-') { exp = -exp; } } /* Finish parsing, call event according to the type of number. */ if (num_type == number_float) { dbl *= mod_json_utils_pow10(exp + exp_frac); if (minus == '-') { dbl = -dbl; } if (mod_json_token_invoke_float(tok, dbl) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } } else { if (minus == '-') { u64 = (uint64_t)(-(int64_t)u64); } if (mod_json_token_invoke_integer(tok, (mod_json_integer_t)u64) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } } return cstr; } static inline mod_json_cchar_t *mod_json_token_array_start( mod_json_token_t *tok, mod_json_cchar_t *cstr) { if (tok->array_depth < tok->array_max_depth) { /* callback */ if (mod_json_token_invoke_array(tok) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } /* increase depth */ ++tok->array_depth; /* push current tag */ mod_json_token_set_tag(tok, '['); cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case '[': tok->state = mod_json_state_array_start; return (cstr + 1); case ']': tok->state = mod_json_state_array_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->state = mod_json_state_array_half; return (cstr); } } else { tok->error = mod_json_error_depth; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_array_half( mod_json_token_t *tok, mod_json_cchar_t *cstr) { cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case ',': tok->state = mod_json_state_array_half; return (cstr + 1); case '[': tok->state = mod_json_state_array_start; return (cstr + 1); case ']': tok->state = mod_json_state_array_finish; return (cstr + 1); case '{': tok->state = mod_json_state_object_start; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; return NULL; /* value in array */ case 't': case 'T': cstr = mod_json_token_value_true(tok, cstr); if (!cstr) { return NULL; } break; case 'f': case 'F': cstr = mod_json_token_value_false(tok, cstr); if (!cstr) { return NULL; } break; case 'n': case 'N': cstr = mod_json_token_value_null(tok, cstr); if (!cstr) { return NULL; } break; case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': case '+': case '-': cstr = mod_json_token_value_number(tok, cstr); if (!cstr) { return NULL; } break; case '\"': cstr = mod_json_token_value_string(tok, cstr, '\"'); if (!cstr) { return NULL; } break; case '\'': if (tok->options & MOD_JSON_SQUOTE) { cstr = mod_json_token_value_string(tok, cstr, '\''); if (!cstr) { return NULL; } break; } /* FALLTHRU */ default: tok->error = mod_json_error_value; tok->context = cstr; return NULL; } cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case ',': tok->state = mod_json_state_array_half; return (cstr + 1); case ']': tok->state = mod_json_state_array_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->error = mod_json_error_value; tok->context = cstr; break; } return NULL; } static inline mod_json_cchar_t *mod_json_token_array_finish( mod_json_token_t *tok, mod_json_cchar_t *cstr) { if (tok->array_depth) { /* decrease depth */ --tok->array_depth; /* callback */ if (mod_json_token_invoke_array(tok) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case ']': tok->state = mod_json_state_array_finish; return (cstr + 1); case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case '\0': if (tok->object_depth || tok->array_depth) { tok->error = mod_json_error_trunc; tok->context = cstr; } else { tok->state = mod_json_state_finish; } break; case ',': if (tok->object_depth || tok->array_depth) { mod_json_char_t tag = mod_json_token_tag(tok); if (tag == '{') { tok->state = mod_json_state_object_half1; return (cstr + 1); } else if (tag == '[') { tok->state = mod_json_state_array_half; return (cstr + 1); } } /* FALLTHRU */ default: tok->error = mod_json_error_array; tok->context = cstr; } } else { tok->error = mod_json_error_depth; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_start( mod_json_token_t *tok, mod_json_cchar_t *cstr) { if (tok->object_depth < tok->object_max_depth) { /* callback */ if (mod_json_token_invoke_object(tok) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } /* increase depth */ ++tok->object_depth; /* push current tag */ mod_json_token_set_tag(tok, '{'); cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->state = mod_json_state_object_half1; return (cstr); } } else { tok->error = mod_json_error_depth; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_quotekey( mod_json_token_t *tok, mod_json_cchar_t *cstr, mod_json_char_t quo) { mod_json_cchar_t *cstr2 = mod_json_token_strfquo(tok, ++cstr, quo); if (cstr2) { /* callback */ if (mod_json_token_invoke_field(tok, cstr, (mod_json_size_t)(cstr2 - cstr)) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } cstr2 = mod_json_token_strskp(tok, ++cstr2); switch (*cstr2) { case ':': tok->state = mod_json_state_object_half2; return (cstr2 + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->error = mod_json_error_key; tok->context = cstr2; break; } } else { tok->error = mod_json_error_quote; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_simplekey( mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_cchar_t *cstr2 = mod_json_token_strfsep(tok, cstr); if (cstr2 != cstr) { /* callback */ if (mod_json_token_invoke_field(tok, cstr, (mod_json_size_t)(cstr2 - cstr)) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } cstr2 = mod_json_token_strskp(tok, cstr2); switch (*cstr2) { case ':': tok->state = mod_json_state_object_half2; return (cstr2 + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->error = mod_json_error_key; tok->context = cstr2; break; } } else { tok->error = mod_json_error_key; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_half1( mod_json_token_t *tok, mod_json_cchar_t *cstr) { cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case ',': tok->state = mod_json_state_object_half1; return (cstr + 1); case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; case '\"': /* The key with double quotes */ return mod_json_token_object_quotekey(tok, cstr, '\"'); case '\'': if (tok->options & MOD_JSON_SQUOTE) { /* The key with single quotes */ return mod_json_token_object_quotekey(tok, cstr, '\''); } /* FALLTHRU */ default: /* support simple format? */ if (tok->options & MOD_JSON_SIMPLE) { return mod_json_token_object_simplekey(tok, cstr); } else { tok->error = mod_json_error_quote; tok->context = cstr; } break; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_half2( mod_json_token_t *tok, mod_json_cchar_t *cstr) { cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case '{': tok->state = mod_json_state_object_start; return (cstr + 1); case '[': tok->state = mod_json_state_array_start; return (cstr + 1); case ',': tok->state = mod_json_state_object_half1; return (cstr + 1); case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; return NULL; case 't': case 'T': cstr = mod_json_token_value_true(tok, cstr); if (!cstr) { return NULL; } break; case 'f': case 'F': cstr = mod_json_token_value_false(tok, cstr); if (!cstr) { return NULL; } break; case 'i': case 'I': cstr = mod_json_token_value_infinity(tok, cstr); if (!cstr) { return NULL; } break; case 'n': case 'N': cstr = mod_json_token_value_null(tok, cstr); if (!cstr) { return NULL; } break; case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9': case '+': case '-': cstr = mod_json_token_value_number(tok, cstr); if (!cstr) { return NULL; } break; case '\"': cstr = mod_json_token_value_string(tok, cstr, '\"'); if (!cstr) { return NULL; } break; case '\'': if (tok->options & MOD_JSON_SQUOTE) { cstr = mod_json_token_value_string(tok, cstr, '\''); if (!cstr) { return NULL; } break; } /* FALLTHRU */ default: tok->error = mod_json_error_value; tok->context = cstr; return NULL; } cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case ',': tok->state = mod_json_state_object_half1; return (cstr + 1); case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case '\0': tok->error = mod_json_error_trunc; tok->context = cstr; break; default: tok->error = mod_json_error_value; tok->context = cstr; break; } return NULL; } static inline mod_json_cchar_t *mod_json_token_object_finish( mod_json_token_t *tok, mod_json_cchar_t *cstr) { if (tok->object_depth) { /* decrease depth */ --tok->object_depth; /* callback */ if (mod_json_token_invoke_object(tok) != 0) { tok->error = mod_json_error_break; tok->context = cstr; return NULL; } cstr = mod_json_token_strskp(tok, cstr); switch (*cstr) { case '}': tok->state = mod_json_state_object_finish; return (cstr + 1); case ']': tok->state = mod_json_state_array_finish; return (cstr + 1); case '\0': if (tok->object_depth || tok->array_depth) { tok->error = mod_json_error_trunc; tok->context = cstr; } else { tok->state = mod_json_state_finish; } break; case ',': if (tok->object_depth || tok->array_depth) { mod_json_char_t tag = mod_json_token_tag(tok); if (tag == '{') { tok->state = mod_json_state_object_half1; return (cstr + 1); } else if (tag == '[') { tok->state = mod_json_state_array_half; return (cstr + 1); } } /* FALLTHRU */ default: tok->error = mod_json_error_object; tok->context = cstr; } } else { tok->error = mod_json_error_depth; tok->context = cstr; } return NULL; } static inline mod_json_cchar_t *mod_json_token_null(mod_json_token_t *tok, mod_json_cchar_t *cstr) { if (!cstr || *cstr == '\0') { tok->error = mod_json_error_invalid; tok->context = cstr; return NULL; } tok->state = mod_json_state_start; return cstr; } static inline mod_json_cchar_t *mod_json_token_finish(mod_json_token_t *tok, mod_json_cchar_t *cstr) { tok->error = mod_json_error_null; (void)cstr; return NULL; } static inline mod_json_cchar_t *mod_json_token_default(mod_json_token_t *tok, mod_json_cchar_t *cstr) { tok->error = mod_json_error_state; tok->context = cstr; return NULL; } int mod_json_token_parse(mod_json_token_t *tok, mod_json_cchar_t *cstr) { while (cstr) { switch (tok->state) { case mod_json_state_start: cstr = mod_json_token_start(tok, cstr); break; case mod_json_state_array_start: cstr = mod_json_token_array_start(tok, cstr); break; case mod_json_state_array_half: cstr = mod_json_token_array_half(tok, cstr); break; case mod_json_state_array_finish: cstr = mod_json_token_array_finish(tok, cstr); break; case mod_json_state_object_start: cstr = mod_json_token_object_start(tok, cstr); break; case mod_json_state_object_half1: cstr = mod_json_token_object_half1(tok, cstr); break; case mod_json_state_object_half2: cstr = mod_json_token_object_half2(tok, cstr); break; case mod_json_state_object_finish: cstr = mod_json_token_object_finish(tok, cstr); break; case mod_json_state_null: cstr = mod_json_token_null(tok, cstr); break; case mod_json_state_finish: cstr = mod_json_token_finish(tok, cstr); break; default: cstr = mod_json_token_default(tok, cstr); break; } } return (tok->error == mod_json_error_null ? 0 : -1); } static inline int mod_json_parser_insert(mod_json_parser_t *par, mod_json_size_t depth, mod_json_value_t *val) { if (depth > 0) { mod_json_value_t *cur = par->vals[depth - 1]; switch (cur->type) { case mod_json_type_object: return (mod_json_object_insert(cur->data.c_obj, par->key, val) ? 0 : -1); case mod_json_type_array: return mod_json_array_push(cur->data.c_arr, val); default: break; } } return -1; } static inline int mod_json_parser_insert_object(mod_json_parser_t *par, mod_json_size_t depth) { mod_json_object_t *obj; mod_json_value_t *jval; obj = mod_json_object_set_default(); mod_json_minus_if_false(obj); jval = mod_json_value_set_object(obj); mod_json_object_unset(obj); mod_json_minus_if_false(jval); if (depth > 0) { int ret = mod_json_parser_insert(par, depth, jval); if (ret == 0) { par->vals[depth] = jval; } mod_json_value_unset(jval); return ret; } else { /* It's the root, save the pointer. Don't unset it. */ par->vals[0] = jval; } return 0; } static inline int mod_json_parser_insert_array(mod_json_parser_t *par, mod_json_size_t depth) { mod_json_array_t *arr; mod_json_value_t *jval; arr = mod_json_array_set_default(); mod_json_minus_if_false(arr); jval = mod_json_value_set_array(arr); mod_json_array_unset(arr); mod_json_minus_if_false(jval); if (depth > 0) { int ret = mod_json_parser_insert(par, depth, jval); if (ret == 0) { par->vals[depth] = jval; } mod_json_value_unset(jval); return ret; } else { /* It's the root, save the pointer. Don't unset it. */ par->vals[0] = jval; } return 0; } static inline void mod_json_token_set_parser(mod_json_token_t *tok, mod_json_parser_t *par) { mod_json_token_set_param(tok, par); } static inline mod_json_parser_t *mod_json_token_parser(mod_json_token_t *tok) { return (mod_json_parser_t *)mod_json_token_param(tok); } static inline int mod_json_parser_event_field(mod_json_token_t *tok, mod_json_cchar_t *val, mod_json_size_t len) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); /* unset previous one */ mod_json_string_unset(parser->key); parser->key = mod_json_string_set(val, len); return (parser->key ? 0 : -1); } static inline int mod_json_parser_event_array(mod_json_token_t *tok) { switch (mod_json_token_state(tok)) { case mod_json_state_array_finish: /* continue */ return 0; case mod_json_state_array_start: return mod_json_parser_insert_array(mod_json_token_parser(tok), mod_json_token_depth(tok)); default: break; } return -1; } static inline int mod_json_parser_event_object(mod_json_token_t *tok) { switch (mod_json_token_state(tok)) { case mod_json_state_object_finish: /* continue */ return 0; case mod_json_state_object_start: return mod_json_parser_insert_object(mod_json_token_parser(tok), mod_json_token_depth(tok)); default: break; } return -1; } static inline int mod_json_parser_event_null(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_null) { parser->val_null = mod_json_value_set_null(); mod_json_minus_if_false(parser->val_null); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_null); } static inline int mod_json_parser_event_true(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_true) { parser->val_true = mod_json_value_set_boolean(MOD_JSON_TRUE); mod_json_minus_if_false(parser->val_true); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_true); } static inline int mod_json_parser_event_false(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_false) { parser->val_false = mod_json_value_set_boolean(MOD_JSON_FALSE); mod_json_minus_if_false(parser->val_false); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_false); } static inline int mod_json_parser_event_boolean(mod_json_token_t *tok, mod_json_boolean_t val) { if (!val) { return mod_json_parser_event_false(tok); } return mod_json_parser_event_true(tok); } static inline int mod_json_parser_event_zero(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_zero) { parser->val_zero = mod_json_value_set_integer(0); mod_json_minus_if_false(parser->val_zero); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_zero); } static inline int mod_json_parser_event_integer(mod_json_token_t *tok, mod_json_integer_t val) { int ret = -1; if (val != 0) { mod_json_value_t *jval; jval = mod_json_value_set_integer(val); if (jval) { ret = mod_json_parser_insert(mod_json_token_parser(tok), mod_json_token_depth(tok), jval); mod_json_value_unset(jval); } } else { /* zero event */ ret = mod_json_parser_event_zero(tok); } return ret; } static inline int mod_json_parser_event_zerof(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_zerof) { parser->val_zerof = mod_json_value_set_float(0.0); mod_json_minus_if_false(parser->val_zerof); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_zerof); } static inline int mod_json_parser_event_float(mod_json_token_t *tok, mod_json_float_t val) { int ret = -1; if (val != 0.0) { mod_json_value_t *jval; jval = mod_json_value_set_float(val); if (jval) { ret = mod_json_parser_insert(mod_json_token_parser(tok), mod_json_token_depth(tok), jval); mod_json_value_unset(jval); } } else { /* zero event */ ret = mod_json_parser_event_zerof(tok); } return ret; } static inline int mod_json_parser_event_empty(mod_json_token_t *tok) { mod_json_parser_t *parser; /* get information */ parser = mod_json_token_parser(tok); if (!parser->val_empty) { mod_json_string_t *str; str = mod_json_string_set("", 0); mod_json_minus_if_false(str); parser->val_empty = mod_json_value_set_string(str); mod_json_string_unset(str); mod_json_minus_if_false(parser->val_empty); } return mod_json_parser_insert(parser, mod_json_token_depth(tok), parser->val_empty); } static inline int mod_json_parser_event_string(mod_json_token_t *tok, mod_json_cchar_t *val, mod_json_size_t len) { int ret = -1; if (len > 0) { mod_json_string_t *str; mod_json_value_t *jval; str = mod_json_string_set(val, len); if (str) { jval = mod_json_value_set_string(str); } else { jval = NULL; } mod_json_string_unset(str); if (jval) { ret = mod_json_parser_insert(mod_json_token_parser(tok), mod_json_token_depth(tok), jval); mod_json_value_unset(jval); } } else { /* empty event */ ret = mod_json_parser_event_empty(tok); } return ret; } static int mod_json_parser_event(mod_json_token_t *tok, mod_json_void_t *val, mod_json_size_t len) { switch (tok->event_code) { case mod_json_event_field: return mod_json_parser_event_field(tok, (mod_json_cchar_t *)val, len); case mod_json_event_object: return mod_json_parser_event_object(tok); case mod_json_event_array: return mod_json_parser_event_array(tok); case mod_json_event_null: return mod_json_parser_event_null(tok); case mod_json_event_boolean: return mod_json_parser_event_boolean(tok, *(mod_json_boolean_t *)val); case mod_json_event_integer: return mod_json_parser_event_integer(tok, *(mod_json_integer_t *)val); case mod_json_event_float: return mod_json_parser_event_float(tok, *(mod_json_float_t *)val); case mod_json_event_string: return mod_json_parser_event_string(tok, (mod_json_cchar_t *)val, len); default: break; } return -1; } static inline mod_json_parser_t *mod_json_parser_create(mod_json_size_t depth) { mod_json_parser_t *parser; mod_json_null_if_false(depth > 0); parser = (mod_json_parser_t *)mod_json_malloc( depth * sizeof(mod_json_value_t *) + sizeof(mod_json_parser_t)); mod_json_null_if_false(parser); memset(parser, 0, sizeof(mod_json_parser_t)); parser->vals[0] = NULL; return parser; } static inline void mod_json_parser_destroy(mod_json_parser_t *par) { mod_json_value_unset(par->val_null); mod_json_value_unset(par->val_true); mod_json_value_unset(par->val_false); mod_json_value_unset(par->val_zero); mod_json_value_unset(par->val_zerof); mod_json_value_unset(par->val_empty); mod_json_string_unset(par->key); mod_json_free(par); } mod_json_value_t *mod_json_parse(mod_json_token_t *tok, mod_json_cchar_t *cstr) { mod_json_parser_t *parser; mod_json_value_t *root; mod_json_null_if_false(tok && cstr && *cstr); parser = mod_json_parser_create(mod_json_token_max_depth(tok)); mod_json_null_if_false(parser); mod_json_token_set_parser(tok, parser); mod_json_token_set_event(tok, mod_json_parser_event); if (mod_json_token_parse(tok, cstr) == 0) { root = parser->vals[0]; } else { /* error occur */ root = NULL; mod_json_value_unset(parser->vals[0]); } /* clean up */ mod_json_parser_destroy(parser); /* success? */ return root; } mod_json_value_t *mod_json_parse_simply(mod_json_cchar_t *cstr, mod_json_size_t opts) { mod_json_value_t *val; mod_json_token_t *tok; mod_json_option_t opt; opt.options = opts; opt.object_depth = 0; /* Use default object depth */ opt.array_depth = 0; /* Use default array depth */ tok = mod_json_token_create(&opt); mod_json_null_if_false(tok); val = mod_json_parse(tok, cstr); mod_json_token_destroy(tok); /* value of root */ return val; } static inline int mod_json_dump_null(mod_json_string_t *str) { return mod_json_string_add_cstr(str, "null", 4); } static inline int mod_json_dump_boolean(mod_json_string_t *str, mod_json_boolean_t bol) { if (!bol) { return mod_json_string_add_cstr(str, "false", 5); } return mod_json_string_add_cstr(str, "true", 4); } static inline int mod_json_dump_integer(mod_json_string_t *str, mod_json_integer_t num) { mod_json_char_t buf[32]; return mod_json_string_add_cstr(str, buf, mod_json_utils_itostr(buf, num)); } static inline int mod_json_dump_float(mod_json_string_t *str, mod_json_float_t dbl) { mod_json_char_t buf[32]; return mod_json_string_add_cstr( str, buf, (mod_json_size_t)mod_json_utils_snprintf(buf, sizeof(buf), "%g", dbl)); } static inline int mod_json_dump_string(mod_json_string_t *str, mod_json_string_t *val) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '\"')); if (val) { mod_json_minus_if_ne_zero(mod_json_string_add_jstr(str, val)); } mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '\"')); return 0; } static inline int mod_json_dump_value(mod_json_string_t *str, mod_json_value_t *val); static inline int mod_json_dump_array(mod_json_string_t *str, mod_json_array_t *arr) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '[')); if (arr) { mod_json_value_t **iter = arr->first; for (; iter != arr->last; ++iter) { mod_json_minus_if_ne_zero(mod_json_dump_value(str, *iter)); if (iter + 1 != arr->last) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, ',')); } } } mod_json_minus_if_ne_zero(mod_json_string_add_char(str, ']')); return 0; } static inline int mod_json_dump_key(mod_json_string_t *str, mod_json_string_t *key) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '\"')); mod_json_minus_if_ne_zero(mod_json_string_add_jstr(str, key)); mod_json_minus_if_ne_zero(mod_json_string_add_cstr(str, "\":", 2)); return 0; } static inline int mod_json_dump_object(mod_json_string_t *str, mod_json_object_t *obj) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '{')); if (obj) { mod_json_pair_t *iter = obj->first; for (; iter != obj->last; ++iter) { mod_json_minus_if_ne_zero(mod_json_dump_key(str, iter->key)); mod_json_minus_if_ne_zero(mod_json_dump_value(str, iter->val)); if (iter + 1 != obj->last) { mod_json_minus_if_ne_zero(mod_json_string_add_char(str, ',')); } } } mod_json_minus_if_ne_zero(mod_json_string_add_char(str, '}')); return 0; } static inline int mod_json_dump_value(mod_json_string_t *str, mod_json_value_t *val) { if (val) { switch (val->type) { case mod_json_type_null: return mod_json_dump_null(str); case mod_json_type_boolean: return mod_json_dump_boolean(str, val->data.c_bool); case mod_json_type_integer: return mod_json_dump_integer(str, val->data.c_int); case mod_json_type_float: return mod_json_dump_float(str, val->data.c_float); case mod_json_type_string: return mod_json_dump_string(str, val->data.c_str); case mod_json_type_array: return mod_json_dump_array(str, val->data.c_arr); case mod_json_type_object: return mod_json_dump_object(str, val->data.c_obj); default: return -1; } } return mod_json_dump_null(str); } mod_json_string_t *mod_json_dump(mod_json_value_t *val) { mod_json_string_t *str = mod_json_string_set("", 0); mod_json_null_if_false(str); if (mod_json_unlikely(mod_json_dump_value(str, val) != 0)) { /* error occur */ mod_json_string_unset(str); return NULL; } return str; } ================================================ FILE: src/ailego/hash/crc32c.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #if !defined(__SSE4_2__) && !defined(__ARM_FEATURE_CRC32) /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o32[256] = { 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o40[256] = { 0x00000000, 0x13A29877, 0x274530EE, 0x34E7A899, 0x4E8A61DC, 0x5D28F9AB, 0x69CF5132, 0x7A6DC945, 0x9D14C3B8, 0x8EB65BCF, 0xBA51F356, 0xA9F36B21, 0xD39EA264, 0xC03C3A13, 0xF4DB928A, 0xE7790AFD, 0x3FC5F181, 0x2C6769F6, 0x1880C16F, 0x0B225918, 0x714F905D, 0x62ED082A, 0x560AA0B3, 0x45A838C4, 0xA2D13239, 0xB173AA4E, 0x859402D7, 0x96369AA0, 0xEC5B53E5, 0xFFF9CB92, 0xCB1E630B, 0xD8BCFB7C, 0x7F8BE302, 0x6C297B75, 0x58CED3EC, 0x4B6C4B9B, 0x310182DE, 0x22A31AA9, 0x1644B230, 0x05E62A47, 0xE29F20BA, 0xF13DB8CD, 0xC5DA1054, 0xD6788823, 0xAC154166, 0xBFB7D911, 0x8B507188, 0x98F2E9FF, 0x404E1283, 0x53EC8AF4, 0x670B226D, 0x74A9BA1A, 0x0EC4735F, 0x1D66EB28, 0x298143B1, 0x3A23DBC6, 0xDD5AD13B, 0xCEF8494C, 0xFA1FE1D5, 0xE9BD79A2, 0x93D0B0E7, 0x80722890, 0xB4958009, 0xA737187E, 0xFF17C604, 0xECB55E73, 0xD852F6EA, 0xCBF06E9D, 0xB19DA7D8, 0xA23F3FAF, 0x96D89736, 0x857A0F41, 0x620305BC, 0x71A19DCB, 0x45463552, 0x56E4AD25, 0x2C896460, 0x3F2BFC17, 0x0BCC548E, 0x186ECCF9, 0xC0D23785, 0xD370AFF2, 0xE797076B, 0xF4359F1C, 0x8E585659, 0x9DFACE2E, 0xA91D66B7, 0xBABFFEC0, 0x5DC6F43D, 0x4E646C4A, 0x7A83C4D3, 0x69215CA4, 0x134C95E1, 0x00EE0D96, 0x3409A50F, 0x27AB3D78, 0x809C2506, 0x933EBD71, 0xA7D915E8, 0xB47B8D9F, 0xCE1644DA, 0xDDB4DCAD, 0xE9537434, 0xFAF1EC43, 0x1D88E6BE, 0x0E2A7EC9, 0x3ACDD650, 0x296F4E27, 0x53028762, 0x40A01F15, 0x7447B78C, 0x67E52FFB, 0xBF59D487, 0xACFB4CF0, 0x981CE469, 0x8BBE7C1E, 0xF1D3B55B, 0xE2712D2C, 0xD69685B5, 0xC5341DC2, 0x224D173F, 0x31EF8F48, 0x050827D1, 0x16AABFA6, 0x6CC776E3, 0x7F65EE94, 0x4B82460D, 0x5820DE7A, 0xFBC3FAF9, 0xE861628E, 0xDC86CA17, 0xCF245260, 0xB5499B25, 0xA6EB0352, 0x920CABCB, 0x81AE33BC, 0x66D73941, 0x7575A136, 0x419209AF, 0x523091D8, 0x285D589D, 0x3BFFC0EA, 0x0F186873, 0x1CBAF004, 0xC4060B78, 0xD7A4930F, 0xE3433B96, 0xF0E1A3E1, 0x8A8C6AA4, 0x992EF2D3, 0xADC95A4A, 0xBE6BC23D, 0x5912C8C0, 0x4AB050B7, 0x7E57F82E, 0x6DF56059, 0x1798A91C, 0x043A316B, 0x30DD99F2, 0x237F0185, 0x844819FB, 0x97EA818C, 0xA30D2915, 0xB0AFB162, 0xCAC27827, 0xD960E050, 0xED8748C9, 0xFE25D0BE, 0x195CDA43, 0x0AFE4234, 0x3E19EAAD, 0x2DBB72DA, 0x57D6BB9F, 0x447423E8, 0x70938B71, 0x63311306, 0xBB8DE87A, 0xA82F700D, 0x9CC8D894, 0x8F6A40E3, 0xF50789A6, 0xE6A511D1, 0xD242B948, 0xC1E0213F, 0x26992BC2, 0x353BB3B5, 0x01DC1B2C, 0x127E835B, 0x68134A1E, 0x7BB1D269, 0x4F567AF0, 0x5CF4E287, 0x04D43CFD, 0x1776A48A, 0x23910C13, 0x30339464, 0x4A5E5D21, 0x59FCC556, 0x6D1B6DCF, 0x7EB9F5B8, 0x99C0FF45, 0x8A626732, 0xBE85CFAB, 0xAD2757DC, 0xD74A9E99, 0xC4E806EE, 0xF00FAE77, 0xE3AD3600, 0x3B11CD7C, 0x28B3550B, 0x1C54FD92, 0x0FF665E5, 0x759BACA0, 0x663934D7, 0x52DE9C4E, 0x417C0439, 0xA6050EC4, 0xB5A796B3, 0x81403E2A, 0x92E2A65D, 0xE88F6F18, 0xFB2DF76F, 0xCFCA5FF6, 0xDC68C781, 0x7B5FDFFF, 0x68FD4788, 0x5C1AEF11, 0x4FB87766, 0x35D5BE23, 0x26772654, 0x12908ECD, 0x013216BA, 0xE64B1C47, 0xF5E98430, 0xC10E2CA9, 0xD2ACB4DE, 0xA8C17D9B, 0xBB63E5EC, 0x8F844D75, 0x9C26D502, 0x449A2E7E, 0x5738B609, 0x63DF1E90, 0x707D86E7, 0x0A104FA2, 0x19B2D7D5, 0x2D557F4C, 0x3EF7E73B, 0xD98EEDC6, 0xCA2C75B1, 0xFECBDD28, 0xED69455F, 0x97048C1A, 0x84A6146D, 0xB041BCF4, 0xA3E32483}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o48[256] = { 0x00000000, 0xA541927E, 0x4F6F520D, 0xEA2EC073, 0x9EDEA41A, 0x3B9F3664, 0xD1B1F617, 0x74F06469, 0x38513EC5, 0x9D10ACBB, 0x773E6CC8, 0xD27FFEB6, 0xA68F9ADF, 0x03CE08A1, 0xE9E0C8D2, 0x4CA15AAC, 0x70A27D8A, 0xD5E3EFF4, 0x3FCD2F87, 0x9A8CBDF9, 0xEE7CD990, 0x4B3D4BEE, 0xA1138B9D, 0x045219E3, 0x48F3434F, 0xEDB2D131, 0x079C1142, 0xA2DD833C, 0xD62DE755, 0x736C752B, 0x9942B558, 0x3C032726, 0xE144FB14, 0x4405696A, 0xAE2BA919, 0x0B6A3B67, 0x7F9A5F0E, 0xDADBCD70, 0x30F50D03, 0x95B49F7D, 0xD915C5D1, 0x7C5457AF, 0x967A97DC, 0x333B05A2, 0x47CB61CB, 0xE28AF3B5, 0x08A433C6, 0xADE5A1B8, 0x91E6869E, 0x34A714E0, 0xDE89D493, 0x7BC846ED, 0x0F382284, 0xAA79B0FA, 0x40577089, 0xE516E2F7, 0xA9B7B85B, 0x0CF62A25, 0xE6D8EA56, 0x43997828, 0x37691C41, 0x92288E3F, 0x78064E4C, 0xDD47DC32, 0xC76580D9, 0x622412A7, 0x880AD2D4, 0x2D4B40AA, 0x59BB24C3, 0xFCFAB6BD, 0x16D476CE, 0xB395E4B0, 0xFF34BE1C, 0x5A752C62, 0xB05BEC11, 0x151A7E6F, 0x61EA1A06, 0xC4AB8878, 0x2E85480B, 0x8BC4DA75, 0xB7C7FD53, 0x12866F2D, 0xF8A8AF5E, 0x5DE93D20, 0x29195949, 0x8C58CB37, 0x66760B44, 0xC337993A, 0x8F96C396, 0x2AD751E8, 0xC0F9919B, 0x65B803E5, 0x1148678C, 0xB409F5F2, 0x5E273581, 0xFB66A7FF, 0x26217BCD, 0x8360E9B3, 0x694E29C0, 0xCC0FBBBE, 0xB8FFDFD7, 0x1DBE4DA9, 0xF7908DDA, 0x52D11FA4, 0x1E704508, 0xBB31D776, 0x511F1705, 0xF45E857B, 0x80AEE112, 0x25EF736C, 0xCFC1B31F, 0x6A802161, 0x56830647, 0xF3C29439, 0x19EC544A, 0xBCADC634, 0xC85DA25D, 0x6D1C3023, 0x8732F050, 0x2273622E, 0x6ED23882, 0xCB93AAFC, 0x21BD6A8F, 0x84FCF8F1, 0xF00C9C98, 0x554D0EE6, 0xBF63CE95, 0x1A225CEB, 0x8B277743, 0x2E66E53D, 0xC448254E, 0x6109B730, 0x15F9D359, 0xB0B84127, 0x5A968154, 0xFFD7132A, 0xB3764986, 0x1637DBF8, 0xFC191B8B, 0x595889F5, 0x2DA8ED9C, 0x88E97FE2, 0x62C7BF91, 0xC7862DEF, 0xFB850AC9, 0x5EC498B7, 0xB4EA58C4, 0x11ABCABA, 0x655BAED3, 0xC01A3CAD, 0x2A34FCDE, 0x8F756EA0, 0xC3D4340C, 0x6695A672, 0x8CBB6601, 0x29FAF47F, 0x5D0A9016, 0xF84B0268, 0x1265C21B, 0xB7245065, 0x6A638C57, 0xCF221E29, 0x250CDE5A, 0x804D4C24, 0xF4BD284D, 0x51FCBA33, 0xBBD27A40, 0x1E93E83E, 0x5232B292, 0xF77320EC, 0x1D5DE09F, 0xB81C72E1, 0xCCEC1688, 0x69AD84F6, 0x83834485, 0x26C2D6FB, 0x1AC1F1DD, 0xBF8063A3, 0x55AEA3D0, 0xF0EF31AE, 0x841F55C7, 0x215EC7B9, 0xCB7007CA, 0x6E3195B4, 0x2290CF18, 0x87D15D66, 0x6DFF9D15, 0xC8BE0F6B, 0xBC4E6B02, 0x190FF97C, 0xF321390F, 0x5660AB71, 0x4C42F79A, 0xE90365E4, 0x032DA597, 0xA66C37E9, 0xD29C5380, 0x77DDC1FE, 0x9DF3018D, 0x38B293F3, 0x7413C95F, 0xD1525B21, 0x3B7C9B52, 0x9E3D092C, 0xEACD6D45, 0x4F8CFF3B, 0xA5A23F48, 0x00E3AD36, 0x3CE08A10, 0x99A1186E, 0x738FD81D, 0xD6CE4A63, 0xA23E2E0A, 0x077FBC74, 0xED517C07, 0x4810EE79, 0x04B1B4D5, 0xA1F026AB, 0x4BDEE6D8, 0xEE9F74A6, 0x9A6F10CF, 0x3F2E82B1, 0xD50042C2, 0x7041D0BC, 0xAD060C8E, 0x08479EF0, 0xE2695E83, 0x4728CCFD, 0x33D8A894, 0x96993AEA, 0x7CB7FA99, 0xD9F668E7, 0x9557324B, 0x3016A035, 0xDA386046, 0x7F79F238, 0x0B899651, 0xAEC8042F, 0x44E6C45C, 0xE1A75622, 0xDDA47104, 0x78E5E37A, 0x92CB2309, 0x378AB177, 0x437AD51E, 0xE63B4760, 0x0C158713, 0xA954156D, 0xE5F54FC1, 0x40B4DDBF, 0xAA9A1DCC, 0x0FDB8FB2, 0x7B2BEBDB, 0xDE6A79A5, 0x3444B9D6, 0x91052BA8}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o56[256] = { 0x00000000, 0xDD45AAB8, 0xBF672381, 0x62228939, 0x7B2231F3, 0xA6679B4B, 0xC4451272, 0x1900B8CA, 0xF64463E6, 0x2B01C95E, 0x49234067, 0x9466EADF, 0x8D665215, 0x5023F8AD, 0x32017194, 0xEF44DB2C, 0xE964B13D, 0x34211B85, 0x560392BC, 0x8B463804, 0x924680CE, 0x4F032A76, 0x2D21A34F, 0xF06409F7, 0x1F20D2DB, 0xC2657863, 0xA047F15A, 0x7D025BE2, 0x6402E328, 0xB9474990, 0xDB65C0A9, 0x06206A11, 0xD725148B, 0x0A60BE33, 0x6842370A, 0xB5079DB2, 0xAC072578, 0x71428FC0, 0x136006F9, 0xCE25AC41, 0x2161776D, 0xFC24DDD5, 0x9E0654EC, 0x4343FE54, 0x5A43469E, 0x8706EC26, 0xE524651F, 0x3861CFA7, 0x3E41A5B6, 0xE3040F0E, 0x81268637, 0x5C632C8F, 0x45639445, 0x98263EFD, 0xFA04B7C4, 0x27411D7C, 0xC805C650, 0x15406CE8, 0x7762E5D1, 0xAA274F69, 0xB327F7A3, 0x6E625D1B, 0x0C40D422, 0xD1057E9A, 0xABA65FE7, 0x76E3F55F, 0x14C17C66, 0xC984D6DE, 0xD0846E14, 0x0DC1C4AC, 0x6FE34D95, 0xB2A6E72D, 0x5DE23C01, 0x80A796B9, 0xE2851F80, 0x3FC0B538, 0x26C00DF2, 0xFB85A74A, 0x99A72E73, 0x44E284CB, 0x42C2EEDA, 0x9F874462, 0xFDA5CD5B, 0x20E067E3, 0x39E0DF29, 0xE4A57591, 0x8687FCA8, 0x5BC25610, 0xB4868D3C, 0x69C32784, 0x0BE1AEBD, 0xD6A40405, 0xCFA4BCCF, 0x12E11677, 0x70C39F4E, 0xAD8635F6, 0x7C834B6C, 0xA1C6E1D4, 0xC3E468ED, 0x1EA1C255, 0x07A17A9F, 0xDAE4D027, 0xB8C6591E, 0x6583F3A6, 0x8AC7288A, 0x57828232, 0x35A00B0B, 0xE8E5A1B3, 0xF1E51979, 0x2CA0B3C1, 0x4E823AF8, 0x93C79040, 0x95E7FA51, 0x48A250E9, 0x2A80D9D0, 0xF7C57368, 0xEEC5CBA2, 0x3380611A, 0x51A2E823, 0x8CE7429B, 0x63A399B7, 0xBEE6330F, 0xDCC4BA36, 0x0181108E, 0x1881A844, 0xC5C402FC, 0xA7E68BC5, 0x7AA3217D, 0x52A0C93F, 0x8FE56387, 0xEDC7EABE, 0x30824006, 0x2982F8CC, 0xF4C75274, 0x96E5DB4D, 0x4BA071F5, 0xA4E4AAD9, 0x79A10061, 0x1B838958, 0xC6C623E0, 0xDFC69B2A, 0x02833192, 0x60A1B8AB, 0xBDE41213, 0xBBC47802, 0x6681D2BA, 0x04A35B83, 0xD9E6F13B, 0xC0E649F1, 0x1DA3E349, 0x7F816A70, 0xA2C4C0C8, 0x4D801BE4, 0x90C5B15C, 0xF2E73865, 0x2FA292DD, 0x36A22A17, 0xEBE780AF, 0x89C50996, 0x5480A32E, 0x8585DDB4, 0x58C0770C, 0x3AE2FE35, 0xE7A7548D, 0xFEA7EC47, 0x23E246FF, 0x41C0CFC6, 0x9C85657E, 0x73C1BE52, 0xAE8414EA, 0xCCA69DD3, 0x11E3376B, 0x08E38FA1, 0xD5A62519, 0xB784AC20, 0x6AC10698, 0x6CE16C89, 0xB1A4C631, 0xD3864F08, 0x0EC3E5B0, 0x17C35D7A, 0xCA86F7C2, 0xA8A47EFB, 0x75E1D443, 0x9AA50F6F, 0x47E0A5D7, 0x25C22CEE, 0xF8878656, 0xE1873E9C, 0x3CC29424, 0x5EE01D1D, 0x83A5B7A5, 0xF90696D8, 0x24433C60, 0x4661B559, 0x9B241FE1, 0x8224A72B, 0x5F610D93, 0x3D4384AA, 0xE0062E12, 0x0F42F53E, 0xD2075F86, 0xB025D6BF, 0x6D607C07, 0x7460C4CD, 0xA9256E75, 0xCB07E74C, 0x16424DF4, 0x106227E5, 0xCD278D5D, 0xAF050464, 0x7240AEDC, 0x6B401616, 0xB605BCAE, 0xD4273597, 0x09629F2F, 0xE6264403, 0x3B63EEBB, 0x59416782, 0x8404CD3A, 0x9D0475F0, 0x4041DF48, 0x22635671, 0xFF26FCC9, 0x2E238253, 0xF36628EB, 0x9144A1D2, 0x4C010B6A, 0x5501B3A0, 0x88441918, 0xEA669021, 0x37233A99, 0xD867E1B5, 0x05224B0D, 0x6700C234, 0xBA45688C, 0xA345D046, 0x7E007AFE, 0x1C22F3C7, 0xC167597F, 0xC747336E, 0x1A0299D6, 0x782010EF, 0xA565BA57, 0xBC65029D, 0x6120A825, 0x0302211C, 0xDE478BA4, 0x31035088, 0xEC46FA30, 0x8E647309, 0x5321D9B1, 0x4A21617B, 0x9764CBC3, 0xF54642FA, 0x2803E842}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o64[256] = { 0x00000000, 0x38116FAC, 0x7022DF58, 0x4833B0F4, 0xE045BEB0, 0xD854D11C, 0x906761E8, 0xA8760E44, 0xC5670B91, 0xFD76643D, 0xB545D4C9, 0x8D54BB65, 0x2522B521, 0x1D33DA8D, 0x55006A79, 0x6D1105D5, 0x8F2261D3, 0xB7330E7F, 0xFF00BE8B, 0xC711D127, 0x6F67DF63, 0x5776B0CF, 0x1F45003B, 0x27546F97, 0x4A456A42, 0x725405EE, 0x3A67B51A, 0x0276DAB6, 0xAA00D4F2, 0x9211BB5E, 0xDA220BAA, 0xE2336406, 0x1BA8B557, 0x23B9DAFB, 0x6B8A6A0F, 0x539B05A3, 0xFBED0BE7, 0xC3FC644B, 0x8BCFD4BF, 0xB3DEBB13, 0xDECFBEC6, 0xE6DED16A, 0xAEED619E, 0x96FC0E32, 0x3E8A0076, 0x069B6FDA, 0x4EA8DF2E, 0x76B9B082, 0x948AD484, 0xAC9BBB28, 0xE4A80BDC, 0xDCB96470, 0x74CF6A34, 0x4CDE0598, 0x04EDB56C, 0x3CFCDAC0, 0x51EDDF15, 0x69FCB0B9, 0x21CF004D, 0x19DE6FE1, 0xB1A861A5, 0x89B90E09, 0xC18ABEFD, 0xF99BD151, 0x37516AAE, 0x0F400502, 0x4773B5F6, 0x7F62DA5A, 0xD714D41E, 0xEF05BBB2, 0xA7360B46, 0x9F2764EA, 0xF236613F, 0xCA270E93, 0x8214BE67, 0xBA05D1CB, 0x1273DF8F, 0x2A62B023, 0x625100D7, 0x5A406F7B, 0xB8730B7D, 0x806264D1, 0xC851D425, 0xF040BB89, 0x5836B5CD, 0x6027DA61, 0x28146A95, 0x10050539, 0x7D1400EC, 0x45056F40, 0x0D36DFB4, 0x3527B018, 0x9D51BE5C, 0xA540D1F0, 0xED736104, 0xD5620EA8, 0x2CF9DFF9, 0x14E8B055, 0x5CDB00A1, 0x64CA6F0D, 0xCCBC6149, 0xF4AD0EE5, 0xBC9EBE11, 0x848FD1BD, 0xE99ED468, 0xD18FBBC4, 0x99BC0B30, 0xA1AD649C, 0x09DB6AD8, 0x31CA0574, 0x79F9B580, 0x41E8DA2C, 0xA3DBBE2A, 0x9BCAD186, 0xD3F96172, 0xEBE80EDE, 0x439E009A, 0x7B8F6F36, 0x33BCDFC2, 0x0BADB06E, 0x66BCB5BB, 0x5EADDA17, 0x169E6AE3, 0x2E8F054F, 0x86F90B0B, 0xBEE864A7, 0xF6DBD453, 0xCECABBFF, 0x6EA2D55C, 0x56B3BAF0, 0x1E800A04, 0x269165A8, 0x8EE76BEC, 0xB6F60440, 0xFEC5B4B4, 0xC6D4DB18, 0xABC5DECD, 0x93D4B161, 0xDBE70195, 0xE3F66E39, 0x4B80607D, 0x73910FD1, 0x3BA2BF25, 0x03B3D089, 0xE180B48F, 0xD991DB23, 0x91A26BD7, 0xA9B3047B, 0x01C50A3F, 0x39D46593, 0x71E7D567, 0x49F6BACB, 0x24E7BF1E, 0x1CF6D0B2, 0x54C56046, 0x6CD40FEA, 0xC4A201AE, 0xFCB36E02, 0xB480DEF6, 0x8C91B15A, 0x750A600B, 0x4D1B0FA7, 0x0528BF53, 0x3D39D0FF, 0x954FDEBB, 0xAD5EB117, 0xE56D01E3, 0xDD7C6E4F, 0xB06D6B9A, 0x887C0436, 0xC04FB4C2, 0xF85EDB6E, 0x5028D52A, 0x6839BA86, 0x200A0A72, 0x181B65DE, 0xFA2801D8, 0xC2396E74, 0x8A0ADE80, 0xB21BB12C, 0x1A6DBF68, 0x227CD0C4, 0x6A4F6030, 0x525E0F9C, 0x3F4F0A49, 0x075E65E5, 0x4F6DD511, 0x777CBABD, 0xDF0AB4F9, 0xE71BDB55, 0xAF286BA1, 0x9739040D, 0x59F3BFF2, 0x61E2D05E, 0x29D160AA, 0x11C00F06, 0xB9B60142, 0x81A76EEE, 0xC994DE1A, 0xF185B1B6, 0x9C94B463, 0xA485DBCF, 0xECB66B3B, 0xD4A70497, 0x7CD10AD3, 0x44C0657F, 0x0CF3D58B, 0x34E2BA27, 0xD6D1DE21, 0xEEC0B18D, 0xA6F30179, 0x9EE26ED5, 0x36946091, 0x0E850F3D, 0x46B6BFC9, 0x7EA7D065, 0x13B6D5B0, 0x2BA7BA1C, 0x63940AE8, 0x5B856544, 0xF3F36B00, 0xCBE204AC, 0x83D1B458, 0xBBC0DBF4, 0x425B0AA5, 0x7A4A6509, 0x3279D5FD, 0x0A68BA51, 0xA21EB415, 0x9A0FDBB9, 0xD23C6B4D, 0xEA2D04E1, 0x873C0134, 0xBF2D6E98, 0xF71EDE6C, 0xCF0FB1C0, 0x6779BF84, 0x5F68D028, 0x175B60DC, 0x2F4A0F70, 0xCD796B76, 0xF56804DA, 0xBD5BB42E, 0x854ADB82, 0x2D3CD5C6, 0x152DBA6A, 0x5D1E0A9E, 0x650F6532, 0x081E60E7, 0x300F0F4B, 0x783CBFBF, 0x402DD013, 0xE85BDE57, 0xD04AB1FB, 0x9879010F, 0xA0686EA3}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o72[256] = { 0x00000000, 0xEF306B19, 0xDB8CA0C3, 0x34BCCBDA, 0xB2F53777, 0x5DC55C6E, 0x697997B4, 0x8649FCAD, 0x6006181F, 0x8F367306, 0xBB8AB8DC, 0x54BAD3C5, 0xD2F32F68, 0x3DC34471, 0x097F8FAB, 0xE64FE4B2, 0xC00C303E, 0x2F3C5B27, 0x1B8090FD, 0xF4B0FBE4, 0x72F90749, 0x9DC96C50, 0xA975A78A, 0x4645CC93, 0xA00A2821, 0x4F3A4338, 0x7B8688E2, 0x94B6E3FB, 0x12FF1F56, 0xFDCF744F, 0xC973BF95, 0x2643D48C, 0x85F4168D, 0x6AC47D94, 0x5E78B64E, 0xB148DD57, 0x370121FA, 0xD8314AE3, 0xEC8D8139, 0x03BDEA20, 0xE5F20E92, 0x0AC2658B, 0x3E7EAE51, 0xD14EC548, 0x570739E5, 0xB83752FC, 0x8C8B9926, 0x63BBF23F, 0x45F826B3, 0xAAC84DAA, 0x9E748670, 0x7144ED69, 0xF70D11C4, 0x183D7ADD, 0x2C81B107, 0xC3B1DA1E, 0x25FE3EAC, 0xCACE55B5, 0xFE729E6F, 0x1142F576, 0x970B09DB, 0x783B62C2, 0x4C87A918, 0xA3B7C201, 0x0E045BEB, 0xE13430F2, 0xD588FB28, 0x3AB89031, 0xBCF16C9C, 0x53C10785, 0x677DCC5F, 0x884DA746, 0x6E0243F4, 0x813228ED, 0xB58EE337, 0x5ABE882E, 0xDCF77483, 0x33C71F9A, 0x077BD440, 0xE84BBF59, 0xCE086BD5, 0x213800CC, 0x1584CB16, 0xFAB4A00F, 0x7CFD5CA2, 0x93CD37BB, 0xA771FC61, 0x48419778, 0xAE0E73CA, 0x413E18D3, 0x7582D309, 0x9AB2B810, 0x1CFB44BD, 0xF3CB2FA4, 0xC777E47E, 0x28478F67, 0x8BF04D66, 0x64C0267F, 0x507CEDA5, 0xBF4C86BC, 0x39057A11, 0xD6351108, 0xE289DAD2, 0x0DB9B1CB, 0xEBF65579, 0x04C63E60, 0x307AF5BA, 0xDF4A9EA3, 0x5903620E, 0xB6330917, 0x828FC2CD, 0x6DBFA9D4, 0x4BFC7D58, 0xA4CC1641, 0x9070DD9B, 0x7F40B682, 0xF9094A2F, 0x16392136, 0x2285EAEC, 0xCDB581F5, 0x2BFA6547, 0xC4CA0E5E, 0xF076C584, 0x1F46AE9D, 0x990F5230, 0x763F3929, 0x4283F2F3, 0xADB399EA, 0x1C08B7D6, 0xF338DCCF, 0xC7841715, 0x28B47C0C, 0xAEFD80A1, 0x41CDEBB8, 0x75712062, 0x9A414B7B, 0x7C0EAFC9, 0x933EC4D0, 0xA7820F0A, 0x48B26413, 0xCEFB98BE, 0x21CBF3A7, 0x1577387D, 0xFA475364, 0xDC0487E8, 0x3334ECF1, 0x0788272B, 0xE8B84C32, 0x6EF1B09F, 0x81C1DB86, 0xB57D105C, 0x5A4D7B45, 0xBC029FF7, 0x5332F4EE, 0x678E3F34, 0x88BE542D, 0x0EF7A880, 0xE1C7C399, 0xD57B0843, 0x3A4B635A, 0x99FCA15B, 0x76CCCA42, 0x42700198, 0xAD406A81, 0x2B09962C, 0xC439FD35, 0xF08536EF, 0x1FB55DF6, 0xF9FAB944, 0x16CAD25D, 0x22761987, 0xCD46729E, 0x4B0F8E33, 0xA43FE52A, 0x90832EF0, 0x7FB345E9, 0x59F09165, 0xB6C0FA7C, 0x827C31A6, 0x6D4C5ABF, 0xEB05A612, 0x0435CD0B, 0x308906D1, 0xDFB96DC8, 0x39F6897A, 0xD6C6E263, 0xE27A29B9, 0x0D4A42A0, 0x8B03BE0D, 0x6433D514, 0x508F1ECE, 0xBFBF75D7, 0x120CEC3D, 0xFD3C8724, 0xC9804CFE, 0x26B027E7, 0xA0F9DB4A, 0x4FC9B053, 0x7B757B89, 0x94451090, 0x720AF422, 0x9D3A9F3B, 0xA98654E1, 0x46B63FF8, 0xC0FFC355, 0x2FCFA84C, 0x1B736396, 0xF443088F, 0xD200DC03, 0x3D30B71A, 0x098C7CC0, 0xE6BC17D9, 0x60F5EB74, 0x8FC5806D, 0xBB794BB7, 0x544920AE, 0xB206C41C, 0x5D36AF05, 0x698A64DF, 0x86BA0FC6, 0x00F3F36B, 0xEFC39872, 0xDB7F53A8, 0x344F38B1, 0x97F8FAB0, 0x78C891A9, 0x4C745A73, 0xA344316A, 0x250DCDC7, 0xCA3DA6DE, 0xFE816D04, 0x11B1061D, 0xF7FEE2AF, 0x18CE89B6, 0x2C72426C, 0xC3422975, 0x450BD5D8, 0xAA3BBEC1, 0x9E87751B, 0x71B71E02, 0x57F4CA8E, 0xB8C4A197, 0x8C786A4D, 0x63480154, 0xE501FDF9, 0x0A3196E0, 0x3E8D5D3A, 0xD1BD3623, 0x37F2D291, 0xD8C2B988, 0xEC7E7252, 0x034E194B, 0x8507E5E6, 0x6A378EFF, 0x5E8B4525, 0xB1BB2E3C}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o80[256] = { 0x00000000, 0x68032CC8, 0xD0065990, 0xB8057558, 0xA5E0C5D1, 0xCDE3E919, 0x75E69C41, 0x1DE5B089, 0x4E2DFD53, 0x262ED19B, 0x9E2BA4C3, 0xF628880B, 0xEBCD3882, 0x83CE144A, 0x3BCB6112, 0x53C84DDA, 0x9C5BFAA6, 0xF458D66E, 0x4C5DA336, 0x245E8FFE, 0x39BB3F77, 0x51B813BF, 0xE9BD66E7, 0x81BE4A2F, 0xD27607F5, 0xBA752B3D, 0x02705E65, 0x6A7372AD, 0x7796C224, 0x1F95EEEC, 0xA7909BB4, 0xCF93B77C, 0x3D5B83BD, 0x5558AF75, 0xED5DDA2D, 0x855EF6E5, 0x98BB466C, 0xF0B86AA4, 0x48BD1FFC, 0x20BE3334, 0x73767EEE, 0x1B755226, 0xA370277E, 0xCB730BB6, 0xD696BB3F, 0xBE9597F7, 0x0690E2AF, 0x6E93CE67, 0xA100791B, 0xC90355D3, 0x7106208B, 0x19050C43, 0x04E0BCCA, 0x6CE39002, 0xD4E6E55A, 0xBCE5C992, 0xEF2D8448, 0x872EA880, 0x3F2BDDD8, 0x5728F110, 0x4ACD4199, 0x22CE6D51, 0x9ACB1809, 0xF2C834C1, 0x7AB7077A, 0x12B42BB2, 0xAAB15EEA, 0xC2B27222, 0xDF57C2AB, 0xB754EE63, 0x0F519B3B, 0x6752B7F3, 0x349AFA29, 0x5C99D6E1, 0xE49CA3B9, 0x8C9F8F71, 0x917A3FF8, 0xF9791330, 0x417C6668, 0x297F4AA0, 0xE6ECFDDC, 0x8EEFD114, 0x36EAA44C, 0x5EE98884, 0x430C380D, 0x2B0F14C5, 0x930A619D, 0xFB094D55, 0xA8C1008F, 0xC0C22C47, 0x78C7591F, 0x10C475D7, 0x0D21C55E, 0x6522E996, 0xDD279CCE, 0xB524B006, 0x47EC84C7, 0x2FEFA80F, 0x97EADD57, 0xFFE9F19F, 0xE20C4116, 0x8A0F6DDE, 0x320A1886, 0x5A09344E, 0x09C17994, 0x61C2555C, 0xD9C72004, 0xB1C40CCC, 0xAC21BC45, 0xC422908D, 0x7C27E5D5, 0x1424C91D, 0xDBB77E61, 0xB3B452A9, 0x0BB127F1, 0x63B20B39, 0x7E57BBB0, 0x16549778, 0xAE51E220, 0xC652CEE8, 0x959A8332, 0xFD99AFFA, 0x459CDAA2, 0x2D9FF66A, 0x307A46E3, 0x58796A2B, 0xE07C1F73, 0x887F33BB, 0xF56E0EF4, 0x9D6D223C, 0x25685764, 0x4D6B7BAC, 0x508ECB25, 0x388DE7ED, 0x808892B5, 0xE88BBE7D, 0xBB43F3A7, 0xD340DF6F, 0x6B45AA37, 0x034686FF, 0x1EA33676, 0x76A01ABE, 0xCEA56FE6, 0xA6A6432E, 0x6935F452, 0x0136D89A, 0xB933ADC2, 0xD130810A, 0xCCD53183, 0xA4D61D4B, 0x1CD36813, 0x74D044DB, 0x27180901, 0x4F1B25C9, 0xF71E5091, 0x9F1D7C59, 0x82F8CCD0, 0xEAFBE018, 0x52FE9540, 0x3AFDB988, 0xC8358D49, 0xA036A181, 0x1833D4D9, 0x7030F811, 0x6DD54898, 0x05D66450, 0xBDD31108, 0xD5D03DC0, 0x8618701A, 0xEE1B5CD2, 0x561E298A, 0x3E1D0542, 0x23F8B5CB, 0x4BFB9903, 0xF3FEEC5B, 0x9BFDC093, 0x546E77EF, 0x3C6D5B27, 0x84682E7F, 0xEC6B02B7, 0xF18EB23E, 0x998D9EF6, 0x2188EBAE, 0x498BC766, 0x1A438ABC, 0x7240A674, 0xCA45D32C, 0xA246FFE4, 0xBFA34F6D, 0xD7A063A5, 0x6FA516FD, 0x07A63A35, 0x8FD9098E, 0xE7DA2546, 0x5FDF501E, 0x37DC7CD6, 0x2A39CC5F, 0x423AE097, 0xFA3F95CF, 0x923CB907, 0xC1F4F4DD, 0xA9F7D815, 0x11F2AD4D, 0x79F18185, 0x6414310C, 0x0C171DC4, 0xB412689C, 0xDC114454, 0x1382F328, 0x7B81DFE0, 0xC384AAB8, 0xAB878670, 0xB66236F9, 0xDE611A31, 0x66646F69, 0x0E6743A1, 0x5DAF0E7B, 0x35AC22B3, 0x8DA957EB, 0xE5AA7B23, 0xF84FCBAA, 0x904CE762, 0x2849923A, 0x404ABEF2, 0xB2828A33, 0xDA81A6FB, 0x6284D3A3, 0x0A87FF6B, 0x17624FE2, 0x7F61632A, 0xC7641672, 0xAF673ABA, 0xFCAF7760, 0x94AC5BA8, 0x2CA92EF0, 0x44AA0238, 0x594FB2B1, 0x314C9E79, 0x8949EB21, 0xE14AC7E9, 0x2ED97095, 0x46DA5C5D, 0xFEDF2905, 0x96DC05CD, 0x8B39B544, 0xE33A998C, 0x5B3FECD4, 0x333CC01C, 0x60F48DC6, 0x08F7A10E, 0xB0F2D456, 0xD8F1F89E, 0xC5144817, 0xAD1764DF, 0x15121187, 0x7D113D4F}; /** * The following CRC lookup table was generated automagically * using the following model parameters: * * Generator Polynomial = ................. 0x1EDC6F41 * Generator Polynomial Length = .......... 32 bits * Reflected Bits = ....................... TRUE * Table Generation Offset = .............. 32 bits * Number of Slices = ..................... 8 slices * Slice Lengths = ........................ 8 8 8 8 8 8 8 8 */ static uint32_t crc_tableil8_o88[256] = { 0x00000000, 0x493C7D27, 0x9278FA4E, 0xDB448769, 0x211D826D, 0x6821FF4A, 0xB3657823, 0xFA590504, 0x423B04DA, 0x0B0779FD, 0xD043FE94, 0x997F83B3, 0x632686B7, 0x2A1AFB90, 0xF15E7CF9, 0xB86201DE, 0x847609B4, 0xCD4A7493, 0x160EF3FA, 0x5F328EDD, 0xA56B8BD9, 0xEC57F6FE, 0x37137197, 0x7E2F0CB0, 0xC64D0D6E, 0x8F717049, 0x5435F720, 0x1D098A07, 0xE7508F03, 0xAE6CF224, 0x7528754D, 0x3C14086A, 0x0D006599, 0x443C18BE, 0x9F789FD7, 0xD644E2F0, 0x2C1DE7F4, 0x65219AD3, 0xBE651DBA, 0xF759609D, 0x4F3B6143, 0x06071C64, 0xDD439B0D, 0x947FE62A, 0x6E26E32E, 0x271A9E09, 0xFC5E1960, 0xB5626447, 0x89766C2D, 0xC04A110A, 0x1B0E9663, 0x5232EB44, 0xA86BEE40, 0xE1579367, 0x3A13140E, 0x732F6929, 0xCB4D68F7, 0x827115D0, 0x593592B9, 0x1009EF9E, 0xEA50EA9A, 0xA36C97BD, 0x782810D4, 0x31146DF3, 0x1A00CB32, 0x533CB615, 0x8878317C, 0xC1444C5B, 0x3B1D495F, 0x72213478, 0xA965B311, 0xE059CE36, 0x583BCFE8, 0x1107B2CF, 0xCA4335A6, 0x837F4881, 0x79264D85, 0x301A30A2, 0xEB5EB7CB, 0xA262CAEC, 0x9E76C286, 0xD74ABFA1, 0x0C0E38C8, 0x453245EF, 0xBF6B40EB, 0xF6573DCC, 0x2D13BAA5, 0x642FC782, 0xDC4DC65C, 0x9571BB7B, 0x4E353C12, 0x07094135, 0xFD504431, 0xB46C3916, 0x6F28BE7F, 0x2614C358, 0x1700AEAB, 0x5E3CD38C, 0x857854E5, 0xCC4429C2, 0x361D2CC6, 0x7F2151E1, 0xA465D688, 0xED59ABAF, 0x553BAA71, 0x1C07D756, 0xC743503F, 0x8E7F2D18, 0x7426281C, 0x3D1A553B, 0xE65ED252, 0xAF62AF75, 0x9376A71F, 0xDA4ADA38, 0x010E5D51, 0x48322076, 0xB26B2572, 0xFB575855, 0x2013DF3C, 0x692FA21B, 0xD14DA3C5, 0x9871DEE2, 0x4335598B, 0x0A0924AC, 0xF05021A8, 0xB96C5C8F, 0x6228DBE6, 0x2B14A6C1, 0x34019664, 0x7D3DEB43, 0xA6796C2A, 0xEF45110D, 0x151C1409, 0x5C20692E, 0x8764EE47, 0xCE589360, 0x763A92BE, 0x3F06EF99, 0xE44268F0, 0xAD7E15D7, 0x572710D3, 0x1E1B6DF4, 0xC55FEA9D, 0x8C6397BA, 0xB0779FD0, 0xF94BE2F7, 0x220F659E, 0x6B3318B9, 0x916A1DBD, 0xD856609A, 0x0312E7F3, 0x4A2E9AD4, 0xF24C9B0A, 0xBB70E62D, 0x60346144, 0x29081C63, 0xD3511967, 0x9A6D6440, 0x4129E329, 0x08159E0E, 0x3901F3FD, 0x703D8EDA, 0xAB7909B3, 0xE2457494, 0x181C7190, 0x51200CB7, 0x8A648BDE, 0xC358F6F9, 0x7B3AF727, 0x32068A00, 0xE9420D69, 0xA07E704E, 0x5A27754A, 0x131B086D, 0xC85F8F04, 0x8163F223, 0xBD77FA49, 0xF44B876E, 0x2F0F0007, 0x66337D20, 0x9C6A7824, 0xD5560503, 0x0E12826A, 0x472EFF4D, 0xFF4CFE93, 0xB67083B4, 0x6D3404DD, 0x240879FA, 0xDE517CFE, 0x976D01D9, 0x4C2986B0, 0x0515FB97, 0x2E015D56, 0x673D2071, 0xBC79A718, 0xF545DA3F, 0x0F1CDF3B, 0x4620A21C, 0x9D642575, 0xD4585852, 0x6C3A598C, 0x250624AB, 0xFE42A3C2, 0xB77EDEE5, 0x4D27DBE1, 0x041BA6C6, 0xDF5F21AF, 0x96635C88, 0xAA7754E2, 0xE34B29C5, 0x380FAEAC, 0x7133D38B, 0x8B6AD68F, 0xC256ABA8, 0x19122CC1, 0x502E51E6, 0xE84C5038, 0xA1702D1F, 0x7A34AA76, 0x3308D751, 0xC951D255, 0x806DAF72, 0x5B29281B, 0x1215553C, 0x230138CF, 0x6A3D45E8, 0xB179C281, 0xF845BFA6, 0x021CBAA2, 0x4B20C785, 0x906440EC, 0xD9583DCB, 0x613A3C15, 0x28064132, 0xF342C65B, 0xBA7EBB7C, 0x4027BE78, 0x091BC35F, 0xD25F4436, 0x9B633911, 0xA777317B, 0xEE4B4C5C, 0x350FCB35, 0x7C33B612, 0x866AB316, 0xCF56CE31, 0x14124958, 0x5D2E347F, 0xE54C35A1, 0xAC704886, 0x7734CFEF, 0x3E08B2C8, 0xC451B7CC, 0x8D6DCAEB, 0x56294D82, 0x1F1530A5}; /** * Implementations adapted from Intel's Slicing By 8 Sourceforge Project * http://sourceforge.net/projects/slicing-by-8/ * http://www.evanjones.ca/crc32c.html */ static inline uint32_t crc32c_slicing8(const void *data, size_t len, uint32_t crc) { const uint8_t *p_buf = (const uint8_t *)data; /* Handle leading misaligned bytes */ size_t init_bytes = (sizeof(int32_t) - (intptr_t)p_buf) & (sizeof(int32_t) - 1); if (len < init_bytes) { init_bytes = len; } for (size_t li = 0; li < init_bytes; li++) { crc = crc_tableil8_o32[(crc ^ *p_buf++) & 0x000000FF] ^ (crc >> 8); } len -= init_bytes; size_t running_length = len & ~(sizeof(uint64_t) - 1); size_t end_bytes = len - running_length; for (size_t li = 0; li < running_length / 8; li++) { uint32_t term1, term2; crc ^= *(uint32_t *)p_buf; p_buf += 4; term1 = crc_tableil8_o88[crc & 0x000000FF] ^ crc_tableil8_o80[(crc >> 8) & 0x000000FF]; term2 = crc >> 16; crc = term1 ^ crc_tableil8_o72[term2 & 0x000000FF] ^ crc_tableil8_o64[(term2 >> 8) & 0x000000FF]; term1 = crc_tableil8_o56[(*(uint32_t *)p_buf) & 0x000000FF] ^ crc_tableil8_o48[((*(uint32_t *)p_buf) >> 8) & 0x000000FF]; term2 = (*(uint32_t *)p_buf) >> 16; crc = crc ^ term1 ^ crc_tableil8_o40[term2 & 0x000000FF] ^ crc_tableil8_o32[(term2 >> 8) & 0x000000FF]; p_buf += 4; } for (size_t li = 0; li < end_bytes; li++) { crc = crc_tableil8_o32[(crc ^ *p_buf++) & 0x000000FF] ^ (crc >> 8); } return crc; } #endif // !__SSE4_2__ #if defined(__SSE4_2__) #if defined(AILEGO_M64) static inline uint32_t crc32c_sse42(const void *data, size_t len, uint32_t crc) { const uint8_t *first = (const uint8_t *)data; const uint8_t *last = first + ((len >> 3) << 3); for (; first != last; first += 8) { crc = (uint32_t)_mm_crc32_u64(crc, *(uint64_t *)first); } switch (((uint8_t *)data + len) - last) { case 1: crc = _mm_crc32_u8(crc, *last); break; case 2: crc = _mm_crc32_u16(crc, *(uint16_t *)last); break; case 3: crc = _mm_crc32_u16(crc, *(uint16_t *)last); crc = _mm_crc32_u8(crc, *(last + 2)); break; case 4: crc = _mm_crc32_u32(crc, *(uint32_t *)last); break; case 5: crc = _mm_crc32_u32(crc, *(uint32_t *)last); crc = _mm_crc32_u8(crc, *(last + 4)); break; case 6: crc = _mm_crc32_u32(crc, *(uint32_t *)last); crc = _mm_crc32_u16(crc, *(uint16_t *)(last + 4)); break; case 7: crc = _mm_crc32_u32(crc, *(uint32_t *)last); crc = _mm_crc32_u16(crc, *(uint16_t *)(last + 4)); crc = _mm_crc32_u8(crc, *(last + 6)); break; } return crc; } #else static inline uint32_t crc32c_sse42(const void *data, size_t len, uint32_t crc) { const uint8_t *first = (const uint8_t *)data; const uint8_t *last = first + ((len >> 2) << 2); for (; first != last; first += 4) { crc = _mm_crc32_u32(crc, *(uint32_t *)first); } switch (((uint8_t *)data + len) - last) { case 1: crc = _mm_crc32_u8(crc, *last); break; case 2: crc = _mm_crc32_u16(crc, *(uint16_t *)last); break; case 3: crc = _mm_crc32_u16(crc, *(uint16_t *)last); crc = _mm_crc32_u8(crc, *(last + 2)); break; } return crc; } #endif // AILEGO_M64 #endif // __SSE4_2__ #if defined(__ARM_FEATURE_CRC32) static inline uint32_t crc32c_neon(const void *data, size_t len, uint32_t crc) { const uint8_t *first = (const uint8_t *)data; const uint8_t *last = first + ((len >> 3) << 3); for (; first != last; first += 8) { crc = __crc32cd(crc, *(uint64_t *)first); } switch (((uint8_t *)data + len) - last) { case 1: crc = __crc32cb(crc, *last); break; case 2: crc = __crc32ch(crc, *(uint16_t *)last); break; case 3: crc = __crc32ch(crc, *(uint16_t *)last); crc = __crc32cb(crc, *(last + 2)); break; case 4: crc = __crc32cw(crc, *(uint32_t *)last); break; case 5: crc = __crc32cw(crc, *(uint32_t *)last); crc = __crc32cb(crc, *(last + 4)); break; case 6: crc = __crc32cw(crc, *(uint32_t *)last); crc = __crc32ch(crc, *(uint16_t *)(last + 4)); break; case 7: crc = __crc32cw(crc, *(uint32_t *)last); crc = __crc32ch(crc, *(uint16_t *)(last + 4)); crc = __crc32cb(crc, *(last + 6)); break; } return crc; } #endif // __ARM_FEATURE_CRC32 namespace zvec { namespace ailego { uint32_t Crc32c::Hash(const void *data, size_t len, uint32_t crc) { #if defined(__SSE4_2__) return crc32c_sse42(data, len, crc); #elif defined(__ARM_FEATURE_CRC32) return crc32c_neon(data, len, crc); #else return crc32c_slicing8(data, len, crc); #endif } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/internal/cpu_features.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "cpu_features.h" #include #if !defined(_MSC_VER) && !defined(__ARM_ARCH) #include #endif namespace zvec { namespace ailego { namespace internal { // // REFER: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/ // tree/arch/x86/include/asm/cpufeatures.h // https://software.intel.com/sites/default/files/managed/c5/15/ // architecture-instruction-set-extensions-programming-reference.pdf // CpuFeatures::CpuFlags CpuFeatures::flags_; #if defined(_MSC_VER) CpuFeatures::CpuFlags::CpuFlags(void) : L1_ECX(0), L1_EDX(0), L7_EBX(0), L7_ECX(0), L7_EDX(0) { int l1[4] = {0, 0, 0, 0}; int l7[4] = {0, 0, 0, 0}; __cpuidex(l1, 1, 0); __cpuidex(l7, 7, 0); L1_ECX = l1[2]; L1_EDX = l1[3]; L7_EBX = l7[1]; L7_ECX = l7[2]; L7_EDX = l7[3]; } #elif !defined(__ARM_ARCH) CpuFeatures::CpuFlags::CpuFlags(void) : L1_ECX(0), L1_EDX(0), L7_EBX(0), L7_ECX(0), L7_EDX(0) { uint32_t eax, ebx, ecx, edx; if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { L1_ECX = ecx; L1_EDX = edx; } if (__get_cpuid_max(0, NULL) >= 7) { __cpuid_count(7, 0, eax, ebx, ecx, edx); L7_EBX = ebx; L7_ECX = ecx; L7_EDX = edx; } } #else CpuFeatures::CpuFlags::CpuFlags(void) : L1_ECX(0), L1_EDX(0), L7_EBX(0), L7_ECX(0), L7_EDX(0) {} #endif //! 16-bit FP conversions bool CpuFeatures::F16C(void) { return !!(flags_.L1_ECX & (1u << 29)); } //! Multimedia Extensions bool CpuFeatures::MMX(void) { return !!(flags_.L1_EDX & (1u << 23)); } //! Streaming SIMD Extensions bool CpuFeatures::SSE(void) { return !!(flags_.L1_EDX & (1u << 25)); } //! Streaming SIMD Extensions 2 bool CpuFeatures::SSE2(void) { return !!(flags_.L1_EDX & (1u << 26)); } //! Streaming SIMD Extensions 3 bool CpuFeatures::SSE3(void) { return !!(flags_.L1_ECX & (1u << 0)); } //! Supplemental Streaming SIMD Extensions 3 bool CpuFeatures::SSSE3(void) { return !!(flags_.L1_ECX & (1u << 9)); } //! Streaming SIMD Extensions 4.1 bool CpuFeatures::SSE4_1(void) { return !!(flags_.L1_ECX & (1u << 19)); } //! Streaming SIMD Extensions 4.2 bool CpuFeatures::SSE4_2(void) { return !!(flags_.L1_ECX & (1u << 20)); } //! Advanced Vector Extensions bool CpuFeatures::AVX(void) { return !!(flags_.L1_ECX & (1u << 28)); } //! Advanced Vector Extensions 2 bool CpuFeatures::AVX2(void) { return !!(flags_.L7_EBX & (1u << 5)); } //! AVX-512 Foundation bool CpuFeatures::AVX512F(void) { return !!(flags_.L7_EBX & (1u << 16)); } //! AVX-512 DQ (Double/Quad granular) Instructions bool CpuFeatures::AVX512DQ(void) { return !!(flags_.L7_EBX & (1u << 17)); } //! AVX-512 Prefetch bool CpuFeatures::AVX512PF(void) { return !!(flags_.L7_EBX & (1u << 26)); } //! AVX-512 Exponential and Reciprocal bool CpuFeatures::AVX512ER(void) { return !!(flags_.L7_EBX & (1u << 27)); } //! AVX-512 Conflict Detection bool CpuFeatures::AVX512CD(void) { return !!(flags_.L7_EBX & (1u << 28)); } //! AVX-512 BW (Byte/Word granular) Instructions bool CpuFeatures::AVX512BW(void) { return !!(flags_.L7_EBX & (1u << 30)); } //! AVX-512 VL (128/256 Vector Length) Extensions bool CpuFeatures::AVX512VL(void) { return !!(flags_.L7_EBX & (1u << 31)); } //! AVX-512 Integer Fused Multiply-Add instructions bool CpuFeatures::AVX512_IFMA(void) { return !!(flags_.L7_EBX & (1u << 21)); } //! AVX512 Vector Bit Manipulation instructions bool CpuFeatures::AVX512_VBMI(void) { return !!(flags_.L7_ECX & (1u << 1)); } //! Additional AVX512 Vector Bit Manipulation Instructions bool CpuFeatures::AVX512_VBMI2(void) { return !!(flags_.L7_ECX & (1u << 6)); } //! Vector Neural Network Instructions bool CpuFeatures::AVX512_VNNI(void) { return !!(flags_.L7_ECX & (1u << 11)); } //! Support for VPOPCNT[B,W] and VPSHUF-BITQMB instructions bool CpuFeatures::AVX512_BITALG(void) { return !!(flags_.L7_ECX & (1u << 12)); } //! POPCNT for vectors of DW/QW bool CpuFeatures::AVX512_VPOPCNTDQ(void) { return !!(flags_.L7_ECX & (1u << 14)); } //! AVX-512 Neural Network Instructions bool CpuFeatures::AVX512_4VNNIW(void) { return !!(flags_.L7_EDX & (1u << 2)); } //! AVX-512 Multiply Accumulation Single precision bool CpuFeatures::AVX512_4FMAPS(void) { return !!(flags_.L7_EDX & (1u << 3)); } //! AVX-512 FP16 instructions bool CpuFeatures::AVX512_FP16(void) { return !!(flags_.L7_EDX & (1u << 23)); } //! CMPXCHG8 instruction bool CpuFeatures::CX8(void) { return !!(flags_.L1_EDX & (1u << 8)); } //! CMPXCHG16B instruction bool CpuFeatures::CX16(void) { return !!(flags_.L1_ECX & (1u << 13)); } //! PCLMULQDQ instruction bool CpuFeatures::PCLMULQDQ(void) { return !!(flags_.L1_ECX & (1u << 1)); } //! Carry-Less Multiplication Double Quadword bool CpuFeatures::VPCLMULQDQ(void) { return !!(flags_.L7_ECX & (1u << 10)); } //! CMOV instructions (plus FCMOVcc, FCOMI with FPU) bool CpuFeatures::CMOV(void) { return !!(flags_.L1_EDX & (1u << 15)); } //! MOVBE instruction bool CpuFeatures::MOVBE(void) { return !!(flags_.L1_ECX & (1u << 22)); } //! Enhanced REP MOVSB/STOSB instructions bool CpuFeatures::ERMS(void) { return !!(flags_.L7_EBX & (1u << 9)); } //! POPCNT instruction bool CpuFeatures::POPCNT(void) { return !!(flags_.L1_ECX & (1u << 23)); } //! XSAVE/XRSTOR/XSETBV/XGETBV instructions bool CpuFeatures::XSAVE(void) { return !!(flags_.L1_ECX & (1u << 26)); } //! Fused multiply-add bool CpuFeatures::FMA(void) { return !!(flags_.L1_ECX & (1u << 12)); } //! ADCX and ADOX instructions bool CpuFeatures::ADX(void) { return !!(flags_.L7_EBX & (1u << 19)); } //! Galois Field New Instructions bool CpuFeatures::GFNI(void) { return !!(flags_.L7_ECX & (1u << 8)); } //! AES instructions bool CpuFeatures::AES(void) { return !!(flags_.L1_ECX & (1u << 25)); } //! Vector AES bool CpuFeatures::VAES(void) { return !!(flags_.L7_ECX & (1u << 9)); } //! RDSEED instruction bool CpuFeatures::RDSEED(void) { return !!(flags_.L7_EBX & (1u << 18)); } //! RDRAND instruction bool CpuFeatures::RDRAND(void) { return !!(flags_.L1_ECX & (1u << 30)); } //! SHA1/SHA256 Instruction Extensions bool CpuFeatures::SHA(void) { return !!(flags_.L7_EBX & (1u << 29)); } //! 1st group bit manipulation extensions bool CpuFeatures::BMI1(void) { return !!(flags_.L7_EBX & (1u << 3)); } //! 2nd group bit manipulation extensions bool CpuFeatures::BMI2(void) { return !!(flags_.L7_EBX & (1u << 8)); } //! CLFLUSH instruction bool CpuFeatures::CLFLUSH(void) { return !!(flags_.L1_EDX & (1u << 19)); } //! CLFLUSHOPT instruction bool CpuFeatures::CLFLUSHOPT(void) { return !!(flags_.L7_EBX & (1u << 23)); } //! CLWB instruction bool CpuFeatures::CLWB(void) { return !!(flags_.L7_EBX & (1u << 24)); } //! RDPID instruction bool CpuFeatures::RDPID(void) { return !!(flags_.L7_ECX & (1u << 22)); } //! Onboard FPU bool CpuFeatures::FPU(void) { return !!(flags_.L1_EDX & (1u << 0)); } //! Hyper-Threading bool CpuFeatures::HT(void) { return !!(flags_.L1_EDX & (1u << 28)); } //! Hardware virtualization bool CpuFeatures::VMX(void) { return !!(flags_.L1_ECX & (1u << 5)); } // !Running on a hypervisor bool CpuFeatures::HYPERVISOR(void) { return !!(flags_.L1_ECX & (1u << 31)); } const char *CpuFeatures::Intrinsics(void) { return "" #if defined(__ARM_NEON) "Neon" #if defined(__ARM_FEATURE_CRC32) "+CRC" #endif #if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) "+FP16" #endif #elif defined(__AVX512F__) "AVX512F" #if defined(__AVX512VL__) "+AVX512VL" #endif #if defined(__AVX512BW__) "+AVX512BW" #endif #if defined(__AVX512DQ__) "+AVX512DQ" #endif #if defined(__AVX512CD__) "+AVX512CD" #endif #if defined(__AVX512ER__) "+AVX512ER" #endif #if defined(__AVX512PF__) "+AVX512PF" #endif #if defined(__AVX512IFMA__) "+AVX512IFMA" #endif #if defined(__AVX512VBMI__) "+AVX512VBMI" #endif #if defined(__AVX512VBMI2__) "+AVX512VBMI2" #endif #if defined(__AVX512VNNI__) "+AVX512VNNI" #endif #if defined(__AVX512BITALG__) "+AVX512BITALG" #endif #if defined(__AVX512VPOPCNTDQ__) "+AVX512VPOPCNTDQ" #endif #if defined(__AVX512FP16__) "+AVX512FP16" #endif #elif defined(__AVX2__) "AVX2" #elif defined(__AVX__) "AVX" #elif defined(__SSE4_2__) "SSE4.2" #elif defined(__SSE4_1__) "SSE4.1" #elif defined(__SSSE3__) "SSSE3" #elif defined(__SSE3__) "SSE3" #elif defined(__SSE2__) "SSE2" #elif defined(__SSE__) "SSE" #elif defined(__MMX__) "MMX" #endif #if defined(__FMA__) "+FMA" #endif #if defined(__BMI2__) "+BMI2" #elif defined(__BMI__) "+BMI" #endif #if defined(__F16C__) "+F16C" #endif ; } CpuFeatures::StaticFlags CpuFeatures::static_flags_; } // namespace internal } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/internal/cpu_features.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { namespace internal { /*! Cpu Features */ class CpuFeatures { public: //! 16-bit FP conversions static bool F16C(void); //! Multimedia Extensions static bool MMX(void); //! Streaming SIMD Extensions static bool SSE(void); //! Streaming SIMD Extensions 2 static bool SSE2(void); //! Streaming SIMD Extensions 3 static bool SSE3(void); //! Supplemental Streaming SIMD Extensions 3 static bool SSSE3(void); //! Streaming SIMD Extensions 4.1 static bool SSE4_1(void); //! Streaming SIMD Extensions 4.2 static bool SSE4_2(void); //! Advanced Vector Extensions static bool AVX(void); //! Advanced Vector Extensions 2 static bool AVX2(void); //! AVX-512 Foundation static bool AVX512F(void); //! AVX-512 DQ (Double/Quad granular) Instructions static bool AVX512DQ(void); //! AVX-512 Prefetch static bool AVX512PF(void); //! AVX-512 Exponential and Reciprocal static bool AVX512ER(void); //! AVX-512 Conflict Detection static bool AVX512CD(void); //! AVX-512 BW (Byte/Word granular) Instructions static bool AVX512BW(void); //! AVX-512 VL (128/256 Vector Length) Extensions static bool AVX512VL(void); //! AVX-512 Integer Fused Multiply-Add instructions static bool AVX512_IFMA(void); //! AVX512 Vector Bit Manipulation instructions static bool AVX512_VBMI(void); //! Additional AVX512 Vector Bit Manipulation Instructions static bool AVX512_VBMI2(void); //! Vector Neural Network Instructions static bool AVX512_VNNI(void); //! Support for VPOPCNT[B,W] and VPSHUF-BITQMB instructions static bool AVX512_BITALG(void); //! POPCNT for vectors of DW/QW static bool AVX512_VPOPCNTDQ(void); //! AVX-512 Neural Network Instructions static bool AVX512_4VNNIW(void); //! AVX-512 Multiply Accumulation Single precision static bool AVX512_4FMAPS(void); //! AVX-512 FP16 instructions static bool AVX512_FP16(void); //! CMPXCHG8 instruction static bool CX8(void); //! CMPXCHG16B instruction static bool CX16(void); //! PCLMULQDQ instruction static bool PCLMULQDQ(void); //! Carry-Less Multiplication Double Quadword static bool VPCLMULQDQ(void); //! CMOV instructions (plus FCMOVcc, FCOMI with FPU) static bool CMOV(void); //! MOVBE instruction static bool MOVBE(void); //! Enhanced REP MOVSB/STOSB instructions static bool ERMS(void); //! POPCNT instruction static bool POPCNT(void); //! XSAVE/XRSTOR/XSETBV/XGETBV instructions static bool XSAVE(void); //! Fused multiply-add static bool FMA(void); //! ADCX and ADOX instructions static bool ADX(void); //! Galois Field New Instructions static bool GFNI(void); //! AES instructions static bool AES(void); //! Vector AES static bool VAES(void); //! RDSEED instruction static bool RDSEED(void); //! RDRAND instruction static bool RDRAND(void); //! SHA1/SHA256 Instruction Extensions static bool SHA(void); //! 1st group bit manipulation extensions static bool BMI1(void); //! 2nd group bit manipulation extensions static bool BMI2(void); //! CLFLUSH instruction static bool CLFLUSH(void); //! CLFLUSHOPT instruction static bool CLFLUSHOPT(void); //! CLWB instruction static bool CLWB(void); //! RDPID instruction static bool RDPID(void); //! Onboard FPU static bool FPU(void); //! Hyper-Threading static bool HT(void); //! Hardware virtualization static bool VMX(void); // !Running on a hypervisor static bool HYPERVISOR(void); //! Intrinsics of compiling static const char *Intrinsics(void); private: struct CpuFlags { //! Constructor CpuFlags(void); //! Members uint32_t L1_ECX; uint32_t L1_EDX; uint32_t L7_EBX; uint32_t L7_ECX; uint32_t L7_EDX; }; //! Static Members static CpuFlags flags_; public: struct StaticFlags { //! 16-bit FP conversions bool F16C = CpuFeatures::F16C(); //! Multimedia Extensions bool MMX = CpuFeatures::MMX(); //! Streaming SIMD Extensions bool SSE = CpuFeatures::SSE(); //! Streaming SIMD Extensions 2 bool SSE2 = CpuFeatures::SSE2(); //! Streaming SIMD Extensions 3 bool SSE3 = CpuFeatures::SSE3(); //! Supplemental Streaming SIMD Extensions 3 bool SSSE3 = CpuFeatures::SSSE3(); //! Streaming SIMD Extensions 4.1 bool SSE4_1 = CpuFeatures::SSE4_1(); //! Streaming SIMD Extensions 4.2 bool SSE4_2 = CpuFeatures::SSE4_2(); //! Advanced Vector Extensions bool AVX = CpuFeatures::AVX(); //! Advanced Vector Extensions 2 bool AVX2 = CpuFeatures::AVX2(); //! AVX-512 Foundation bool AVX512F = CpuFeatures::AVX512F(); //! AVX-512 DQ (Double/Quad granular) Instructions bool AVX512DQ = CpuFeatures::AVX512DQ(); //! AVX-512 Prefetch bool AVX512PF = CpuFeatures::AVX512PF(); //! AVX-512 Exponential and Reciprocal bool AVX512ER = CpuFeatures::AVX512ER(); //! AVX-512 Conflict Detection bool AVX512CD = CpuFeatures::AVX512CD(); //! AVX-512 BW (Byte/Word granular) Instructions bool AVX512BW = CpuFeatures::AVX512BW(); //! AVX-512 VL (128/256 Vector Length) Extensions bool AVX512VL = CpuFeatures::AVX512VL(); //! AVX-512 Integer Fused Multiply-Add instructions bool AVX512_IFMA = CpuFeatures::AVX512_IFMA(); //! AVX512 Vector Bit Manipulation instructions bool AVX512_VBMI = CpuFeatures::AVX512_VBMI(); //! Additional AVX512 Vector Bit Manipulation Instructions bool AVX512_VBMI2 = CpuFeatures::AVX512_VBMI2(); //! Vector Neural Network Instructions bool AVX512_VNNI = CpuFeatures::AVX512_VNNI(); //! Support for VPOPCNT[B,W] and VPSHUF-BITQMB instructions bool AVX512_BITALG = CpuFeatures::AVX512_BITALG(); //! POPCNT for vectors of DW/QW bool AVX512_VPOPCNTDQ = CpuFeatures::AVX512_VPOPCNTDQ(); //! AVX-512 Neural Network Instructions bool AVX512_4VNNIW = CpuFeatures::AVX512_4VNNIW(); //! AVX-512 Multiply Accumulation Single precision bool AVX512_4FMAPS = CpuFeatures::AVX512_4FMAPS(); //! AVX-512 FP16 instructions bool AVX512_FP16 = CpuFeatures::AVX512_FP16(); //! CMPXCHG8 instruction bool CX8 = CpuFeatures::CX8(); //! CMPXCHG16B instruction bool CX16 = CpuFeatures::CX16(); //! PCLMULQDQ instruction bool PCLMULQDQ = CpuFeatures::PCLMULQDQ(); //! Carry-Less Multiplication Double Quadword bool VPCLMULQDQ = CpuFeatures::VPCLMULQDQ(); //! CMOV instructions (plus FCMOVcc, FCOMI with FPU) bool CMOV = CpuFeatures::CMOV(); //! MOVBE instruction bool MOVBE = CpuFeatures::MOVBE(); //! Enhanced REP MOVSB/STOSB instructions bool ERMS = CpuFeatures::ERMS(); //! POPCNT instruction bool POPCNT = CpuFeatures::POPCNT(); //! XSAVE/XRSTOR/XSETBV/XGETBV instructions bool XSAVE = CpuFeatures::XSAVE(); //! Fused multiply-add bool FMA = CpuFeatures::FMA(); //! ADCX and ADOX instructions bool ADX = CpuFeatures::ADX(); //! Galois Field New Instructions bool GFNI = CpuFeatures::GFNI(); //! AES instructions bool AES = CpuFeatures::AES(); //! Vector AES bool VAES = CpuFeatures::VAES(); //! RDSEED instruction bool RDSEED = CpuFeatures::RDSEED(); //! RDRAND instruction bool RDRAND = CpuFeatures::RDRAND(); //! SHA1/SHA256 Instruction Extensions bool SHA = CpuFeatures::SHA(); //! 1st group bit manipulation extensions bool BMI1 = CpuFeatures::BMI1(); //! 2nd group bit manipulation extensions bool BMI2 = CpuFeatures::BMI2(); //! CLFLUSH instruction bool CLFLUSH = CpuFeatures::CLFLUSH(); //! CLFLUSHOPT instruction bool CLFLUSHOPT = CpuFeatures::CLFLUSHOPT(); //! CLWB instruction bool CLWB = CpuFeatures::CLWB(); //! RDPID instruction bool RDPID = CpuFeatures::RDPID(); //! Onboard FPU bool FPU = CpuFeatures::FPU(); //! Hyper-Threading bool HT = CpuFeatures::HT(); //! Hardware virtualization bool VMX = CpuFeatures::VMX(); // !Running on a hypervisor bool HYPERVISOR = CpuFeatures::HYPERVISOR(); }; static StaticFlags static_flags_; }; } // namespace internal } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/io/file.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #if !defined(_WIN64) && !defined(_WIN32) #include #include #include #include #include #include #include #else #include #endif namespace zvec { namespace ailego { #if !defined(_WIN64) && !defined(_WIN32) static inline int OpenSafely(const char *path, int flags) { int fd = open(path, flags, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); while (fd == -1 && errno == EINTR) { fd = open(path, flags, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); } return fd; } static inline void CloseSafely(int fd) { int ret = close(fd); while (ret == -1 && errno == EINTR) { ret = close(fd); } } static inline ssize_t ReadSafely(int fd, void *buf, size_t count) { ssize_t ret = read(fd, buf, count); while (ret == -1 && errno == EINTR) { ret = read(fd, buf, count); } return ret; } static inline ssize_t PreadSafely(int fd, void *buf, size_t count, ssize_t offset) { ssize_t ret = pread(fd, buf, count, offset); while (ret == -1 && errno == EINTR) { ret = pread(fd, buf, count, offset); } return ret; } static inline ssize_t WriteSafely(int fd, const void *buf, size_t count) { ssize_t ret = write(fd, buf, count); while (ret == -1 && errno == EINTR) { ret = write(fd, buf, count); } return ret; } static inline ssize_t PwriteSafely(int fd, const void *buf, size_t count, ssize_t offset) { ssize_t ret = pwrite(fd, buf, count, offset); while (ret == -1 && errno == EINTR) { ret = pwrite(fd, buf, count, offset); } return ret; } static inline size_t ReadAll(int fd, void *buf, size_t count) { size_t rdlen = 0; while (rdlen < count) { ssize_t ret = ReadSafely(fd, (char *)buf + rdlen, count - rdlen); if (ret <= 0) { break; } rdlen += ret; } return rdlen; } static inline size_t PreadAll(int fd, void *buf, size_t count, ssize_t offset) { size_t rdlen = 0; while (rdlen < count) { ssize_t ret = PreadSafely(fd, (char *)buf + rdlen, count - rdlen, offset + rdlen); if (ret <= 0) { break; } rdlen += ret; } return rdlen; } static inline size_t WriteAll(int fd, const void *buf, size_t count) { size_t wrlen = 0; while (wrlen < count) { ssize_t ret = WriteSafely(fd, (const char *)buf + wrlen, count - wrlen); if (ret <= 0) { break; } wrlen += ret; } return wrlen; } static inline size_t PwriteAll(int fd, const void *buf, size_t count, ssize_t offset) { size_t wrlen = 0; while (wrlen < count) { ssize_t ret = PwriteSafely(fd, (const char *)buf + wrlen, count - wrlen, offset + wrlen); if (ret <= 0) { break; } wrlen += ret; } return wrlen; } bool File::create(const char *path, size_t len, bool direct) { ailego_false_if_false(native_handle_ == File::InvalidHandle && path); // Try opening or creating a file int flags = O_RDWR | O_CREAT; #ifdef O_DIRECT if (direct) { flags |= O_DIRECT; } #else (void)direct; #endif int fd = OpenSafely(path, flags); ailego_false_if_lt_zero(fd); #ifdef F_NOCACHE // Direct IO canonical solution for Mac OSX if (direct) { ailego_false_if_ne_zero(fcntl(fd, F_NOCACHE, 1)); } #endif // Truncate the file to the specified size ailego_do_if_ne_zero(ftruncate(fd, len)) { CloseSafely(fd); return false; } read_only_ = false; native_handle_ = fd; return true; } bool File::open(const char *path, bool rdonly, bool direct) { ailego_false_if_false(native_handle_ == File::InvalidHandle && path); // Try opening the file int flags = rdonly ? O_RDONLY : O_RDWR; #ifdef O_DIRECT if (direct) { flags |= O_DIRECT; } #else (void)direct; #endif int fd = OpenSafely(path, flags); ailego_false_if_lt_zero(fd); #ifdef F_NOCACHE // Direct IO canonical solution for Mac OSX if (direct) { ailego_false_if_ne_zero(fcntl(fd, F_NOCACHE, 1)); } #endif read_only_ = rdonly; native_handle_ = fd; return true; } void File::close(void) { ailego_return_if_false(native_handle_ != File::InvalidHandle); CloseSafely(native_handle_); native_handle_ = File::InvalidHandle; } void File::reset(void) { ailego_return_if_false(native_handle_ != File::InvalidHandle); lseek(native_handle_, 0, SEEK_SET); } size_t File::write(const void *data, size_t len) { const size_t block_size = 0x40000000u; size_t total = 0u; for (; len >= block_size; len -= block_size) { size_t wrlen = WriteAll(native_handle_, (const uint8_t *)data + total, block_size); if (wrlen != block_size) { return (total + wrlen); } total += block_size; } if (len > 0) { total += WriteAll(native_handle_, (const uint8_t *)data + total, len); } return total; } size_t File::write(ssize_t off, const void *data, size_t len) { const size_t block_size = 0x40000000u; size_t total = 0u; for (; len >= block_size; len -= block_size) { size_t wrlen = PwriteAll(native_handle_, (const uint8_t *)data + total, block_size, off + total); if (wrlen != block_size) { return (total + wrlen); } total += block_size; } if (len > 0) { total += PwriteAll(native_handle_, (const uint8_t *)data + total, len, off + total); } return total; } size_t File::read(void *buf, size_t len) { const size_t block_size = 0x40000000u; size_t total = 0u; for (; len >= block_size; len -= block_size) { size_t rdlen = ReadAll(native_handle_, (uint8_t *)buf + total, block_size); if (rdlen != block_size) { return (total + rdlen); } total += block_size; } if (len > 0) { total += ReadAll(native_handle_, (uint8_t *)buf + total, len); } return total; } size_t File::read(ssize_t off, void *buf, size_t len) { const size_t block_size = 0x40000000u; size_t total = 0u; for (; len >= block_size; len -= block_size) { size_t rdlen = PreadAll(native_handle_, (uint8_t *)buf + total, block_size, off + total); if (rdlen != block_size) { return (total + rdlen); } total += block_size; } if (len > 0) { total += PreadAll(native_handle_, (uint8_t *)buf + total, len, off + total); } return total; } bool File::flush(void) { ailego_false_if_false(native_handle_ != File::InvalidHandle); return (fsync(native_handle_) == 0); } bool File::seek(ssize_t off, Origin origin) { ailego_false_if_false(native_handle_ != File::InvalidHandle); ailego_false_if_false(lseek(native_handle_, off, (int)origin) != (off_t)-1); return true; } bool File::truncate(size_t len) { ailego_false_if_false(native_handle_ != File::InvalidHandle); ailego_false_if_ne_zero(ftruncate(native_handle_, (off_t)len)); return true; } size_t File::size(void) const { struct stat fs; ailego_zero_if_false(native_handle_ != File::InvalidHandle && fstat(native_handle_, &fs) == 0); return (fs.st_size); } ssize_t File::offset(void) const { off_t off; ailego_zero_if_false(native_handle_ != File::InvalidHandle && (off = lseek(native_handle_, 0, SEEK_CUR)) != -1); return off; } void *File::MemoryMap(NativeHandle handle, ssize_t off, size_t len, int opts) { int prot = ((opts & File::MMAP_READONLY) ? PROT_READ : PROT_READ | PROT_WRITE); #if defined(MAP_POPULATE) if (opts & File::MMAP_POPULATE) { prot |= MAP_POPULATE; } #endif int flags = (opts & File::MMAP_SHARED) ? MAP_SHARED : MAP_PRIVATE; #if defined(MAP_HUGETLB) if (opts & File::MMAP_HUGE_PAGE) { flags |= MAP_HUGETLB; } #endif void *addr = mmap(nullptr, len, prot, flags, handle, off); ailego_null_if_false(addr != MAP_FAILED); if (opts & File::MMAP_LOCKED) { mlock(addr, len); } if (opts & File::MMAP_WARMUP) { File::MemoryWarmup(addr, len); } return addr; } #if !defined(MAP_ANONYMOUS) && defined(MAP_ANON) #define MAP_ANONYMOUS MAP_ANON #endif void *File::MemoryMap(size_t len, int opts) { #if defined(MAP_ANONYMOUS) int prot = ((opts & File::MMAP_READONLY) ? PROT_READ : PROT_READ | PROT_WRITE); #if defined(MAP_POPULATE) if (opts & File::MMAP_POPULATE) { prot |= MAP_POPULATE; } #endif int flags = (opts & File::MMAP_SHARED) ? MAP_SHARED | MAP_ANONYMOUS : MAP_PRIVATE | MAP_ANONYMOUS; #if defined(MAP_HUGETLB) if (opts & File::MMAP_HUGE_PAGE) { flags |= MAP_HUGETLB; } #endif void *addr = mmap(nullptr, len, prot, flags, -1, 0); ailego_null_if_false(addr != MAP_FAILED); return addr; #else (void)len; (void)opts; return nullptr; #endif // MAP_ANONYMOUS } void *File::MemoryRemap(void *oldptr, size_t oldsize, void *newptr, size_t newsize) { #if defined(__linux) || defined(__linux__) return newptr ? mremap(oldptr, oldsize, newsize, MREMAP_FIXED, newptr) : mremap(oldptr, oldsize, newsize, MREMAP_MAYMOVE); #elif defined(__NetBSD__) return newptr ? mremap(oldptr, oldsize, newptr, newsize, MAP_FIXED) : mremap(oldptr, oldsize, nullptr, newsize, 0); #else (void)oldptr; (void)oldsize; (void)newptr; (void)newsize; errno = ENOTSUP; return nullptr; #endif } void File::MemoryUnmap(void *addr, size_t len) { ailego_return_if_false(addr); munmap(addr, len); } bool File::MemoryFlush(void *addr, size_t len) { ailego_false_if_false(addr); return (msync(addr, len, MS_ASYNC) == 0); } bool File::MemoryLock(void *addr, size_t len) { ailego_false_if_false(addr && len); return (mlock(addr, len) == 0); } bool File::MemoryUnlock(void *addr, size_t len) { ailego_false_if_false(addr && len); return (munlock(addr, len) == 0); } #else //! Create a local file bool File::create(const char *path, size_t len, bool direct) { ailego_false_if_false(native_handle_ == File::InvalidHandle && path); // Try opening or creating the file HANDLE file_handle = CreateFileA(path, GENERIC_WRITE | GENERIC_READ, FILE_SHARE_READ, nullptr, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr); ailego_false_if_false(file_handle != INVALID_HANDLE_VALUE); // Truncate the file to the specified size LARGE_INTEGER file_size; file_size.QuadPart = len; ailego_do_if_false( SetFilePointerEx(file_handle, file_size, nullptr, FILE_BEGIN) && SetEndOfFile(file_handle)) { CloseHandle(file_handle); return false; } if (!direct) { // Reset the file pointer SetFilePointer(file_handle, 0, nullptr, FILE_BEGIN); } else { // Close and reopen file CloseHandle(file_handle); file_handle = CreateFileA( path, GENERIC_WRITE | GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_NO_BUFFERING, nullptr); ailego_false_if_false(file_handle != INVALID_HANDLE_VALUE); } read_only_ = false; native_handle_ = file_handle; return true; } //! Open a local file bool File::open(const char *path, bool rdonly, bool direct) { ailego_false_if_false(native_handle_ == File::InvalidHandle && path); // Try opening the file DWORD flags = FILE_ATTRIBUTE_NORMAL; if (direct) { flags |= FILE_FLAG_NO_BUFFERING; } HANDLE file_handle = CreateFileA(path, (rdonly ? GENERIC_READ : GENERIC_READ | GENERIC_WRITE), FILE_SHARE_READ, nullptr, OPEN_EXISTING, flags, nullptr); ailego_false_if_false(file_handle != INVALID_HANDLE_VALUE); read_only_ = rdonly; native_handle_ = file_handle; return true; } void File::close(void) { ailego_return_if_false(native_handle_ != File::InvalidHandle); CloseHandle(native_handle_); native_handle_ = File::InvalidHandle; } void File::reset(void) { ailego_return_if_false(native_handle_ != File::InvalidHandle); SetFilePointer(native_handle_, 0, nullptr, FILE_BEGIN); } size_t File::write(const void *data, size_t len) { const DWORD block_size = 0x40000000u; DWORD wrlen = 0u; size_t total = 0u; for (; len >= block_size; len -= block_size) { if (!WriteFile(native_handle_, (const uint8_t *)data + total, block_size, &wrlen, nullptr)) { return total; } if (wrlen != block_size) { return (total + wrlen); } total += block_size; } if (len > 0 && WriteFile(native_handle_, (const uint8_t *)data + total, (DWORD)len, &wrlen, nullptr)) { total += wrlen; } return total; } size_t File::write(ssize_t off, const void *data, size_t len) { const DWORD block_size = 0x40000000u; DWORD wrlen = 0u; size_t total = 0u; OVERLAPPED overlapped; memset(&overlapped, 0, sizeof(OVERLAPPED)); for (; len >= block_size; len -= block_size) { uint64_t current = off + total; overlapped.OffsetHigh = (DWORD)(current >> 32); overlapped.Offset = (DWORD)(current & 0xffffffffu); if (!WriteFile(native_handle_, (const uint8_t *)data + total, block_size, &wrlen, &overlapped)) { return total; } if (wrlen != block_size) { return (total + wrlen); } total += block_size; } if (len > 0) { uint64_t current = off + total; overlapped.OffsetHigh = (DWORD)(current >> 32); overlapped.Offset = (DWORD)(current & 0xffffffffu); if (WriteFile(native_handle_, (const uint8_t *)data + total, (DWORD)len, &wrlen, &overlapped)) { total += wrlen; } } return total; } size_t File::read(void *buf, size_t len) { const DWORD block_size = 0x40000000u; DWORD rdlen = 0u; size_t total = 0u; for (; len >= block_size; len -= block_size) { if (!ReadFile(native_handle_, (uint8_t *)buf + total, block_size, &rdlen, nullptr)) { return total; } if (rdlen != block_size) { return (total + rdlen); } total += block_size; } if (len > 0 && ReadFile(native_handle_, (uint8_t *)buf + total, (DWORD)len, &rdlen, nullptr)) { total += rdlen; } return total; } size_t File::read(ssize_t off, void *buf, size_t len) { const DWORD block_size = 0x40000000u; DWORD rdlen = 0u; size_t total = 0u; OVERLAPPED overlapped; memset(&overlapped, 0, sizeof(OVERLAPPED)); for (; len >= block_size; len -= block_size) { uint64_t current = off + total; overlapped.OffsetHigh = (DWORD)(current >> 32); overlapped.Offset = (DWORD)(current & 0xffffffffu); if (!ReadFile(native_handle_, (uint8_t *)buf + total, block_size, &rdlen, &overlapped)) { return total; } if (rdlen != block_size) { return (total + rdlen); } total += block_size; } if (len > 0) { uint64_t current = off + total; overlapped.OffsetHigh = (DWORD)(current >> 32); overlapped.Offset = (DWORD)(current & 0xffffffffu); if (ReadFile(native_handle_, (uint8_t *)buf + total, (DWORD)len, &rdlen, &overlapped)) { total += rdlen; } } return total; } bool File::flush(void) { ailego_false_if_false(native_handle_ != File::InvalidHandle); return (!!FlushFileBuffers(native_handle_)); } bool File::seek(ssize_t off, Origin origin) { ailego_false_if_false(native_handle_ != File::InvalidHandle); LARGE_INTEGER file_offset; file_offset.QuadPart = off; ailego_false_if_false(SetFilePointerEx(native_handle_, file_offset, nullptr, (DWORD)origin) != 0); return true; } bool File::truncate(size_t len) { ailego_false_if_false(native_handle_ != File::InvalidHandle); LARGE_INTEGER file_size, orig_file_size; file_size.QuadPart = 0; orig_file_size.QuadPart = 0; ailego_false_if_false(SetFilePointerEx(native_handle_, file_size, &orig_file_size, FILE_CURRENT)); // Truncate the file to the specified size file_size.QuadPart = len; ailego_false_if_false( SetFilePointerEx(native_handle_, file_size, nullptr, FILE_BEGIN) && SetEndOfFile(native_handle_)); // Reset the file pointer SetFilePointerEx(native_handle_, orig_file_size, nullptr, FILE_BEGIN); return true; } size_t File::size(void) const { LARGE_INTEGER file_size; ailego_zero_if_false(native_handle_ != File::InvalidHandle && GetFileSizeEx(native_handle_, &file_size)); return (size_t)file_size.QuadPart; } ssize_t File::offset(void) const { LARGE_INTEGER file_size; LARGE_INTEGER file_size_new; file_size.QuadPart = 0; ailego_zero_if_false(native_handle_ != File::InvalidHandle && SetFilePointerEx(native_handle_, file_size, &file_size_new, FILE_CURRENT)); return (size_t)file_size_new.QuadPart; } void *File::MemoryMap(NativeHandle handle, ssize_t off, size_t len, int opts) { LARGE_INTEGER file_size; file_size.QuadPart = len; // Create map object HANDLE file_mapping = CreateFileMapping( handle, nullptr, ((opts & File::MMAP_READONLY) ? PAGE_READONLY : PAGE_READWRITE), file_size.HighPart, file_size.LowPart, nullptr); ailego_null_if_false(file_mapping != nullptr); DWORD desired_access = FILE_MAP_READ; if (!(opts & File::MMAP_READONLY)) { desired_access |= FILE_MAP_WRITE; } if (!(opts & File::MMAP_SHARED)) { desired_access |= FILE_MAP_COPY; } file_size.QuadPart = off; // Map the whole file to memory and close handle void *addr = MapViewOfFile(file_mapping, desired_access, file_size.HighPart, file_size.LowPart, 0); CloseHandle(file_mapping); ailego_null_if_false(addr); if (opts & File::MMAP_LOCKED) { VirtualLock(addr, len); } if (opts & File::MMAP_WARMUP) { File::MemoryWarmup(addr, len); } return addr; } void *File::MemoryMap(size_t, int) { return nullptr; } void *File::MemoryRemap(void *, size_t, void *, size_t) { return nullptr; } void File::MemoryUnmap(void *addr, size_t /*len*/) { ailego_return_if_false(addr); UnmapViewOfFile(addr); } bool File::MemoryFlush(void *addr, size_t /*len*/) { ailego_false_if_false(addr); return (!!FlushViewOfFile(addr, 0)); } bool File::MemoryLock(void *addr, size_t len) { ailego_false_if_false(addr && len); return (!!VirtualLock(addr, len)); } bool File::MemoryUnlock(void *addr, size_t len) { ailego_false_if_false(addr && len); return (!!VirtualUnlock(addr, len)); } static inline int getpagesize(void) { SYSTEM_INFO info; GetSystemInfo(&info); return info.dwPageSize; } #endif void File::MemoryWarmup(void *addr, size_t len) { static int page_size = getpagesize(); if (addr && len) { uint8_t *p = reinterpret_cast(addr); uint8_t *end = p + len; volatile uint8_t tmp = 0; while (p < end) { tmp ^= *p; p += page_size; } } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/io/file_lock.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "file_lock.h" #if !defined(_WIN64) && !defined(_WIN32) #include #else #include #endif namespace zvec { namespace ailego { #if !defined(_WIN64) && !defined(_WIN32) bool FileLock::Lock(int fd) { return (flock(fd, LOCK_EX) == 0); } bool FileLock::TryLock(int fd) { return (flock(fd, LOCK_EX | LOCK_NB) == 0); } bool FileLock::LockShared(int fd) { return (flock(fd, LOCK_SH) == 0); } bool FileLock::TryLockShared(int fd) { return (flock(fd, LOCK_SH | LOCK_NB) == 0); } bool FileLock::Unlock(int fd) { return (flock(fd, LOCK_UN) == 0); } #else bool FileLock::Lock(HANDLE handle) { OVERLAPPED ol = {0}; return (!!LockFileEx(handle, LOCKFILE_EXCLUSIVE_LOCK, 0, MAXDWORD, MAXDWORD, &ol)); } bool FileLock::TryLock(HANDLE handle) { OVERLAPPED ol = {0}; return (!!LockFileEx(handle, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, &ol)); } bool FileLock::LockShared(HANDLE handle) { OVERLAPPED ol = {0}; return (!!LockFileEx(handle, 0, 0, MAXDWORD, MAXDWORD, &ol)); } bool FileLock::TryLockShared(HANDLE handle) { OVERLAPPED ol = {0}; return (!!LockFileEx(handle, LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, &ol)); } bool FileLock::Unlock(HANDLE handle) { OVERLAPPED ol = {0}; return (!!UnlockFileEx(handle, 0, MAXDWORD, MAXDWORD, &ol)); } #endif // !_WIN64 && !_WIN32 } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/io/file_lock.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { /*! File Utility */ class FileLock { public: //! Constructor FileLock(const File &file) : native_handle_(file.native_handle()) {} //! Constructor FileLock(File::NativeHandle handle) : native_handle_(handle) {} //! Locking bool lock(void) const { return FileLock::Lock(native_handle_); } //! Try locking bool try_lock(void) const { return FileLock::TryLock(native_handle_); } //! Locking (shared) bool lock_shared(void) const { return FileLock::LockShared(native_handle_); } //! Try locking (shared) bool try_lock_shared(void) const { return FileLock::TryLockShared(native_handle_); } //! Unlocking bool unlock(void) const { return FileLock::Unlock(native_handle_); } //! Locking static bool Lock(File::NativeHandle handle); //! Try locking static bool TryLock(File::NativeHandle handle); //! Locking (shared) static bool LockShared(File::NativeHandle handle); //! Try locking (shared) static bool TryLockShared(File::NativeHandle handle); //! Unlocking static bool Unlock(File::NativeHandle handle); private: //! Disable them FileLock(const FileLock &) = delete; FileLock(FileLock &&) = delete; FileLock &operator=(const FileLock &) = delete; FileLock &operator=(FileLock &&) = delete; //! Members File::NativeHandle native_handle_; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/io/file_writer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "file.h" namespace zvec { namespace ailego { /*! File Stream Writer */ class FileWriter { public: //! Constructor FileWriter(void) {} //! Constructor FileWriter(FileWriter &&rhs) : file_(std::move(rhs.file_)) {} //! Destructor ~FileWriter(void) {} //! Assignment FileWriter &operator=(FileWriter &&rhs) { file_ = std::move(rhs.file_); return *this; } //! Output to writer FileWriter &operator<<(const char *str) { size_t len = std::strlen(str); if (file_.write(str, len) != len) { throw std::ios_base::failure("Write error"); } return *this; } //! Output to writer FileWriter &operator<<(const std::string &str) { if (file_.write(str.data(), str.size()) != str.size()) { throw std::ios_base::failure("Write error"); } return *this; } //! Output to writer FileWriter &operator<<(char c) { if (file_.write(&c, 1) != 1) { throw std::ios_base::failure("Write error"); } return *this; } //! Test if the file is valid bool is_valid(void) const { return file_.is_valid(); } //! Create a local file bool create(const char *path) { return file_.create(path, 0, false); } //! Open a local file bool open(const char *path) { return file_.open(path, false, false); } //! Close the local file void close(void) { file_.close(); } //! Write data into the file size_t write(const void *data, size_t len) { return file_.write(data, len); } //! Synchronize memory with physical storage bool flush(void) { return file_.flush(); } //! Output with format void print(const char *format, va_list args) { char buf[8192]; std::vsnprintf(buf, sizeof(buf), format, args); (*this) << buf; } //! Output with format #if defined(__GNUC__) void print(const char *format, ...) __attribute__((format(printf, 2, 3))) { #else void print(const char *format, ...) { #endif va_list args; va_start(args, format); this->print(format, args); va_end(args); } private: //! Disable them FileWriter(const FileWriter &) = delete; FileWriter &operator=(const FileWriter &) = delete; //! Members File file_; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/logger/logger.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include namespace zvec { namespace ailego { const int Logger::LEVEL_DEBUG = 0; const int Logger::LEVEL_INFO = 1; const int Logger::LEVEL_WARN = 2; const int Logger::LEVEL_ERROR = 3; const int Logger::LEVEL_FATAL = 4; /*! Console Logger */ struct ConsoleLogger : public Logger { //! Initialize Logger int init(const Params &) override { return 0; } //! Cleanup Logger int cleanup(void) override { return 0; } //! Log Message void log(int level, const char *file, int line, const char *format, va_list args) override { char buffer[8192]; std::ostringstream stream; ailego::Realtime::Localtime(buffer, sizeof(buffer)); stream << '[' << LevelString(level) << ' ' << buffer << ' ' << std::this_thread::get_id() << ' ' << ailego::File::BaseName(file) << ':' << line << "] "; vsnprintf(buffer, sizeof(buffer), format, args); stream << buffer << '\n'; if (level <= LEVEL_INFO) { std::cout << stream.str() << std::flush; } else { std::cerr << stream.str() << std::flush; } } }; //! Logger Level int LoggerBroker::logger_level_ = Logger::LEVEL_WARN; //! Logger Logger::Pointer LoggerBroker::logger_(new ConsoleLogger); //! Register Console Logger in Factory FACTORY_REGISTER_LOGGER(ConsoleLogger); } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/cosine_distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "inner_product_matrix.h" namespace zvec { namespace ailego { /*! Cosine Distance Matrix */ template struct CosineDistanceMatrix; /*! Cosine Distance Matrix (M=1, N=1) */ template struct CosineDistanceMatrix< T, 1, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); constexpr size_t extra_dim = sizeof(float) / sizeof(ValueType); size_t d = dim - extra_dim; float ip; InnerProductMatrix::Compute(m, q, d, &ip); *out = 1 - ip; } }; /*! Cosine Distance Matrix */ template struct CosineDistanceMatrix< T, M, N, typename std::enable_if::value && M >= 2 && N >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType * /*m*/, const ValueType * /*q*/, size_t /*dim*/, float *out) { // ailego_assert(m && q && dim && out); *out = 0.0f; } }; /*! Cosine Distance Matrix (N=1) */ template struct CosineDistanceMatrix< T, M, 1, typename std::enable_if::value && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType * /*m*/, const ValueType * /*q*/, size_t /*dim*/, float *out) { // ailego_assert(m && q && dim && out); *out = 0.0f; } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/distance.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "distance_matrix.h" namespace zvec { namespace ailego { /*! Distance module */ struct Distance { //! Compute the hamming distance between two vectors (BINARY) static float Hamming(const uint32_t *lhs, const uint32_t *rhs, size_t dim) { float result; HammingDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } #if defined(AILEGO_M64) //! Compute the hamming distance between two vectors (BINARY) static float Hamming(const uint64_t *lhs, const uint64_t *rhs, size_t dim) { float result; HammingDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } #else //! Compute the hamming distance between two vectors (BINARY) static float Hamming(const uint64_t *lhs, const uint64_t *rhs, size_t dim) { float result; HammingDistanceMatrix::Compute( reinterpret_cast(lhs), reinterpret_cast(rhs), dim, &result); return result; } #endif //! Compute the squared euclidean distance between two vectors (FP32) static float SquaredEuclidean(const float *lhs, const float *rhs, size_t dim) { float result; SquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the squared euclidean distance between two vectors (FP16) static float SquaredEuclidean(const Float16 *lhs, const Float16 *rhs, size_t dim) { float result; SquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the squared euclidean distance between two vectors (INT8) static float SquaredEuclidean(const int8_t *lhs, const int8_t *rhs, size_t dim) { float result; SquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the squared euclidean distance between two vectors (INT4) static float SquaredEuclidean(const uint8_t *lhs, const uint8_t *rhs, size_t dim) { float result; SquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the euclidean distance between two vectors (FP32) static float Euclidean(const float *lhs, const float *rhs, size_t dim) { float result; EuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the euclidean distance between two vectors (FP16) static float Euclidean(const Float16 *lhs, const Float16 *rhs, size_t dim) { float result; EuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the euclidean distance between two vectors (INT8) static float Euclidean(const int8_t *lhs, const int8_t *rhs, size_t dim) { float result; EuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the euclidean distance between two vectors (INT4) static float Euclidean(const uint8_t *lhs, const uint8_t *rhs, size_t dim) { float result; EuclideanDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the inner product between two vectors (FP32) static float InnerProduct(const float *lhs, const float *rhs, size_t dim) { float result; InnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the inner product between two vectors (FP16) static float InnerProduct(const Float16 *lhs, const Float16 *rhs, size_t dim) { float result; InnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the inner product between two vectors (INT8) static float InnerProduct(const int8_t *lhs, const int8_t *rhs, size_t dim) { float result; InnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the minus inner product between two vectors (INT4) static float InnerProduct(const uint8_t *lhs, const uint8_t *rhs, size_t dim) { float result; InnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the minus inner product between two vectors (FP32) static float MinusInnerProduct(const float *lhs, const float *rhs, size_t dim) { float result; MinusInnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the minus inner product between two vectors (FP16) static float MinusInnerProduct(const Float16 *lhs, const Float16 *rhs, size_t dim) { float result; MinusInnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the minus inner product between two vectors (INT8) static float MinusInnerProduct(const int8_t *lhs, const int8_t *rhs, size_t dim) { float result; MinusInnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the minus inner product between two vectors (INT4) static float MinusInnerProduct(const uint8_t *lhs, const uint8_t *rhs, size_t dim) { float result; MinusInnerProductMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (FP32, RepeatedQuadraticInjection) static float MipsSquaredEuclidean(const float *lhs, const float *rhs, size_t dim, size_t m, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, m, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (FP16, RepeatedQuadraticInjection) static float MipsSquaredEuclidean(const Float16 *lhs, const Float16 *rhs, size_t dim, size_t m, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, m, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (INT8, RepeatedQuadraticInjection) static float MipsSquaredEuclidean(const int8_t *lhs, const int8_t *rhs, size_t dim, size_t m, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, m, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (INT4, RepeatedQuadraticInjection) static float MipsSquaredEuclidean(const uint8_t *lhs, const uint8_t *rhs, size_t dim, size_t m, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, m, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (FP32, SphericalInjection) static float MipsSquaredEuclidean(const float *lhs, const float *rhs, size_t dim, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (FP16, SphericalInjection) static float MipsSquaredEuclidean(const Float16 *lhs, const Float16 *rhs, size_t dim, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (INT8, SphericalInjection) static float MipsSquaredEuclidean(const int8_t *lhs, const int8_t *rhs, size_t dim, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, eta, &result); return result; } //! Compute the mips squared L2 distance between two vectors //! (INT4, SphericalInjection) static float MipsSquaredEuclidean(const uint8_t *lhs, const uint8_t *rhs, size_t dim, float eta) { float result; MipsSquaredEuclideanDistanceMatrix::Compute(lhs, rhs, dim, eta, &result); return result; } //! Compute the cosine distance between two vectors (FP32) static float Cosine(const float *lhs, const float *rhs, size_t dim) { float result; CosineDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the cosine distance between two vectors (FP16) static float Cosine(const Float16 *lhs, const Float16 *rhs, size_t dim) { float result; CosineDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } //! Compute the cosine distance between two vectors (FP16) static float Cosine(const int8_t *lhs, const int8_t *rhs, size_t dim) { float result; CosineDistanceMatrix::Compute(lhs, rhs, dim, &result); return result; } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "cosine_distance_matrix.h" #include "euclidean_distance_matrix.h" #include "hamming_distance_matrix.h" #include "inner_product_matrix.h" #include "mips_euclidean_distance_matrix.h" ================================================ FILE: src/ailego/math/distance_matrix_accum_fp16.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_fp16.i" #include "matrix_utility.i" #if !defined(__FMA__) #define _mm_fmadd_ps(a, b, c) _mm_add_ps(_mm_mul_ps((a), (b)), (c)) #define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(_mm256_mul_ps((a), (b)), (c)) #endif // !__FMA__ #if defined(__AVX512F__) && !defined(__AVX512DQ__) #define _mm512_and_ps(a, b) \ _mm512_castsi512_ps( \ _mm512_and_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) #define _mm512_mask_and_ps(src, k, a, b) \ _mm512_castsi512_ps(_mm512_mask_and_epi32(_mm512_castps_si512(src), (k), \ _mm512_castps_si512(a), \ _mm512_castps_si512(b))) #endif // __AVX512DQ__ //! Compute the distance between matrix and query (FP16, M=1, N=1) #define ACCUM_FP16_1X1_AVX(m, q, dim, out, _MASK, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe = q + dim; \ const Float16 *qe_aligned = q + ((dim >> 4) << 4); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_aligned; m += 16, q += 16) { \ MATRIX_FP16_ITER_1X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ if (qe >= qe_aligned + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)m)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)q)); \ ACCUM_FP32_STEP_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ m += 8; \ q += 8; \ } \ } else { \ for (; q != qe_aligned; m += 16, q += 16) { \ MATRIX_FP16_ITER_1X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ if (qe >= qe_aligned + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)m)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q)); \ ACCUM_FP32_STEP_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ m += 8; \ q += 8; \ } \ } \ MATRIX_FP16_MASK_AVX(m, q, (qe - q), _MASK, ymm_sum, ACCUM_FP32_STEP_AVX) \ *out = _NORM(HorizontalAdd_FP32_V256(ymm_sum_0_0)); //! Compute the distance between matrix and query (FP16, M=2, N=1) #define ACCUM_FP16_2X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe_aligned = q + ((dim >> 2) << 2); \ const Float16 *qe = q + dim; \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP16_ITER_2X1_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP16_ITER_2X1_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ if (qe >= qe_aligned + 2) { \ __m128 xmm_m = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(m))); \ __m128 xmm_q = _mm_cvtph_ps( \ _mm_shufflelo_epi16(_mm_broadcast_si32(q), _MM_SHUFFLE(1, 1, 0, 0))); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ xmm_sum_0_0 = \ _mm_add_ps(xmm_sum_0_0, _mm_movehl_ps(xmm_sum_0_0, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_cvtph_ps( \ _mm_shufflelo_epi16(_mm_broadcast_si32(m), _MM_SHUFFLE(0, 0, 1, 0))); \ __m128 xmm_q = _mm_cvtph_ps(_mm_set1_epi16(*(const short *)(q))); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (FP16, M=2, N=2) #define ACCUM_FP16_2X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe_aligned = q + ((dim >> 2) << 3); \ const Float16 *qe = q + (dim << 1); \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP16_ITER_2X2_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP16_ITER_2X2_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ if (qe >= qe_aligned + 4) { \ __m128 xmm_m = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(m))); \ __m128 xmm_q = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(q))); \ __m128 xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(2, 2, 0, 0)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_0) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(3, 3, 1, 1)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_1) \ m += 4; \ q += 4; \ } \ xmm_sum_0_0 = _mm_add_ps(_mm_movelh_ps(xmm_sum_0_0, xmm_sum_0_1), \ _mm_movehl_ps(xmm_sum_0_1, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_cvtph_ps( \ _mm_shufflelo_epi16(_mm_broadcast_si32(m), _MM_SHUFFLE(1, 0, 1, 0))); \ __m128 xmm_q = _mm_cvtph_ps( \ _mm_shufflelo_epi16(_mm_broadcast_si32(q), _MM_SHUFFLE(1, 1, 0, 0))); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=4, N=1) #define ACCUM_FP16_4X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe = q + dim; \ if (((uintptr_t)m & 0xf) == 0) { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP16_ITER_4X1_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP16_ITER_4X1_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ if (q != qe) { \ __m128 xmm_m = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(m))); \ __m128 xmm_q = _mm_cvtph_ps(_mm_set1_epi16(*(const short *)(q))); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=4, N=2) #define ACCUM_FP16_4X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe = q + (dim << 1); \ if (((uintptr_t)m & 0xf) == 0) { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 8, q += 4) { \ MATRIX_FP16_ITER_4X2_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 8, q += 4) { \ MATRIX_FP16_ITER_4X2_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ if (q != qe) { \ __m128 xmm_q_0 = _mm_cvtph_ps(_mm_set1_epi16(*(const short *)(q + 0))); \ __m128 xmm_q_1 = _mm_cvtph_ps(_mm_set1_epi16(*(const short *)(q + 1))); \ __m128 xmm_m = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(m))); \ MATRIX_VAR_PROC(1, 2, 0, xmm_m, xmm_q, xmm_sum, ACCUM_FP32_STEP_SSE) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=4, N=4) #define ACCUM_FP16_4X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *qe = q + (dim << 2); \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 3); q != qe_aligned; \ m += 8, q += 8) { \ MATRIX_FP16_ITER_4X4_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe_aligned = q + ((dim >> 1) << 3); q != qe_aligned; \ m += 8, q += 8) { \ MATRIX_FP16_ITER_4X4_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ __m128 xmm_sum_0_2 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_2), \ _mm256_extractf128_ps(ymm_sum_0_2, 1)); \ __m128 xmm_sum_0_3 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_3), \ _mm256_extractf128_ps(ymm_sum_0_3, 1)); \ if (q != qe) { \ __m128 xmm_m = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(m))); \ __m128 xmm_q = _mm_cvtph_ps(_mm_set1_epi64(*(const __m64 *)(q))); \ __m128 xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(0, 0, 0, 0)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_0) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(1, 1, 1, 1)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_1) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(2, 2, 2, 2)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_2) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(3, 3, 3, 3)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_3) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=8, N=1) #define ACCUM_FP16_8X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const Float16 *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP16_ITER_8X1_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP16_ITER_8X1_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=8, N=2) #define ACCUM_FP16_8X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP16_ITER_8X2_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP16_ITER_8X2_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=8, N=4) #define ACCUM_FP16_8X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP16_ITER_8X4_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP16_ITER_8X4_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=8, N=8) #define ACCUM_FP16_8X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP16_ITER_8X8_AVX(m, q, ymm_sum, _mm_load_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP16_ITER_8X8_AVX(m, q, ymm_sum, _mm_loadu_si128, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=1) #define ACCUM_FP16_16X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP16_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP16_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=2) #define ACCUM_FP16_16X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP16_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP16_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=4) #define ACCUM_FP16_16X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP16_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP16_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=8) #define ACCUM_FP16_16X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP16_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP16_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=16) #define ACCUM_FP16_16X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP16_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP16_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=1) #define ACCUM_FP16_32X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP16_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP16_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=2) #define ACCUM_FP16_32X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP16_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP16_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=4) #define ACCUM_FP16_32X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP16_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP16_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=8) #define ACCUM_FP16_32X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP16_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP16_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=16) #define ACCUM_FP16_32X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP16_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP16_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=32) #define ACCUM_FP16_32X32_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 32, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP16_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const Float16 *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP16_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=1, N=1) #define ACCUM_FP16_1X1_AVX512(m, q, dim, out, _MASK, _NORM) \ MATRIX_VAR_INIT(1, 1, __m512, zmm_sum, _mm512_setzero_ps()) \ const Float16 *qe = q + dim; \ const Float16 *qe_aligned = q + ((dim >> 5) << 5); \ if (((uintptr_t)m & 0x3f) == 0 && ((uintptr_t)q & 0x3f) == 0) { \ for (; q != qe_aligned; m += 32, q += 32) { \ MATRIX_FP16_ITER_1X1_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ if (qe >= qe_aligned + 16) { \ __m512 zmm_m = _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)m)); \ __m512 zmm_q = _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)q)); \ ACCUM_FP32_STEP_AVX512(zmm_m, zmm_q, zmm_sum_0_0) \ m += 16; \ q += 16; \ } \ } else { \ for (; q != qe_aligned; m += 32, q += 32) { \ MATRIX_FP16_ITER_1X1_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ if (qe >= qe_aligned + 16) { \ __m512 zmm_m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)m)); \ __m512 zmm_q = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)q)); \ ACCUM_FP32_STEP_AVX512(zmm_m, zmm_q, zmm_sum_0_0) \ m += 16; \ q += 16; \ } \ } \ __m256 ymm_sum_0_0 = _mm256_add_ps(_mm512_castps512_ps256(zmm_sum_0_0), \ _mm256_castpd_ps(_mm512_extractf64x4_pd( \ _mm512_castps_pd(zmm_sum_0_0), 1))); \ if (qe >= q + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)m)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q)); \ ACCUM_FP32_STEP_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ m += 8; \ q += 8; \ } \ MATRIX_FP16_MASK_AVX(m, q, (qe - q), _MASK, ymm_sum, ACCUM_FP32_STEP_AVX) \ *out = _NORM(HorizontalAdd_FP32_V256(ymm_sum_0_0)); //! Compute the distance between matrix and query (FP16, M=16, N=1) #define ACCUM_FP16_16X1_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP16_ITER_16X1_AVX512(m, q, zmm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP16_ITER_16X1_AVX512(m, q, zmm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 1, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=2) #define ACCUM_FP16_16X2_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP16_ITER_16X2_AVX512(m, q, zmm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP16_ITER_16X2_AVX512(m, q, zmm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 2, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=4) #define ACCUM_FP16_16X4_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP16_ITER_16X4_AVX512(m, q, zmm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP16_ITER_16X4_AVX512(m, q, zmm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 4, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=8) #define ACCUM_FP16_16X8_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP16_ITER_16X8_AVX512(m, q, zmm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP16_ITER_16X8_AVX512(m, q, zmm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 8, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=16, N=16) #define ACCUM_FP16_16X16_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 16, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP16_ITER_16X16_AVX512(m, q, zmm_sum, _mm256_load_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP16_ITER_16X16_AVX512(m, q, zmm_sum, _mm256_loadu_si256, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 16, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 16, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=1) #define ACCUM_FP16_32X1_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP16_ITER_32X1_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP16_ITER_32X1_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 1, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=2) #define ACCUM_FP16_32X2_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP16_ITER_32X2_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP16_ITER_32X2_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 2, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=4) #define ACCUM_FP16_32X4_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP16_ITER_32X4_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP16_ITER_32X4_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 4, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=8) #define ACCUM_FP16_32X8_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP16_ITER_32X8_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP16_ITER_32X8_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 8, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=16) #define ACCUM_FP16_32X16_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP16_ITER_32X16_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP16_ITER_32X16_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 16, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP16, M=32, N=32) #define ACCUM_FP16_32X32_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 32, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const Float16 *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP16_ITER_32X32_AVX512(m, q, zmm_sum, _mm512_load_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const Float16 *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP16_ITER_32X32_AVX512(m, q, zmm_sum, _mm512_loadu_si512, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 32, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 32, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) //! Compute the distance between matrix and query (FP16, M=1, N=1) #define ACCUM_FP16_1X1_NEON(m, q, dim, out, _MASK, _NORM) \ MATRIX_VAR_INIT(1, 1, float16x8_t, v_sum, vdupq_n_f16(0)) \ const Float16 *qe = q + dim; \ const Float16 *qe_aligned = q + ((dim >> 3) << 3); \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP16_ITER_1X1_NEON(m, q, v_sum, ACCUM_FP16_STEP_NEON) \ } \ if (qe >= qe_aligned + 4) { \ float16x8_t v_m = \ vcombine_f16(vld1_f16((const float16_t *)m), \ vreinterpret_f16_u64(vdup_n_u64((uint64_t)(_MASK)))); \ float16x8_t v_q = \ vcombine_f16(vld1_f16((const float16_t *)q), \ vreinterpret_f16_u64(vdup_n_u64((uint64_t)(_MASK)))); \ ACCUM_FP16_STEP_NEON(v_m, v_q, v_sum_0_0) \ m += 4; \ q += 4; \ } \ float result = vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(v_sum_0_0)), \ vcvt_high_f32_f16(v_sum_0_0))); \ switch (qe - q) { \ case 3: \ ACCUM_FP16_STEP_GENERAL(m[2], q[2], result) \ /* FALLTHRU */ \ case 2: \ ACCUM_FP16_STEP_GENERAL(m[1], q[1], result) \ /* FALLTHRU */ \ case 1: \ ACCUM_FP16_STEP_GENERAL(m[0], q[0], result) \ } \ *out = _NORM(result); #else //! Compute the distance between matrix and query (FP16, M=1, N=1) #define ACCUM_FP16_1X1_NEON(m, q, dim, out, _MASK, _NORM) \ MATRIX_VAR_INIT(1, 1, float32x4_t, v_sum, vdupq_n_f32(0)) \ const Float16 *qe = q + dim; \ const Float16 *qe_aligned = q + ((dim >> 3) << 3); \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP16_ITER_1X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ if (qe >= qe_aligned + 4) { \ float32x4_t v_m = vcvt_f32_f16(vld1_f16((const float16_t *)m)); \ float32x4_t v_q = vcvt_f32_f16(vld1_f16((const float16_t *)q)); \ ACCUM_FP32_STEP_NEON(v_m, v_q, v_sum_0_0) \ m += 4; \ q += 4; \ } \ float result = vaddvq_f32(v_sum_0_0); \ switch (qe - q) { \ case 3: \ ACCUM_FP16_STEP_GENERAL(m[2], q[2], result) \ /* FALLTHRU */ \ case 2: \ ACCUM_FP16_STEP_GENERAL(m[1], q[1], result) \ /* FALLTHRU */ \ case 1: \ ACCUM_FP16_STEP_GENERAL(m[0], q[0], result) \ } \ *out = _NORM(result); #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ================================================ FILE: src/ailego/math/distance_matrix_accum_fp32.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_fp32.i" #include "matrix_utility.i" #if !defined(__FMA__) #define _mm_fmadd_ps(a, b, c) _mm_add_ps(_mm_mul_ps((a), (b)), (c)) #define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(_mm256_mul_ps((a), (b)), (c)) #endif // !__FMA__ #if defined(__AVX512F__) && !defined(__AVX512DQ__) #define _mm512_and_ps(a, b) \ _mm512_castsi512_ps( \ _mm512_and_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) #define _mm512_mask_and_ps(src, k, a, b) \ _mm512_castsi512_ps(_mm512_mask_and_epi32(_mm512_castps_si512(src), (k), \ _mm512_castps_si512(a), \ _mm512_castps_si512(b))) #endif // __AVX512DQ__ #if defined(__ARM_NEON) && !defined(__aarch64__) static inline float32_t vaddvq_f32(float32x4_t v) { float32x2_t s = vadd_f32(vget_low_f32(v), vget_high_f32(v)); return vget_lane_f32(vpadd_f32(s, s), 0); } static inline int32_t vaddvq_s32(int32x4_t v) { int32x2_t s = vadd_s32(vget_low_s32(v), vget_high_s32(v)); return vget_lane_s32(vpadd_s32(s, s), 0); } #endif //__ARM_NEON && !__aarch64__ #if defined(__aarch64__) #define ACCUM_FP32_2X1_NEON ACCUM_FP32_2X1_NEON_A64 #else #define ACCUM_FP32_2X1_NEON ACCUM_FP32_2X1_NEON_A32 #endif // __aarch64__ //! Compute the distance between matrix and query (FP32, M=2, N=1) #define ACCUM_FP32_2X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128, xmm_sum, _mm_setzero_ps()) \ const float *qe_aligned = q + ((dim >> 2) << 2); \ const float *qe = q + dim; \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP32_ITER_2X1_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128 xmm_m = _mm_load_ps(m); \ __m128 xmm_q = _mm_set_ps(q[1], q[1], q[0], q[0]); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ } else { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP32_ITER_2X1_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_set_ps(q[1], q[1], q[0], q[0]); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ } \ xmm_sum_0_0 = _mm_add_ps(xmm_sum_0_0, xmm_sum_0_1); \ xmm_sum_0_0 = \ _mm_add_ps(xmm_sum_0_0, _mm_movehl_ps(xmm_sum_0_0, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_set_ps(0.0f, 0.0f, m[1], m[0]); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (FP32, M=2, N=2) #define ACCUM_FP32_2X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128, xmm_sum, _mm_setzero_ps()) \ const float *qe = q + (dim << 1); \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (const float *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 4, q += 4) { \ MATRIX_FP32_ITER_2X2_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 4, q += 4) { \ MATRIX_FP32_ITER_2X2_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ xmm_sum_0_0 = _mm_add_ps(_mm_movelh_ps(xmm_sum_0_0, xmm_sum_0_1), \ _mm_movehl_ps(xmm_sum_0_1, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_set_ps(m[1], m[0], m[1], m[0]); \ __m128 xmm_q = _mm_set_ps(q[1], q[1], q[0], q[0]); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=1) #define ACCUM_FP32_4X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128, xmm_sum, _mm_setzero_ps()) \ const float *qe = q + dim; \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP32_ITER_4X1_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ if (q != qe) { \ __m128 xmm_m = _mm_load_ps(m); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ } else { \ for (const float *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP32_ITER_4X1_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ if (q != qe) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ } \ xmm_sum_0_0 = _mm_add_ps(xmm_sum_0_0, xmm_sum_0_1); \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=2) #define ACCUM_FP32_4X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 4, q += 2) { \ MATRIX_FP32_ITER_4X2_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 4, q += 2) { \ MATRIX_FP32_ITER_4X2_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=4) #define ACCUM_FP32_4X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 4, q += 4) { \ MATRIX_FP32_ITER_4X4_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 4, q += 4) { \ MATRIX_FP32_ITER_4X4_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=1) #define ACCUM_FP32_8X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP32_ITER_8X1_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP32_ITER_8X1_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=2) #define ACCUM_FP32_8X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP32_ITER_8X2_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP32_ITER_8X2_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=4) #define ACCUM_FP32_8X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP32_ITER_8X4_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP32_ITER_8X4_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=8) #define ACCUM_FP32_8X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP32_ITER_8X8_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP32_ITER_8X8_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=1) #define ACCUM_FP32_16X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=2) #define ACCUM_FP32_16X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=4) #define ACCUM_FP32_16X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=8) #define ACCUM_FP32_16X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=16) #define ACCUM_FP32_16X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=1) #define ACCUM_FP32_32X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 1, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=2) #define ACCUM_FP32_32X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 2, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=4) #define ACCUM_FP32_32X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 4, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=8) #define ACCUM_FP32_32X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 8, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=16) #define ACCUM_FP32_32X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 16, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=32) #define ACCUM_FP32_32X32_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 32, __m128, xmm_sum, _mm_setzero_ps()) \ if (((uintptr_t)m & 0xf) == 0) { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_SSE(m, q, xmm_sum, _mm_load_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } else { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_SSE(m, q, xmm_sum, _mm_loadu_ps, \ ACCUM_FP32_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=2, N=1) #define ACCUM_FP32_2X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *qe_aligned = q + ((dim >> 2) << 2); \ const float *qe = q + dim; \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (; q != qe_aligned; m += 8, q += 4) { \ MATRIX_FP32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ if (qe >= qe_aligned + 2) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_set_ps(q[1], q[1], q[0], q[0]); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ xmm_sum_0_0 = \ _mm_add_ps(xmm_sum_0_0, _mm_movehl_ps(xmm_sum_0_0, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_set_ps(0.0f, 0.0f, m[1], m[0]); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (FP32, M=2, N=2) #define ACCUM_FP32_2X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *qe_aligned = q + ((dim >> 2) << 3); \ const float *qe = q + (dim << 1); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (; q != qe_aligned; m += 8, q += 8) { \ MATRIX_FP32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ if (qe >= qe_aligned + 4) { \ __m128 xmm_q = _mm_loadu_ps(q); \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(2, 2, 0, 0)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_0) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(3, 3, 1, 1)); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_p, xmm_sum_0_1) \ m += 4; \ q += 4; \ } \ xmm_sum_0_0 = _mm_add_ps(_mm_movelh_ps(xmm_sum_0_0, xmm_sum_0_1), \ _mm_movehl_ps(xmm_sum_0_1, xmm_sum_0_0)); \ if (q != qe) { \ __m128 xmm_m = _mm_set_ps(m[1], m[0], m[1], m[0]); \ __m128 xmm_q = _mm_set_ps(q[1], q[1], q[0], q[0]); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=1) #define ACCUM_FP32_4X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *qe = q + dim; \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe_aligned = q + ((dim >> 1) << 1); q != qe_aligned; \ m += 8, q += 2) { \ MATRIX_FP32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ if (q != qe) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=2) #define ACCUM_FP32_4X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *qe = q + (dim << 1); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 8, q += 4) { \ MATRIX_FP32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe_aligned = q + ((dim >> 1) << 2); q != qe_aligned; \ m += 8, q += 4) { \ MATRIX_FP32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ if (q != qe) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ xmm_q = _mm_broadcast_ss(q + 1); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_1) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=4, N=4) #define ACCUM_FP32_4X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *qe = q + (dim << 2); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (const float *qe_aligned = q + ((dim >> 1) << 3); q != qe_aligned; \ m += 8, q += 8) { \ MATRIX_FP32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe_aligned = q + ((dim >> 1) << 3); q != qe_aligned; \ m += 8, q += 8) { \ MATRIX_FP32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ __m128 xmm_sum_0_0 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_0), \ _mm256_extractf128_ps(ymm_sum_0_0, 1)); \ __m128 xmm_sum_0_1 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_1), \ _mm256_extractf128_ps(ymm_sum_0_1, 1)); \ __m128 xmm_sum_0_2 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_2), \ _mm256_extractf128_ps(ymm_sum_0_2, 1)); \ __m128 xmm_sum_0_3 = _mm_add_ps(_mm256_castps256_ps128(ymm_sum_0_3), \ _mm256_extractf128_ps(ymm_sum_0_3, 1)); \ if (q != qe) { \ __m128 xmm_m = _mm_loadu_ps(m); \ __m128 xmm_q = _mm_broadcast_ss(q); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ xmm_q = _mm_broadcast_ss(q + 1); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_1) \ xmm_q = _mm_broadcast_ss(q + 2); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_2) \ xmm_q = _mm_broadcast_ss(q + 3); \ ACCUM_FP32_STEP_SSE(xmm_m, xmm_q, xmm_sum_0_3) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=1) #define ACCUM_FP32_8X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=2) #define ACCUM_FP32_8X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=4) #define ACCUM_FP32_8X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=8, N=8) #define ACCUM_FP32_8X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=1) #define ACCUM_FP32_16X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=2) #define ACCUM_FP32_16X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=4) #define ACCUM_FP32_16X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=8) #define ACCUM_FP32_16X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=16) #define ACCUM_FP32_16X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=1) #define ACCUM_FP32_32X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=2) #define ACCUM_FP32_32X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=4) #define ACCUM_FP32_32X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=8) #define ACCUM_FP32_32X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=16) #define ACCUM_FP32_32X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=32) #define ACCUM_FP32_32X32_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 32, __m256, ymm_sum, _mm256_setzero_ps()) \ if (((uintptr_t)m & 0x1f) == 0) { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } else { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_ps, \ ACCUM_FP32_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=1) #define ACCUM_FP32_16X1_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 1, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=2) #define ACCUM_FP32_16X2_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 2, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=4) #define ACCUM_FP32_16X4_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 4, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=8) #define ACCUM_FP32_16X8_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 8, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=16, N=16) #define ACCUM_FP32_16X16_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 16, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(1, 16, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 16, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=1) #define ACCUM_FP32_32X1_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 1, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=2) #define ACCUM_FP32_32X2_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 2, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=4) #define ACCUM_FP32_32X4_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 4, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=8) #define ACCUM_FP32_32X8_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 8, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=16) #define ACCUM_FP32_32X16_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 16, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=32, N=32) #define ACCUM_FP32_32X32_AVX512(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 32, __m512, zmm_sum, _mm512_setzero_ps()) \ if (((uintptr_t)m & 0x3f) == 0) { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_AVX512(m, q, zmm_sum, _mm512_load_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } else { \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_AVX512(m, q, zmm_sum, _mm512_loadu_ps, \ ACCUM_FP32_STEP_AVX512) \ } \ } \ if (((uintptr_t)out & 0x3f) == 0) { \ MATRIX_VAR_STORE(2, 32, 16, zmm_sum, out, _mm512_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 32, 16, zmm_sum, out, _mm512_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (FP32, M=2, N=1) on A64 #define ACCUM_FP32_2X1_NEON_A64(m, q, dim, out, _NORM) \ float32x4_t v_sum = vdupq_n_f32(0); \ const float *qe_aligned = q + ((dim >> 1) << 1); \ const float *qe = q + dim; \ for (; q != qe_aligned; m += 4, q += 2) { \ MATRIX_FP32_ITER_2X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ v_sum = vaddq_f32( \ vreinterpretq_f32_u64(vdupq_laneq_u64(vreinterpretq_u64_f32(v_sum), 1)), \ v_sum); \ if (q != qe) { \ float32x4_t v_m = vreinterpretq_f32_u64( \ vdupq_lane_u64(vld1_u64((const uint64_t *)m), 0)); \ float32x4_t v_q = vld1q_dup_f32(q); \ ACCUM_FP32_STEP_NEON(v_m, v_q, v_sum) \ } \ vst1_f32(out, _NORM(vget_low_f32(v_sum))); //! Compute the distance between matrix and query (FP32, M=2, N=1) on A32 #define ACCUM_FP32_2X1_NEON_A32(m, q, dim, out, _NORM) \ float32x4_t v_sum = vdupq_n_f32(0); \ const float *qe_aligned = q + ((dim >> 1) << 1); \ const float *qe = q + dim; \ for (; q != qe_aligned; m += 4, q += 2) { \ MATRIX_FP32_ITER_2X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ float32x2_t sum = vadd_f32(vget_low_f32(v_sum), vget_high_f32(v_sum)); \ v_sum = vcombine_f32(sum, sum); \ if (q != qe) { \ float32x4_t v_m = vreinterpretq_f32_u64( \ vdupq_lane_u64(vld1_u64((const uint64_t *)m), 0)); \ float32x4_t v_q = vld1q_dup_f32(q); \ ACCUM_FP32_STEP_NEON(v_m, v_q, v_sum) \ } \ vst1_f32(out, _NORM(vget_low_f32(v_sum))); //! Compute the distance between matrix and query (FP32, M=2, N=2) #define ACCUM_FP32_2X2_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ const float *qe_aligned = q + ((dim >> 1) << 2); \ const float *qe = q + (dim << 1); \ for (; q != qe_aligned; m += 4, q += 4) { \ MATRIX_FP32_ITER_2X2_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ v_sum_0_0 = vaddq_f32( \ vcombine_f32(vget_low_f32(v_sum_0_0), vget_low_f32(v_sum_0_1)), \ vcombine_f32(vget_high_f32(v_sum_0_0), vget_high_f32(v_sum_0_1))); \ if (q != qe) { \ float32x2_t v_m_0 = vld1_f32(m); \ float32x2_t v_q_0 = vld1_f32(q); \ float32x4_t v_m = vcombine_f32(v_m_0, v_m_0); \ float32x4_t v_q = \ vcombine_f32(vdup_lane_f32(v_q_0, 0), vdup_lane_f32(v_q_0, 1)); \ ACCUM_FP32_STEP_NEON(v_m, v_q, v_sum_0_0) \ } \ MATRIX_VAR_STORE(1, 1, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=4, N=1) #define ACCUM_FP32_4X1_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ const float *qe_aligned = q + ((dim >> 1) << 1); \ const float *qe = q + dim; \ for (; q != qe_aligned; m += 8, q += 2) { \ MATRIX_FP32_ITER_4X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ if (q != qe) { \ float32x4_t v_m = vld1q_f32(m); \ float32x4_t v_q = vld1q_dup_f32(q); \ ACCUM_FP32_STEP_NEON(v_m, v_q, v_sum_0_0) \ } \ v_sum_0_0 = vaddq_f32(v_sum_0_0, v_sum_0_1); \ MATRIX_VAR_STORE(1, 1, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=4, N=2) #define ACCUM_FP32_4X2_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 1); q != qe; m += 4, q += 2) { \ MATRIX_FP32_ITER_4X2_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(1, 2, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=4, N=4) #define ACCUM_FP32_4X4_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 2); q != qe; m += 4, q += 4) { \ MATRIX_FP32_ITER_4X4_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(1, 4, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=8, N=1) #define ACCUM_FP32_8X1_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + dim; q != qe; m += 8, ++q) { \ MATRIX_FP32_ITER_8X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(2, 1, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=8, N=2) #define ACCUM_FP32_8X2_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 1); q != qe; m += 8, q += 2) { \ MATRIX_FP32_ITER_8X2_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(2, 2, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=8, N=4) #define ACCUM_FP32_8X4_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 2); q != qe; m += 8, q += 4) { \ MATRIX_FP32_ITER_8X4_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(2, 4, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=8, N=8) #define ACCUM_FP32_8X8_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 3); q != qe; m += 8, q += 8) { \ MATRIX_FP32_ITER_8X8_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(2, 8, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=16, N=1) #define ACCUM_FP32_16X1_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + dim; q != qe; m += 16, ++q) { \ MATRIX_FP32_ITER_16X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(4, 1, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=16, N=2) #define ACCUM_FP32_16X2_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 1); q != qe; m += 16, q += 2) { \ MATRIX_FP32_ITER_16X2_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(4, 2, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=16, N=4) #define ACCUM_FP32_16X4_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 2); q != qe; m += 16, q += 4) { \ MATRIX_FP32_ITER_16X4_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(4, 4, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=16, N=8) #define ACCUM_FP32_16X8_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 3); q != qe; m += 16, q += 8) { \ MATRIX_FP32_ITER_16X8_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(4, 8, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=16, N=16) #define ACCUM_FP32_16X16_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 4); q != qe; m += 16, q += 16) { \ MATRIX_FP32_ITER_16X16_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(4, 16, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=1) #define ACCUM_FP32_32X1_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 1, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + dim; q != qe; m += 32, ++q) { \ MATRIX_FP32_ITER_32X1_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 1, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=2) #define ACCUM_FP32_32X2_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 1); q != qe; m += 32, q += 2) { \ MATRIX_FP32_ITER_32X2_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 2, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=4) #define ACCUM_FP32_32X4_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 4, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 2); q != qe; m += 32, q += 4) { \ MATRIX_FP32_ITER_32X4_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 4, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=8) #define ACCUM_FP32_32X8_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 8, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 3); q != qe; m += 32, q += 8) { \ MATRIX_FP32_ITER_32X8_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 8, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=16) #define ACCUM_FP32_32X16_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 16, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 4); q != qe; m += 32, q += 16) { \ MATRIX_FP32_ITER_32X16_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 16, 4, v_sum, out, vst1q_f32, _NORM) //! Compute the distance between matrix and query (FP32, M=32, N=32) #define ACCUM_FP32_32X32_NEON(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 32, float32x4_t, v_sum, vdupq_n_f32(0)) \ for (const float *qe = q + (dim << 5); q != qe; m += 32, q += 32) { \ MATRIX_FP32_ITER_32X32_NEON(m, q, v_sum, ACCUM_FP32_STEP_NEON) \ } \ MATRIX_VAR_STORE(8, 32, 4, v_sum, out, vst1q_f32, _NORM) ================================================ FILE: src/ailego/math/distance_matrix_accum_int4.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_int32.i" #include "matrix_utility.i" //! Compute the distance between matrix and query (INT4, M=2, N=1) #define ACCUM_INT4_2X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 5) << 2); \ const uint32_t *qe = qi + (dim >> 3); \ if (((uintptr_t)mi & 0xf) == 0 && ((uintptr_t)qi & 0xf) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_load_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ mi += 4; \ qi += 2; \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ mi += 4; \ qi += 2; \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_0_1); \ xmm_sum_0_0 = _mm_add_epi32( \ xmm_sum_0_0, _mm_shuffle_epi32(xmm_sum_0_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(0, 0, mi[1], mi[0]); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (INT4, M=2, N=2) #define ACCUM_INT4_2X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 3) << 1); \ if (((uintptr_t)mi & 0xf) == 0 && ((uintptr_t)qi & 0xf) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ qi != qe_aligned; mi += 4, qi += 4) { \ MATRIX_INT32_ITER_2X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ qi != qe_aligned; mi += 4, qi += 4) { \ MATRIX_INT32_ITER_2X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(mi[1], mi[0], mi[1], mi[0]); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=1) #define ACCUM_INT4_4X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + (dim >> 3); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ if (qi != qe) { \ __m128i xmm_mi = _mm_load_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_1_0); \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=2) #define ACCUM_INT4_4X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 4, qi += 2) { \ MATRIX_INT32_ITER_4X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 4, qi += 2) { \ MATRIX_INT32_ITER_4X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=4) #define ACCUM_INT4_4X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 4, qi += 4) { \ MATRIX_INT32_ITER_4X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 4, qi += 4) { \ MATRIX_INT32_ITER_4X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=1) #define ACCUM_INT4_8X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=2) #define ACCUM_INT4_8X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=4) #define ACCUM_INT4_8X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=8) #define ACCUM_INT4_8X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=1) #define ACCUM_INT4_16X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=2) #define ACCUM_INT4_16X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=4) #define ACCUM_INT4_16X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=8) #define ACCUM_INT4_16X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=16) #define ACCUM_INT4_16X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=1) #define ACCUM_INT4_32X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=2) #define ACCUM_INT4_32X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=4) #define ACCUM_INT4_32X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=8) #define ACCUM_INT4_32X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=16) #define ACCUM_INT4_32X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=32) #define ACCUM_INT4_32X32_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 32, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT4_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=2, N=1) #define ACCUM_INT4_2X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 5) << 2); \ const uint32_t *qe = qi + (dim >> 3); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ __m128i xmm_sum_0 = _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0) \ mi += 4; \ qi += 2; \ } \ xmm_sum_0 = _mm_add_epi32( \ xmm_sum_0, _mm_shuffle_epi32(xmm_sum_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(0, 0, mi[1], mi[0]); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0)); //! Compute the distance between matrix and query (INT4, M=2, N=2) #define ACCUM_INT4_2X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 5) << 3); \ const uint32_t *qe = qi + ((dim >> 3) << 1); \ if (((uintptr_t)mi & 0x1f) == 0 && ((uintptr_t)qi & 0x1f) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_2X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_2X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (qe >= qe_aligned + 4) { \ __m128i xmm_qi = _mm_loadu_si128((const __m128i *)(qi)); \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(2, 2, 0, 0)); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_pi, xmm_sum_0_0) \ xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(3, 3, 1, 1)); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_pi, xmm_sum_0_1) \ mi += 4; \ qi += 4; \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(mi[1], mi[0], mi[1], mi[0]); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=1) #define ACCUM_INT4_4X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + (dim >> 3); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=2) #define ACCUM_INT4_4X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 3) << 1); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_4X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_4X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_1) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=4, N=4) #define ACCUM_INT4_4X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 3) << 2); \ if (((uintptr_t)mi & 0x1f) == 0 && ((uintptr_t)qi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 3); \ qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_4X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 4) << 3); \ qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_4X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ __m128i xmm_sum_0_2 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_2), \ _mm256_extracti128_si256(ymm_sum_0_2, 1)); \ __m128i xmm_sum_0_3 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_3), \ _mm256_extracti128_si256(ymm_sum_0_3, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_1) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_2) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ ACCUM_INT4_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_3) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=1) #define ACCUM_INT4_8X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=2) #define ACCUM_INT4_8X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=4) #define ACCUM_INT4_8X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=8, N=8) #define ACCUM_INT4_8X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=1) #define ACCUM_INT4_16X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=2) #define ACCUM_INT4_16X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=4) #define ACCUM_INT4_16X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=8) #define ACCUM_INT4_16X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=16, N=16) #define ACCUM_INT4_16X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=1) #define ACCUM_INT4_32X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 3); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=2) #define ACCUM_INT4_32X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=4) #define ACCUM_INT4_32X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=8) #define ACCUM_INT4_32X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=16) #define ACCUM_INT4_32X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT4, M=32, N=32) #define ACCUM_INT4_32X32_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 32, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 3) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 3) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT4_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } ================================================ FILE: src/ailego/math/distance_matrix_accum_int8.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_int32.i" #include "matrix_utility.i" //! Compute the distance between matrix and query (INT8, M=2, N=1) #define ACCUM_INT8_2X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ const uint32_t *qe = qi + (dim >> 2); \ if (((uintptr_t)mi & 0xf) == 0 && ((uintptr_t)qi & 0xf) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_load_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ mi += 4; \ qi += 2; \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ mi += 4; \ qi += 2; \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_0_1); \ xmm_sum_0_0 = _mm_add_epi32( \ xmm_sum_0_0, _mm_shuffle_epi32(xmm_sum_0_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(0, 0, mi[1], mi[0]); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (INT8, M=2, N=2) #define ACCUM_INT8_2X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 2) << 1); \ if (((uintptr_t)mi & 0xf) == 0 && ((uintptr_t)qi & 0xf) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 2); \ qi != qe_aligned; mi += 4, qi += 4) { \ MATRIX_INT32_ITER_2X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 2); \ qi != qe_aligned; mi += 4, qi += 4) { \ MATRIX_INT32_ITER_2X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(mi[1], mi[0], mi[1], mi[0]); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=1) #define ACCUM_INT8_4X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + (dim >> 2); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ if (qi != qe) { \ __m128i xmm_mi = _mm_load_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_1_0); \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=2) #define ACCUM_INT8_4X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 4, qi += 2) { \ MATRIX_INT32_ITER_4X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 4, qi += 2) { \ MATRIX_INT32_ITER_4X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=4) #define ACCUM_INT8_4X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 4, qi += 4) { \ MATRIX_INT32_ITER_4X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 4, qi += 4) { \ MATRIX_INT32_ITER_4X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=1) #define ACCUM_INT8_8X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=2) #define ACCUM_INT8_8X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=4) #define ACCUM_INT8_8X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=8) #define ACCUM_INT8_8X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=1) #define ACCUM_INT8_16X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=2) #define ACCUM_INT8_16X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=4) #define ACCUM_INT8_16X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=8) #define ACCUM_INT8_16X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=16) #define ACCUM_INT8_16X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=1) #define ACCUM_INT8_32X1_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=2) #define ACCUM_INT8_32X2_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=4) #define ACCUM_INT8_32X4_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=8) #define ACCUM_INT8_32X8_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=16) #define ACCUM_INT8_32X16_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=32) #define ACCUM_INT8_32X32_SSE(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(8, 32, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0xf) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_SSE(mi, qi, xmm_sum, _mm_load_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_SSE(mi, qi, xmm_sum, _mm_loadu_si128, \ ACCUM_INT8_STEP_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=2, N=1) #define ACCUM_INT8_2X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 4) << 2); \ const uint32_t *qe = qi + (dim >> 2); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_2X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ __m128i xmm_sum_0 = _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (qe >= qe_aligned + 2) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0) \ mi += 4; \ qi += 2; \ } \ xmm_sum_0 = _mm_add_epi32( \ xmm_sum_0, _mm_shuffle_epi32(xmm_sum_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(0, 0, mi[1], mi[0]); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0)); //! Compute the distance between matrix and query (INT8, M=2, N=2) #define ACCUM_INT8_2X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe_aligned = qi + ((dim >> 4) << 3); \ const uint32_t *qe = qi + ((dim >> 2) << 1); \ if (((uintptr_t)mi & 0x1f) == 0 && ((uintptr_t)qi & 0x1f) == 0) { \ for (; qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_2X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (; qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_2X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (qe >= qe_aligned + 4) { \ __m128i xmm_qi = _mm_loadu_si128((const __m128i *)(qi)); \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(2, 2, 0, 0)); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_pi, xmm_sum_0_0) \ xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(3, 3, 1, 1)); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_pi, xmm_sum_0_1) \ mi += 4; \ qi += 4; \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_set_epi32(mi[1], mi[0], mi[1], mi[0]); \ __m128i xmm_qi = _mm_set_epi32(qi[1], qi[1], qi[0], qi[0]); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=1) #define ACCUM_INT8_4X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + (dim >> 2); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 1); \ qi != qe_aligned; mi += 8, qi += 2) { \ MATRIX_INT32_ITER_4X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=2) #define ACCUM_INT8_4X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 2) << 1); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 2); \ qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_4X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 2); \ qi != qe_aligned; mi += 8, qi += 4) { \ MATRIX_INT32_ITER_4X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_1) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=4, N=4) #define ACCUM_INT8_4X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ const uint32_t *qe = qi + ((dim >> 2) << 2); \ if (((uintptr_t)mi & 0x1f) == 0 && ((uintptr_t)qi & 0x1f) == 0) { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 3); \ qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_4X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe_aligned = qi + ((dim >> 3) << 3); \ qi != qe_aligned; mi += 8, qi += 8) { \ MATRIX_INT32_ITER_4X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ __m128i xmm_sum_0_2 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_2), \ _mm256_extracti128_si256(ymm_sum_0_2, 1)); \ __m128i xmm_sum_0_3 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_3), \ _mm256_extracti128_si256(ymm_sum_0_3, 1)); \ if (qi != qe) { \ __m128i xmm_mi = _mm_loadu_si128((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_1) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_2) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ ACCUM_INT8_STEP_SSE(xmm_mi, xmm_qi, xmm_sum_0_3) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=1) #define ACCUM_INT8_8X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 8, ++qi) { \ MATRIX_INT32_ITER_8X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=2) #define ACCUM_INT8_8X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 8, qi += 2) { \ MATRIX_INT32_ITER_8X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=4) #define ACCUM_INT8_8X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 8, qi += 4) { \ MATRIX_INT32_ITER_8X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=8, N=8) #define ACCUM_INT8_8X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 8, qi += 8) { \ MATRIX_INT32_ITER_8X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=1) #define ACCUM_INT8_16X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 16, ++qi) { \ MATRIX_INT32_ITER_16X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=2) #define ACCUM_INT8_16X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 16, qi += 2) { \ MATRIX_INT32_ITER_16X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=4) #define ACCUM_INT8_16X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 16, qi += 4) { \ MATRIX_INT32_ITER_16X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=8) #define ACCUM_INT8_16X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 16, qi += 8) { \ MATRIX_INT32_ITER_16X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=16, N=16) #define ACCUM_INT8_16X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 16, qi += 16) { \ MATRIX_INT32_ITER_16X16_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=1) #define ACCUM_INT8_32X1_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + (dim >> 2); qi != qe; mi += 32, ++qi) { \ MATRIX_INT32_ITER_32X1_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=2) #define ACCUM_INT8_32X2_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 1); qi != qe; \ mi += 32, qi += 2) { \ MATRIX_INT32_ITER_32X2_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=4) #define ACCUM_INT8_32X4_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 2); qi != qe; \ mi += 32, qi += 4) { \ MATRIX_INT32_ITER_32X4_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=8) #define ACCUM_INT8_32X8_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 3); qi != qe; \ mi += 32, qi += 8) { \ MATRIX_INT32_ITER_32X8_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=16) #define ACCUM_INT8_32X16_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 4); qi != qe; \ mi += 32, qi += 16) { \ MATRIX_INT32_ITER_32X16_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (INT8, M=32, N=32) #define ACCUM_INT8_32X32_AVX(m, q, dim, out, _NORM) \ MATRIX_VAR_INIT(4, 32, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qi = reinterpret_cast(q); \ const uint32_t *mi = reinterpret_cast(m); \ if (((uintptr_t)mi & 0x1f) == 0) { \ for (const uint32_t *qe = qi + ((dim >> 2) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_AVX(mi, qi, ymm_sum, _mm256_load_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } else { \ for (const uint32_t *qe = qi + ((dim >> 2) << 5); qi != qe; \ mi += 32, qi += 32) { \ MATRIX_INT32_ITER_32X32_AVX(mi, qi, ymm_sum, _mm256_loadu_si256, \ ACCUM_INT8_STEP_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } ================================================ FILE: src/ailego/math/distance_matrix_euclidean_utility.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Calculate sum of squared difference (GENERAL) #define SSD_FP32_GENERAL(m, q, sum) \ { \ float x = m - q; \ sum += (x * x); \ } //! Calculate sum of squared difference (SSE) #define SSD_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ { \ __m128 xmm_d = _mm_sub_ps(xmm_m, xmm_q); \ xmm_sum = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum); \ } //! Calculate sum of squared difference (AVX) #define SSD_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ { \ __m256 ymm_d = _mm256_sub_ps(ymm_m, ymm_q); \ ymm_sum = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum); \ } //! Calculate sum of squared difference (NEON) #define SSD_FP32_NEON(v_m, v_q, v_sum) \ { \ float32x4_t v_d = vsubq_f32(v_m, v_q); \ v_sum = vfmaq_f32(v_sum, v_d, v_d); \ } //! Calculate sum of squared difference (GENERAL) #define SSD_FP16_GENERAL(m, q, sum) \ { \ float x = m - q; \ sum += (x * x); \ } //! Calculate sum of squared difference (NEON) #define SSD_FP16_NEON(v_m, v_q, v_sum) \ { \ float16x8_t v_d = vsubq_f16(v_m, v_q); \ v_sum = vfmaq_f16(v_sum, v_d, v_d); \ } //! Calculate sum of squared difference (AVX512) #define SSD_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ { \ __m512 zmm_d = _mm512_sub_ps(zmm_m, zmm_q); \ zmm_sum = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum); \ } //! Calculate sum of squared difference (GENERAL) #define SSD_INT4_GENERAL(m, q, sum) \ sum += Int4SquaredDiffTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ Int4SquaredDiffTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; #if defined(__SSE4_1__) static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0xf0f0f0f0); static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); #endif // __SSE4_1__ //! Compute the square root of value (SSE) #define SQRT_FP32_SSE(v, ...) _mm_sqrt_ps(_mm_cvtepi32_ps(v)) #if defined(__AVX2__) static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0xf0f0f0f0); static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); #endif // __AVX2__ //! Calculate sum of squared difference (SSE) #define SSD_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ { \ __m128i xmm_lhs = \ _mm_and_si128(_mm_slli_epi32((xmm_m), 4), MASK_INT4_SSE); \ __m128i xmm_rhs = \ _mm_and_si128(_mm_slli_epi32((xmm_q), 4), MASK_INT4_SSE); \ xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ _mm_min_epi8(xmm_lhs, xmm_rhs)), \ 4); \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ xmm_sum); \ xmm_lhs = _mm_and_si128((xmm_m), MASK_INT4_SSE); \ xmm_rhs = _mm_and_si128((xmm_q), MASK_INT4_SSE); \ xmm_lhs = _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), \ _mm_min_epi8(xmm_lhs, xmm_rhs)), \ 4); \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs, xmm_lhs), ONES_INT16_SSE), \ xmm_sum); \ } //! Compute the distance between matrix and query #define SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ { \ __m128i xmm_lhs_0 = \ _mm_and_si128(_mm_slli_epi32((xmm_lhs), 4), MASK_INT4_SSE); \ __m128i xmm_rhs_0 = \ _mm_and_si128(_mm_slli_epi32((xmm_rhs), 4), MASK_INT4_SSE); \ __m128i xmm_lhs_1 = _mm_and_si128((xmm_lhs), MASK_INT4_SSE); \ __m128i xmm_rhs_1 = _mm_and_si128((xmm_rhs), MASK_INT4_SSE); \ xmm_lhs_0 = \ _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), \ _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)), \ 4); \ xmm_rhs_0 = \ _mm_srli_epi32(_mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), \ _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)), \ 4); \ xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_lhs_0, xmm_lhs_0), \ ONES_INT16_SSE); \ xmm_rhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_rhs_0), \ ONES_INT16_SSE); \ xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_rhs_0), xmm_sum); \ } //! Calculate sum of squared difference (AVX) #define SSD_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ { \ __m256i ymm_lhs = \ _mm256_and_si256(_mm256_slli_epi32((ymm_m), 4), MASK_INT4_AVX); \ __m256i ymm_rhs = \ _mm256_and_si256(_mm256_slli_epi32((ymm_q), 4), MASK_INT4_AVX); \ ymm_lhs = \ _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ 4); \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ ONES_INT16_AVX), \ ymm_sum); \ ymm_lhs = _mm256_and_si256((ymm_m), MASK_INT4_AVX); \ ymm_rhs = _mm256_and_si256((ymm_q), MASK_INT4_AVX); \ ymm_lhs = \ _mm256_srli_epi32(_mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), \ _mm256_min_epi8(ymm_lhs, ymm_rhs)), \ 4); \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs, ymm_lhs), \ ONES_INT16_AVX), \ ymm_sum); \ } //! Compute the distance between matrix and query #define SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ { \ __m256i ymm_lhs_0 = \ _mm256_and_si256(_mm256_slli_epi32((ymm_lhs), 4), MASK_INT4_AVX); \ __m256i ymm_rhs_0 = \ _mm256_and_si256(_mm256_slli_epi32((ymm_rhs), 4), MASK_INT4_AVX); \ __m256i ymm_lhs_1 = _mm256_and_si256((ymm_lhs), MASK_INT4_AVX); \ __m256i ymm_rhs_1 = _mm256_and_si256((ymm_rhs), MASK_INT4_AVX); \ ymm_lhs_0 = _mm256_srli_epi32( \ _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), \ _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)), \ 4); \ ymm_rhs_0 = _mm256_srli_epi32( \ _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), \ _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)), \ 4); \ ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_lhs_0, ymm_lhs_0), \ ONES_INT16_AVX); \ ymm_rhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_rhs_0), \ ONES_INT16_AVX); \ ymm_sum = \ _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_rhs_0), ymm_sum); \ } //! Calculate sum of squared difference (GENERAL) #define SSD_INT8_GENERAL(m, q, sum) \ { \ int32_t x = m - q; \ sum += static_cast(x * x); \ } //! Calculate sum of squared difference (SSE) #define SSD_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ { \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_m), \ _mm_sign_epi8(xmm_m, xmm_m)), \ ONES_INT16_SSE), \ xmm_sum); \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ _mm_sign_epi8(xmm_q, xmm_q)), \ ONES_INT16_SSE), \ xmm_sum); \ xmm_sum = _mm_sub_epi32( \ xmm_sum, \ _mm_slli_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_q), \ _mm_sign_epi8(xmm_m, xmm_q)), \ ONES_INT16_SSE), \ 1)); \ } //! Calculate sum of squared difference (AVX) #define SSD_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ { \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_m), \ _mm256_sign_epi8(ymm_m, ymm_m)), \ ONES_INT16_AVX), \ ymm_sum); \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ _mm256_sign_epi8(ymm_q, ymm_q)), \ ONES_INT16_AVX), \ ymm_sum); \ ymm_sum = _mm256_sub_epi32( \ ymm_sum, _mm256_slli_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ _mm256_sign_epi8(ymm_m, ymm_q)), \ ONES_INT16_AVX), \ 1)); \ } //! Compute the square root of value (AVX) #define SQRT_FP32_AVX(v, ...) _mm256_sqrt_ps(_mm256_cvtepi32_ps(v)) //! Compute the square root of value (AVX512) #define SQRT_FP32_AVX512(v, ...) _mm512_sqrt_ps(_mm512_cvtepi32_ps(v)) #define ACCUM_FP32_STEP_SSE SSD_FP32_SSE #define ACCUM_FP32_STEP_AVX SSD_FP32_AVX #define ACCUM_FP32_STEP_AVX512 SSD_FP32_AVX512 #define ACCUM_FP16_STEP_GENERAL SSD_FP16_GENERAL #define ACCUM_FP16_STEP_NEON SSD_FP16_NEON #define ACCUM_FP32_STEP_NEON SSD_FP32_NEON #define ACCUM_INT4_STEP_SSE SSD_INT4_SSE #define ACCUM_INT4_STEP_AVX SSD_INT4_AVX #define ACCUM_INT8_STEP_SSE SSD_INT8_SSE #define ACCUM_INT8_STEP_AVX SSD_INT8_AVX ================================================ FILE: src/ailego/math/distance_matrix_fp16.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "matrix_define.i" #include #if !defined(__AVX__) #define _mm_broadcast_si32(a) _mm_castps_si128(_mm_load1_ps((const float *)(a))) #else #define _mm_broadcast_si32(a) \ _mm_castps_si128(_mm_broadcast_ss((const float *)(a))) #define _mm256_broadcast_si32(a) \ _mm256_castps_si256(_mm256_broadcast_ss((const float *)(a))) #endif // !__AVX__ //! Mask process of computing distance (FP16) #define MATRIX_FP16_MASK_AVX(lhs, rhs, cnt, _MASK, _RES, _PROC) \ switch (cnt) { \ case 7: { \ __m256 ymm_lhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), *((const short *)(lhs) + 6), \ *((const short *)(lhs) + 5), *((const short *)(lhs) + 4), \ *((const short *)(lhs) + 3), *((const short *)(lhs) + 2), \ *((const short *)(lhs) + 1), *((const short *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), *((const short *)(rhs) + 6), \ *((const short *)(rhs) + 5), *((const short *)(rhs) + 4), \ *((const short *)(rhs) + 3), *((const short *)(rhs) + 2), \ *((const short *)(rhs) + 1), *((const short *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 6: { \ __m256 ymm_lhs = _mm256_cvtph_ps( \ _mm_set_epi32((int)(_MASK), *((const int *)(lhs) + 2), \ *((const int *)(lhs) + 1), *((const int *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps( \ _mm_set_epi32((int)(_MASK), *((const int *)(rhs) + 2), \ *((const int *)(rhs) + 1), *((const int *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 5: { \ __m256 ymm_lhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), (short)(_MASK), (short)(_MASK), \ *((const short *)(lhs) + 4), *((const short *)(lhs) + 3), \ *((const short *)(lhs) + 2), *((const short *)(lhs) + 1), \ *((const short *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), (short)(_MASK), (short)(_MASK), \ *((const short *)(rhs) + 4), *((const short *)(rhs) + 3), \ *((const short *)(rhs) + 2), *((const short *)(rhs) + 1), \ *((const short *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 4: { \ __m256 ymm_lhs = _mm256_cvtph_ps( \ _mm_set_epi64((__m64)(_MASK), *((const __m64 *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps( \ _mm_set_epi64((__m64)(_MASK), *((const __m64 *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 3: { \ __m256 ymm_lhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), (short)(_MASK), (short)(_MASK), (short)(_MASK), \ (short)(_MASK), *((const short *)(lhs) + 2), \ *((const short *)(lhs) + 1), *((const short *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps(_mm_set_epi16( \ (short)(_MASK), (short)(_MASK), (short)(_MASK), (short)(_MASK), \ (short)(_MASK), *((const short *)(rhs) + 2), \ *((const short *)(rhs) + 1), *((const short *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 2: { \ __m256 ymm_lhs = _mm256_cvtph_ps(_mm_set_epi32( \ (int)(_MASK), (int)(_MASK), (int)(_MASK), *((const int *)(lhs)))); \ __m256 ymm_rhs = _mm256_cvtph_ps(_mm_set_epi32( \ (int)(_MASK), (int)(_MASK), (int)(_MASK), *((const int *)(rhs)))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ case 1: { \ __m256 ymm_lhs = _mm256_cvtph_ps( \ _mm_set_epi16(*((const short *)(lhs)), (short)(_MASK), \ (short)(_MASK), (short)(_MASK), (short)(_MASK), \ (short)(_MASK), (short)(_MASK), (short)(_MASK))); \ __m256 ymm_rhs = _mm256_cvtph_ps( \ _mm_set_epi16(*((const short *)(rhs)), (short)(_MASK), \ (short)(_MASK), (short)(_MASK), (short)(_MASK), \ (short)(_MASK), (short)(_MASK), (short)(_MASK))); \ _PROC(ymm_lhs, ymm_rhs, _RES##_0_0) \ break; \ } \ } //! Iterative process of computing distance (FP16, M=1, N=1) #define MATRIX_FP16_ITER_1X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)m); \ __m256i ymm_qi = _LOAD((const __m256i *)q); \ __m256 ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_qi)); \ _PROC(ymm_m, ymm_q, _RES##_0_0); \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ ymm_q = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_qi, 1)); \ _PROC(ymm_m, ymm_q, _RES##_0_0); \ } //! Iterative process of computing distance (FP16, M=2, N=1) #define MATRIX_FP16_ITER_2X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_shufflehi_epi16( \ _mm_shufflelo_epi16(_mm_set1_epi64(*(const __m64 *)(q)), \ _MM_SHUFFLE(1, 1, 0, 0)), \ _MM_SHUFFLE(3, 3, 2, 2))); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP16, M=2, N=2) #define MATRIX_FP16_ITER_2X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_q = _mm256_cvtph_ps(_LOAD((const __m128i *)(q))); \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_p = _mm256_moveldup_ps(ymm_q); \ _PROC(ymm_m, ymm_p, _RES##_0_0) \ ymm_p = _mm256_movehdup_ps(ymm_q); \ _PROC(ymm_m, ymm_p, _RES##_0_1) \ } //! Iterative process of computing distance (FP16, M=4, N=1) #define MATRIX_FP16_ITER_4X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_q = _mm256_cvtph_ps( \ _mm_shufflehi_epi16(_mm_shufflelo_epi16(_mm_broadcast_si32(q), 0), \ _MM_SHUFFLE(1, 1, 1, 1))); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP16, M=4, N=2) #define MATRIX_FP16_ITER_4X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi = _mm_set1_epi64(*(const __m64 *)(q)); \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_q_0 = _mm256_cvtph_ps(_mm_shufflehi_epi16( \ _mm_shufflelo_epi16(xmm_qi, _MM_SHUFFLE(0, 0, 0, 0)), \ _MM_SHUFFLE(2, 2, 2, 2))); \ __m256 ymm_q_1 = _mm256_cvtph_ps(_mm_shufflehi_epi16( \ _mm_shufflelo_epi16(xmm_qi, _MM_SHUFFLE(1, 1, 1, 1)), \ _MM_SHUFFLE(3, 3, 3, 3))); \ MATRIX_VAR_PROC(1, 2, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=4, N=4) #define MATRIX_FP16_ITER_4X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_q = _mm256_cvtph_ps(_LOAD((const __m128i *)(q))); \ __m256 ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(0, 0, 0, 0)); \ _PROC(ymm_m, ymm_p, _RES##_0_0) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(1, 1, 1, 1)); \ _PROC(ymm_m, ymm_p, _RES##_0_1) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(2, 2, 2, 2)); \ _PROC(ymm_m, ymm_p, _RES##_0_2) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(3, 3, 3, 3)); \ _PROC(ymm_m, ymm_p, _RES##_0_3) \ } //! Iterative process of computing distance (FP16, M=8, N=1) #define MATRIX_FP16_ITER_8X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_set1_epi16(*(const short *)(q))); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP16, M=8, N=2) #define MATRIX_FP16_ITER_8X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m128 xmm_p = _mm_cvtph_ps(_mm_broadcast_si32(q)); \ __m256 ymm_q_0 = _mm256_set1_ps(xmm_p[0]); \ __m256 ymm_q_1 = _mm256_set1_ps(xmm_p[1]); \ MATRIX_VAR_PROC(1, 2, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=8, N=4) #define MATRIX_FP16_ITER_8X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m128 xmm_p = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(xmm_p[0]); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_set1_ps(xmm_p[1]); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ ymm_q = _mm256_set1_ps(xmm_p[2]); \ _PROC(ymm_m, ymm_q, _RES##_0_2) \ ymm_q = _mm256_set1_ps(xmm_p[3]); \ _PROC(ymm_m, ymm_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP16, M=8, N=8) #define MATRIX_FP16_ITER_8X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _mm256_cvtph_ps(_LOAD((const __m128i *)(m))); \ __m256 ymm_p = _mm256_cvtph_ps(_LOAD((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ _PROC(ymm_m, ymm_q, _RES##_0_2) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ _PROC(ymm_m, ymm_q, _RES##_0_3) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ _PROC(ymm_m, ymm_q, _RES##_0_4) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ _PROC(ymm_m, ymm_q, _RES##_0_5) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ _PROC(ymm_m, ymm_q, _RES##_0_6) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ _PROC(ymm_m, ymm_q, _RES##_0_7) \ } //! Iterative process of computing distance (FP16, M=16, N=1) #define MATRIX_FP16_ITER_16X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_set1_epi16(*(const short *)q)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=16, N=2) #define MATRIX_FP16_ITER_16X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m128 xmm_p = _mm_cvtph_ps(_mm_broadcast_si32(q)); \ __m256 ymm_q_0 = _mm256_set1_ps(xmm_p[0]); \ __m256 ymm_q_1 = _mm256_set1_ps(xmm_p[1]); \ __m256 ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(1, 2, 1, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=16, N=4) #define MATRIX_FP16_ITER_16X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m128 xmm_p = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(xmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(xmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(xmm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(xmm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=16, N=8) #define MATRIX_FP16_ITER_16X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(2, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(2, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(2, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(2, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=16, N=16) #define MATRIX_FP16_ITER_16X16_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(2, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(2, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(2, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(2, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q + 8))); \ ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(2, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(2, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(2, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(2, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(2, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(2, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(2, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(2, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=1) #define MATRIX_FP16_ITER_32X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ __m256 ymm_m_2 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_3 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_q = _mm256_cvtph_ps(_mm_set1_epi16(*(const short *)q)); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=2) #define MATRIX_FP16_ITER_32X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_p = _mm_cvtph_ps(_mm_broadcast_si32(q)); \ __m256 ymm_q_0 = _mm256_set1_ps(xmm_p[0]); \ __m256 ymm_q_1 = _mm256_set1_ps(xmm_p[1]); \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(1, 2, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ MATRIX_VAR_PROC(1, 2, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(1, 2, 3, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=4) #define MATRIX_FP16_ITER_32X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_p = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)(q))); \ __m256 ymm_q_0 = _mm256_set1_ps(xmm_p[0]); \ __m256 ymm_q_1 = _mm256_set1_ps(xmm_p[1]); \ __m256 ymm_q_2 = _mm256_set1_ps(xmm_p[2]); \ __m256 ymm_q_3 = _mm256_set1_ps(xmm_p[3]); \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ MATRIX_VAR_PROC(1, 4, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(1, 4, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ ymm_m = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ MATRIX_VAR_PROC(1, 4, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_m = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ MATRIX_VAR_PROC(1, 4, 3, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=8) #define MATRIX_FP16_ITER_32X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ __m256 ymm_m_2 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_3 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=16) #define MATRIX_FP16_ITER_32X16_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ __m256 ymm_m_2 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_3 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q + 8))); \ ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=32) #define MATRIX_FP16_ITER_32X32_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(m)); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ ymm_mi = _LOAD((const __m256i *)(m + 16)); \ __m256 ymm_m_2 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_3 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m256 ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q + 8))); \ ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q + 16))); \ ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 16, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 17, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 18, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 19, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 20, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 21, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 22, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 23, ymm_m, ymm_q, _RES, _PROC) \ ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q + 24))); \ ymm_q = _mm256_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(4, 1, 24, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(4, 1, 25, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(4, 1, 26, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(4, 1, 27, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(4, 1, 28, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(4, 1, 29, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(4, 1, 30, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(4, 1, 31, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=1, N=1) #define MATRIX_FP16_ITER_1X1_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)m); \ __m512i zmm_qi = _LOAD((const __m512i *)q); \ __m512 zmm_m = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_q = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_qi)); \ _PROC(zmm_m, zmm_q, _RES##_0_0); \ zmm_m = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ zmm_q = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_qi, 1)); \ _PROC(zmm_m, zmm_q, _RES##_0_0); \ } //! Iterative process of computing distance (FP16, M=16, N=1) #define MATRIX_FP16_ITER_16X1_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _mm512_cvtph_ps(_LOAD((const __m256i *)(m))); \ __m512 zmm_q = _mm512_cvtph_ps(_mm256_set1_epi16(*(const short *)q)); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP16, M=16, N=2) #define MATRIX_FP16_ITER_16X2_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _mm512_cvtph_ps(_LOAD((const __m256i *)(m))); \ __m128 xmm_p = _mm_cvtph_ps(_mm_broadcast_si32(q)); \ __m512 zmm_q_0 = _mm512_set1_ps(xmm_p[0]); \ __m512 zmm_q_1 = _mm512_set1_ps(xmm_p[1]); \ MATRIX_VAR_PROC(1, 2, 0, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=16, N=4) #define MATRIX_FP16_ITER_16X4_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _mm512_cvtph_ps(_LOAD((const __m256i *)(m))); \ __m128 xmm_p = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(xmm_p[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(xmm_p[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(xmm_p[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(xmm_p[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP16, M=16, N=8) #define MATRIX_FP16_ITER_16X8_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _mm512_cvtph_ps(_LOAD((const __m256i *)(m))); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(ymm_p[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(ymm_p[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(ymm_p[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(ymm_p[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ zmm_q = _mm512_set1_ps(ymm_p[4]); \ _PROC(zmm_m, zmm_q, _RES##_0_4) \ zmm_q = _mm512_set1_ps(ymm_p[5]); \ _PROC(zmm_m, zmm_q, _RES##_0_5) \ zmm_q = _mm512_set1_ps(ymm_p[6]); \ _PROC(zmm_m, zmm_q, _RES##_0_6) \ zmm_q = _mm512_set1_ps(ymm_p[7]); \ _PROC(zmm_m, zmm_q, _RES##_0_7) \ } //! Iterative process of computing distance (FP16, M=16, N=16) #define MATRIX_FP16_ITER_16X16_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _mm512_cvtph_ps(_LOAD((const __m256i *)(m))); \ __m512 zmm_p = _mm512_cvtph_ps(_LOAD((const __m256i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(zmm_p[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(zmm_p[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(zmm_p[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(zmm_p[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ zmm_q = _mm512_set1_ps(zmm_p[4]); \ _PROC(zmm_m, zmm_q, _RES##_0_4) \ zmm_q = _mm512_set1_ps(zmm_p[5]); \ _PROC(zmm_m, zmm_q, _RES##_0_5) \ zmm_q = _mm512_set1_ps(zmm_p[6]); \ _PROC(zmm_m, zmm_q, _RES##_0_6) \ zmm_q = _mm512_set1_ps(zmm_p[7]); \ _PROC(zmm_m, zmm_q, _RES##_0_7) \ zmm_q = _mm512_set1_ps(zmm_p[8]); \ _PROC(zmm_m, zmm_q, _RES##_0_8) \ zmm_q = _mm512_set1_ps(zmm_p[9]); \ _PROC(zmm_m, zmm_q, _RES##_0_9) \ zmm_q = _mm512_set1_ps(zmm_p[10]); \ _PROC(zmm_m, zmm_q, _RES##_0_10) \ zmm_q = _mm512_set1_ps(zmm_p[11]); \ _PROC(zmm_m, zmm_q, _RES##_0_11) \ zmm_q = _mm512_set1_ps(zmm_p[12]); \ _PROC(zmm_m, zmm_q, _RES##_0_12) \ zmm_q = _mm512_set1_ps(zmm_p[13]); \ _PROC(zmm_m, zmm_q, _RES##_0_13) \ zmm_q = _mm512_set1_ps(zmm_p[14]); \ _PROC(zmm_m, zmm_q, _RES##_0_14) \ zmm_q = _mm512_set1_ps(zmm_p[15]); \ _PROC(zmm_m, zmm_q, _RES##_0_15) \ } //! Iterative process of computing distance (FP16, M=32, N=1) #define MATRIX_FP16_ITER_32X1_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m512 zmm_q = _mm512_cvtph_ps(_mm256_set1_epi16(*(const short *)q)); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=2) #define MATRIX_FP16_ITER_32X2_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m128 xmm_p = _mm_cvtph_ps(_mm_broadcast_si32(q)); \ __m512 zmm_q = _mm512_set1_ps(xmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(xmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=4) #define MATRIX_FP16_ITER_32X4_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m128 xmm_p = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(xmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(xmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(xmm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(xmm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=8) #define MATRIX_FP16_ITER_32X8_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m256 ymm_p = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(ymm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(ymm_p[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=16) #define MATRIX_FP16_ITER_32X16_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m512 zmm_p = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(zmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[8]); \ MATRIX_VAR_PROC(2, 1, 8, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[9]); \ MATRIX_VAR_PROC(2, 1, 9, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[10]); \ MATRIX_VAR_PROC(2, 1, 10, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[11]); \ MATRIX_VAR_PROC(2, 1, 11, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[12]); \ MATRIX_VAR_PROC(2, 1, 12, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[13]); \ MATRIX_VAR_PROC(2, 1, 13, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[14]); \ MATRIX_VAR_PROC(2, 1, 14, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[15]); \ MATRIX_VAR_PROC(2, 1, 15, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP16, M=32, N=32) #define MATRIX_FP16_ITER_32X32_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512i zmm_mi = _LOAD((const __m512i *)(m)); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ __m512 zmm_p = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(q))); \ __m512 zmm_q = _mm512_set1_ps(zmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[8]); \ MATRIX_VAR_PROC(2, 1, 8, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[9]); \ MATRIX_VAR_PROC(2, 1, 9, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[10]); \ MATRIX_VAR_PROC(2, 1, 10, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[11]); \ MATRIX_VAR_PROC(2, 1, 11, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[12]); \ MATRIX_VAR_PROC(2, 1, 12, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[13]); \ MATRIX_VAR_PROC(2, 1, 13, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[14]); \ MATRIX_VAR_PROC(2, 1, 14, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[15]); \ MATRIX_VAR_PROC(2, 1, 15, zmm_m, zmm_q, _RES, _PROC) \ zmm_p = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(q + 16))); \ zmm_q = _mm512_set1_ps(zmm_p[0]); \ MATRIX_VAR_PROC(2, 1, 16, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[1]); \ MATRIX_VAR_PROC(2, 1, 17, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[2]); \ MATRIX_VAR_PROC(2, 1, 18, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[3]); \ MATRIX_VAR_PROC(2, 1, 19, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[4]); \ MATRIX_VAR_PROC(2, 1, 20, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[5]); \ MATRIX_VAR_PROC(2, 1, 21, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[6]); \ MATRIX_VAR_PROC(2, 1, 22, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[7]); \ MATRIX_VAR_PROC(2, 1, 23, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[8]); \ MATRIX_VAR_PROC(2, 1, 24, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[9]); \ MATRIX_VAR_PROC(2, 1, 25, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[10]); \ MATRIX_VAR_PROC(2, 1, 26, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[11]); \ MATRIX_VAR_PROC(2, 1, 27, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[12]); \ MATRIX_VAR_PROC(2, 1, 28, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[13]); \ MATRIX_VAR_PROC(2, 1, 29, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[14]); \ MATRIX_VAR_PROC(2, 1, 30, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(zmm_p[15]); \ MATRIX_VAR_PROC(2, 1, 31, zmm_m, zmm_q, _RES, _PROC) \ } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) //! Iterative process of computing distance (FP16, M=1, N=1) #define MATRIX_FP16_ITER_1X1_NEON(m, q, _RES, _PROC) \ { \ float16x8_t v_m = vld1q_f16((const float16_t *)m); \ float16x8_t v_q = vld1q_f16((const float16_t *)q); \ _PROC(v_m, v_q, _RES##_0_0) \ } #else //! Iterative process of computing distance (FP16, M=1, N=1) #define MATRIX_FP16_ITER_1X1_NEON(m, q, _RES, _PROC) \ { \ float16x8_t v_m = vld1q_f16((const float16_t *)m); \ float16x8_t v_q = vld1q_f16((const float16_t *)q); \ float32x4_t v_m_0 = vcvt_f32_f16(vget_low_f16(v_m)); \ float32x4_t v_q_0 = vcvt_f32_f16(vget_low_f16(v_q)); \ _PROC(v_m_0, v_q_0, _RES##_0_0) \ v_m_0 = vcvt_high_f32_f16(v_m); \ v_q_0 = vcvt_high_f32_f16(v_q); \ _PROC(v_m_0, v_q_0, _RES##_0_0) \ } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ================================================ FILE: src/ailego/math/distance_matrix_fp32.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "matrix_define.i" #if !defined(__AVX__) #undef _mm_permute_ps #define _mm_permute_ps(a, b) _mm_shuffle_ps((a), (a), (b)) #define _mm_broadcast_ss(a) _mm_load1_ps(a) #endif // !__AVX__ #if defined(__AVX__) && defined(__GNUC__) #define _mm256_set_m128(a, b) \ _mm256_insertf128_ps(_mm256_castps128_ps256(b), (a), 1) #endif // __AVX__ #if defined(__ARM_NEON) && !defined(__aarch64__) #define vdupq_laneq_f32(a, b) vdupq_n_f32(vgetq_lane_f32(a, b)) #endif // __ARM_NEON && __aarch64__ //! Iterative process of computing distance (FP32, M=2, N=1) #define MATRIX_FP32_ITER_2X1_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _LOAD(q); \ __m128 xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(1, 1, 0, 0)); \ _PROC(xmm_m_0, xmm_p, _RES##_0_0) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(3, 3, 2, 2)); \ _PROC(xmm_m_1, xmm_p, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=2, N=2) #define MATRIX_FP32_ITER_2X2_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_q = _LOAD(q); \ __m128 xmm_m = _LOAD(m); \ __m128 xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(2, 2, 0, 0)); \ _PROC(xmm_m, xmm_p, _RES##_0_0) \ xmm_p = _mm_permute_ps(xmm_q, _MM_SHUFFLE(3, 3, 1, 1)); \ _PROC(xmm_m, xmm_p, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=1) #define MATRIX_FP32_ITER_4X1_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ _PROC(xmm_m_0, xmm_q, _RES##_0_0) \ xmm_q = _mm_broadcast_ss(q + 1); \ _PROC(xmm_m_1, xmm_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=2) #define MATRIX_FP32_ITER_4X2_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m = _LOAD(m); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ _PROC(xmm_m, xmm_q, _RES##_0_0) \ xmm_q = _mm_broadcast_ss(q + 1); \ _PROC(xmm_m, xmm_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=4) #define MATRIX_FP32_ITER_4X4_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m = _LOAD(m); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ _PROC(xmm_m, xmm_q, _RES##_0_0) \ xmm_q = _mm_broadcast_ss(q + 1); \ _PROC(xmm_m, xmm_q, _RES##_0_1) \ xmm_q = _mm_broadcast_ss(q + 2); \ _PROC(xmm_m, xmm_q, _RES##_0_2) \ xmm_q = _mm_broadcast_ss(q + 3); \ _PROC(xmm_m, xmm_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP32, M=8, N=1) #define MATRIX_FP32_ITER_8X1_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _mm_broadcast_ss(q); \ _PROC(xmm_m_0, xmm_q, _RES##_0_0) \ _PROC(xmm_m_1, xmm_q, _RES##_1_0) \ } //! Iterative process of computing distance (FP32, M=8, N=2) #define MATRIX_FP32_ITER_8X2_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ MATRIX_VAR_PROC(2, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=8, N=4) #define MATRIX_FP32_ITER_8X4_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ MATRIX_VAR_PROC(2, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(2, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(2, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=8, N=8) #define MATRIX_FP32_ITER_8X8_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(2, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(2, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(2, 1, 4, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(2, 1, 5, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(2, 1, 6, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(2, 1, 7, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=1) #define MATRIX_FP32_ITER_16X1_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=2) #define MATRIX_FP32_ITER_16X2_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=4) #define MATRIX_FP32_ITER_16X4_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_q = _mm_broadcast_ss(q + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=8) #define MATRIX_FP32_ITER_16X8_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(4, 1, 4, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(4, 1, 5, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(4, 1, 6, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(4, 1, 7, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=16) #define MATRIX_FP32_ITER_16X16_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(4, 1, 4, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(4, 1, 5, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(4, 1, 6, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(4, 1, 7, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(4, 1, 8, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(4, 1, 9, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(4, 1, 10, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(4, 1, 11, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(4, 1, 12, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(4, 1, 13, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(4, 1, 14, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(4, 1, 15, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=1) #define MATRIX_FP32_ITER_32X1_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_q = _mm_broadcast_ss(q); \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ _PROC(xmm_m_0, xmm_q, _RES##_0_0) \ _PROC(xmm_m_1, xmm_q, _RES##_1_0) \ _PROC(xmm_m_2, xmm_q, _RES##_2_0) \ _PROC(xmm_m_3, xmm_q, _RES##_3_0) \ xmm_m_0 = _LOAD(m + 16); \ xmm_m_1 = _LOAD(m + 20); \ xmm_m_2 = _LOAD(m + 24); \ xmm_m_3 = _LOAD(m + 28); \ _PROC(xmm_m_0, xmm_q, _RES##_4_0) \ _PROC(xmm_m_1, xmm_q, _RES##_5_0) \ _PROC(xmm_m_2, xmm_q, _RES##_6_0) \ _PROC(xmm_m_3, xmm_q, _RES##_7_0) \ } //! Iterative process of computing distance (FP32, M=32, N=2) #define MATRIX_FP32_ITER_32X2_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_q_0 = _mm_broadcast_ss(q + 0); \ __m128 xmm_q_1 = _mm_broadcast_ss(q + 1); \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ MATRIX_VAR_PROC(1, 2, 0, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 1, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 2, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 3, xmm_m_3, xmm_q, _RES, _PROC) \ xmm_m_0 = _LOAD(m + 16); \ xmm_m_1 = _LOAD(m + 20); \ xmm_m_2 = _LOAD(m + 24); \ xmm_m_3 = _LOAD(m + 28); \ MATRIX_VAR_PROC(1, 2, 4, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 5, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 6, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 7, xmm_m_3, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=4) #define MATRIX_FP32_ITER_32X4_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_q_0 = _mm_broadcast_ss(q + 0); \ __m128 xmm_q_1 = _mm_broadcast_ss(q + 1); \ __m128 xmm_q_2 = _mm_broadcast_ss(q + 2); \ __m128 xmm_q_3 = _mm_broadcast_ss(q + 3); \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ MATRIX_VAR_PROC(1, 4, 0, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 1, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 2, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 3, xmm_m_3, xmm_q, _RES, _PROC) \ xmm_m_0 = _LOAD(m + 16); \ xmm_m_1 = _LOAD(m + 20); \ xmm_m_2 = _LOAD(m + 24); \ xmm_m_3 = _LOAD(m + 28); \ MATRIX_VAR_PROC(1, 4, 4, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 5, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 6, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 7, xmm_m_3, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=8) #define MATRIX_FP32_ITER_32X8_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_q_0 = _mm_broadcast_ss(q + 0); \ __m128 xmm_q_1 = _mm_broadcast_ss(q + 1); \ __m128 xmm_q_2 = _mm_broadcast_ss(q + 2); \ __m128 xmm_q_3 = _mm_broadcast_ss(q + 3); \ __m128 xmm_q_4 = _mm_broadcast_ss(q + 4); \ __m128 xmm_q_5 = _mm_broadcast_ss(q + 5); \ __m128 xmm_q_6 = _mm_broadcast_ss(q + 6); \ __m128 xmm_q_7 = _mm_broadcast_ss(q + 7); \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ MATRIX_VAR_PROC(1, 8, 0, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 1, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 2, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 3, xmm_m_3, xmm_q, _RES, _PROC) \ xmm_m_0 = _LOAD(m + 16); \ xmm_m_1 = _LOAD(m + 20); \ xmm_m_2 = _LOAD(m + 24); \ xmm_m_3 = _LOAD(m + 28); \ MATRIX_VAR_PROC(1, 8, 4, xmm_m_0, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 5, xmm_m_1, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 6, xmm_m_2, xmm_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 7, xmm_m_3, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=16) #define MATRIX_FP32_ITER_32X16_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_m_4 = _LOAD(m + 16); \ __m128 xmm_m_5 = _LOAD(m + 20); \ __m128 xmm_m_6 = _LOAD(m + 24); \ __m128 xmm_m_7 = _LOAD(m + 28); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(8, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(8, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(8, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(8, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(8, 1, 4, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(8, 1, 5, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(8, 1, 6, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(8, 1, 7, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(8, 1, 8, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(8, 1, 9, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(8, 1, 10, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(8, 1, 11, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(8, 1, 12, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(8, 1, 13, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(8, 1, 14, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(8, 1, 15, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=32) #define MATRIX_FP32_ITER_32X32_SSE(m, q, _RES, _LOAD, _PROC) \ { \ __m128 xmm_m_0 = _LOAD(m + 0); \ __m128 xmm_m_1 = _LOAD(m + 4); \ __m128 xmm_m_2 = _LOAD(m + 8); \ __m128 xmm_m_3 = _LOAD(m + 12); \ __m128 xmm_m_4 = _LOAD(m + 16); \ __m128 xmm_m_5 = _LOAD(m + 20); \ __m128 xmm_m_6 = _LOAD(m + 24); \ __m128 xmm_m_7 = _LOAD(m + 28); \ __m128 xmm_q = _mm_broadcast_ss(q); \ MATRIX_VAR_PROC(8, 1, 0, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(8, 1, 1, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(8, 1, 2, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(8, 1, 3, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(8, 1, 4, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(8, 1, 5, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(8, 1, 6, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(8, 1, 7, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(8, 1, 8, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(8, 1, 9, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(8, 1, 10, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(8, 1, 11, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(8, 1, 12, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(8, 1, 13, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(8, 1, 14, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(8, 1, 15, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 16); \ MATRIX_VAR_PROC(8, 1, 16, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 17); \ MATRIX_VAR_PROC(8, 1, 17, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 18); \ MATRIX_VAR_PROC(8, 1, 18, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 19); \ MATRIX_VAR_PROC(8, 1, 19, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 20); \ MATRIX_VAR_PROC(8, 1, 20, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 21); \ MATRIX_VAR_PROC(8, 1, 21, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 22); \ MATRIX_VAR_PROC(8, 1, 22, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 23); \ MATRIX_VAR_PROC(8, 1, 23, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 24); \ MATRIX_VAR_PROC(8, 1, 24, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 25); \ MATRIX_VAR_PROC(8, 1, 25, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 26); \ MATRIX_VAR_PROC(8, 1, 26, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 27); \ MATRIX_VAR_PROC(8, 1, 27, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 28); \ MATRIX_VAR_PROC(8, 1, 28, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 29); \ MATRIX_VAR_PROC(8, 1, 29, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 30); \ MATRIX_VAR_PROC(8, 1, 30, xmm_m, xmm_q, _RES, _PROC) \ xmm_q = _mm_broadcast_ss(q + 31); \ MATRIX_VAR_PROC(8, 1, 31, xmm_m, xmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=2, N=1) #define MATRIX_FP32_ITER_2X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = \ _mm256_set_ps(q[3], q[3], q[2], q[2], q[1], q[1], q[0], q[0]); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP32, M=2, N=2) #define MATRIX_FP32_ITER_2X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_q = _LOAD(q); \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_p = _mm256_moveldup_ps(ymm_q); \ _PROC(ymm_m, ymm_p, _RES##_0_0) \ ymm_p = _mm256_movehdup_ps(ymm_q); \ _PROC(ymm_m, ymm_p, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=1) #define MATRIX_FP32_ITER_4X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = \ _mm256_set_m128(_mm_broadcast_ss(q + 1), _mm_broadcast_ss(q)); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP32, M=4, N=2) #define MATRIX_FP32_ITER_4X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = \ _mm256_set_m128(_mm_broadcast_ss(q + 2), _mm_broadcast_ss(q + 0)); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_set_m128(_mm_broadcast_ss(q + 3), _mm_broadcast_ss(q + 1)); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=4) #define MATRIX_FP32_ITER_4X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_q = _LOAD(q); \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(0, 0, 0, 0)); \ _PROC(ymm_m, ymm_p, _RES##_0_0) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(1, 1, 1, 1)); \ _PROC(ymm_m, ymm_p, _RES##_0_1) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(2, 2, 2, 2)); \ _PROC(ymm_m, ymm_p, _RES##_0_2) \ ymm_p = _mm256_permute_ps(ymm_q, _MM_SHUFFLE(3, 3, 3, 3)); \ _PROC(ymm_m, ymm_p, _RES##_0_3) \ } //! Iterative process of computing distance (FP32, M=8, N=1) #define MATRIX_FP32_ITER_8X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP32, M=8, N=2) #define MATRIX_FP32_ITER_8X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_broadcast_ss(q + 1); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=8, N=4) #define MATRIX_FP32_ITER_8X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_broadcast_ss(q + 1); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ ymm_q = _mm256_broadcast_ss(q + 2); \ _PROC(ymm_m, ymm_q, _RES##_0_2) \ ymm_q = _mm256_broadcast_ss(q + 3); \ _PROC(ymm_m, ymm_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP32, M=8, N=8) #define MATRIX_FP32_ITER_8X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m = _LOAD(m); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ _PROC(ymm_m, ymm_q, _RES##_0_0) \ ymm_q = _mm256_broadcast_ss(q + 1); \ _PROC(ymm_m, ymm_q, _RES##_0_1) \ ymm_q = _mm256_broadcast_ss(q + 2); \ _PROC(ymm_m, ymm_q, _RES##_0_2) \ ymm_q = _mm256_broadcast_ss(q + 3); \ _PROC(ymm_m, ymm_q, _RES##_0_3) \ ymm_q = _mm256_broadcast_ss(q + 4); \ _PROC(ymm_m, ymm_q, _RES##_0_4) \ ymm_q = _mm256_broadcast_ss(q + 5); \ _PROC(ymm_m, ymm_q, _RES##_0_5) \ ymm_q = _mm256_broadcast_ss(q + 6); \ _PROC(ymm_m, ymm_q, _RES##_0_6) \ ymm_q = _mm256_broadcast_ss(q + 7); \ _PROC(ymm_m, ymm_q, _RES##_0_7) \ } //! Iterative process of computing distance (FP32, M=16, N=1) #define MATRIX_FP32_ITER_16X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=2) #define MATRIX_FP32_ITER_16X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=4) #define MATRIX_FP32_ITER_16X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=8) #define MATRIX_FP32_ITER_16X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(2, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(2, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(2, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(2, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=16) #define MATRIX_FP32_ITER_16X16_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(2, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(2, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(2, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(2, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(2, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(2, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(2, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(2, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(2, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(2, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(2, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(2, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(2, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=1) #define MATRIX_FP32_ITER_32X1_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=2) #define MATRIX_FP32_ITER_32X2_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=4) #define MATRIX_FP32_ITER_32X4_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=8) #define MATRIX_FP32_ITER_32X8_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=16) #define MATRIX_FP32_ITER_32X16_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(4, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(4, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(4, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(4, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(4, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(4, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(4, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(4, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=32) #define MATRIX_FP32_ITER_32X32_AVX(m, q, _RES, _LOAD, _PROC) \ { \ __m256 ymm_m_0 = _LOAD(m + 0); \ __m256 ymm_m_1 = _LOAD(m + 8); \ __m256 ymm_m_2 = _LOAD(m + 16); \ __m256 ymm_m_3 = _LOAD(m + 24); \ __m256 ymm_q = _mm256_broadcast_ss(q); \ MATRIX_VAR_PROC(4, 1, 0, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 8); \ MATRIX_VAR_PROC(4, 1, 8, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 9); \ MATRIX_VAR_PROC(4, 1, 9, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 10); \ MATRIX_VAR_PROC(4, 1, 10, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 11); \ MATRIX_VAR_PROC(4, 1, 11, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 12); \ MATRIX_VAR_PROC(4, 1, 12, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 13); \ MATRIX_VAR_PROC(4, 1, 13, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 14); \ MATRIX_VAR_PROC(4, 1, 14, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 15); \ MATRIX_VAR_PROC(4, 1, 15, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 16); \ MATRIX_VAR_PROC(4, 1, 16, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 17); \ MATRIX_VAR_PROC(4, 1, 17, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 18); \ MATRIX_VAR_PROC(4, 1, 18, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 19); \ MATRIX_VAR_PROC(4, 1, 19, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 20); \ MATRIX_VAR_PROC(4, 1, 20, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 21); \ MATRIX_VAR_PROC(4, 1, 21, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 22); \ MATRIX_VAR_PROC(4, 1, 22, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 23); \ MATRIX_VAR_PROC(4, 1, 23, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 24); \ MATRIX_VAR_PROC(4, 1, 24, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 25); \ MATRIX_VAR_PROC(4, 1, 25, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 26); \ MATRIX_VAR_PROC(4, 1, 26, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 27); \ MATRIX_VAR_PROC(4, 1, 27, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 28); \ MATRIX_VAR_PROC(4, 1, 28, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 29); \ MATRIX_VAR_PROC(4, 1, 29, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 30); \ MATRIX_VAR_PROC(4, 1, 30, ymm_m, ymm_q, _RES, _PROC) \ ymm_q = _mm256_broadcast_ss(q + 31); \ MATRIX_VAR_PROC(4, 1, 31, ymm_m, ymm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=1) #define MATRIX_FP32_ITER_16X1_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _LOAD(m); \ __m512 zmm_q = _mm512_set1_ps(*q); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ } //! Iterative process of computing distance (FP32, M=16, N=2) #define MATRIX_FP32_ITER_16X2_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _LOAD(m); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(q[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=16, N=4) #define MATRIX_FP32_ITER_16X4_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _LOAD(m); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(q[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(q[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(q[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP32, M=16, N=8) #define MATRIX_FP32_ITER_16X8_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _LOAD(m); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(q[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(q[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(q[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ zmm_q = _mm512_set1_ps(q[4]); \ _PROC(zmm_m, zmm_q, _RES##_0_4) \ zmm_q = _mm512_set1_ps(q[5]); \ _PROC(zmm_m, zmm_q, _RES##_0_5) \ zmm_q = _mm512_set1_ps(q[6]); \ _PROC(zmm_m, zmm_q, _RES##_0_6) \ zmm_q = _mm512_set1_ps(q[7]); \ _PROC(zmm_m, zmm_q, _RES##_0_7) \ } //! Iterative process of computing distance (FP32, M=16, N=16) #define MATRIX_FP32_ITER_16X16_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m = _LOAD(m); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_q = _mm512_set1_ps(q[1]); \ _PROC(zmm_m, zmm_q, _RES##_0_1) \ zmm_q = _mm512_set1_ps(q[2]); \ _PROC(zmm_m, zmm_q, _RES##_0_2) \ zmm_q = _mm512_set1_ps(q[3]); \ _PROC(zmm_m, zmm_q, _RES##_0_3) \ zmm_q = _mm512_set1_ps(q[4]); \ _PROC(zmm_m, zmm_q, _RES##_0_4) \ zmm_q = _mm512_set1_ps(q[5]); \ _PROC(zmm_m, zmm_q, _RES##_0_5) \ zmm_q = _mm512_set1_ps(q[6]); \ _PROC(zmm_m, zmm_q, _RES##_0_6) \ zmm_q = _mm512_set1_ps(q[7]); \ _PROC(zmm_m, zmm_q, _RES##_0_7) \ zmm_q = _mm512_set1_ps(q[8]); \ _PROC(zmm_m, zmm_q, _RES##_0_8) \ zmm_q = _mm512_set1_ps(q[9]); \ _PROC(zmm_m, zmm_q, _RES##_0_9) \ zmm_q = _mm512_set1_ps(q[10]); \ _PROC(zmm_m, zmm_q, _RES##_0_10) \ zmm_q = _mm512_set1_ps(q[11]); \ _PROC(zmm_m, zmm_q, _RES##_0_11) \ zmm_q = _mm512_set1_ps(q[12]); \ _PROC(zmm_m, zmm_q, _RES##_0_12) \ zmm_q = _mm512_set1_ps(q[13]); \ _PROC(zmm_m, zmm_q, _RES##_0_13) \ zmm_q = _mm512_set1_ps(q[14]); \ _PROC(zmm_m, zmm_q, _RES##_0_14) \ zmm_q = _mm512_set1_ps(q[15]); \ _PROC(zmm_m, zmm_q, _RES##_0_15) \ } //! Iterative process of computing distance (FP32, M=32, N=1) #define MATRIX_FP32_ITER_32X1_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_q = _mm512_set1_ps(*q); \ __m512 zmm_m = _LOAD(m); \ _PROC(zmm_m, zmm_q, _RES##_0_0) \ zmm_m = _LOAD(m + 16); \ _PROC(zmm_m, zmm_q, _RES##_1_0) \ } //! Iterative process of computing distance (FP32, M=32, N=2) #define MATRIX_FP32_ITER_32X2_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m_0 = _LOAD(m + 0); \ __m512 zmm_m_1 = _LOAD(m + 16); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=4) #define MATRIX_FP32_ITER_32X4_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m_0 = _LOAD(m + 0); \ __m512 zmm_m_1 = _LOAD(m + 16); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=8) #define MATRIX_FP32_ITER_32X8_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m_0 = _LOAD(m + 0); \ __m512 zmm_m_1 = _LOAD(m + 16); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=16) #define MATRIX_FP32_ITER_32X16_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m_0 = _LOAD(m + 0); \ __m512 zmm_m_1 = _LOAD(m + 16); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[8]); \ MATRIX_VAR_PROC(2, 1, 8, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[9]); \ MATRIX_VAR_PROC(2, 1, 9, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[10]); \ MATRIX_VAR_PROC(2, 1, 10, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[11]); \ MATRIX_VAR_PROC(2, 1, 11, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[12]); \ MATRIX_VAR_PROC(2, 1, 12, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[13]); \ MATRIX_VAR_PROC(2, 1, 13, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[14]); \ MATRIX_VAR_PROC(2, 1, 14, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[15]); \ MATRIX_VAR_PROC(2, 1, 15, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=32) #define MATRIX_FP32_ITER_32X32_AVX512(m, q, _RES, _LOAD, _PROC) \ { \ __m512 zmm_m_0 = _LOAD(m + 0); \ __m512 zmm_m_1 = _LOAD(m + 16); \ __m512 zmm_q = _mm512_set1_ps(q[0]); \ MATRIX_VAR_PROC(2, 1, 0, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[1]); \ MATRIX_VAR_PROC(2, 1, 1, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[2]); \ MATRIX_VAR_PROC(2, 1, 2, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[3]); \ MATRIX_VAR_PROC(2, 1, 3, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[4]); \ MATRIX_VAR_PROC(2, 1, 4, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[5]); \ MATRIX_VAR_PROC(2, 1, 5, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[6]); \ MATRIX_VAR_PROC(2, 1, 6, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[7]); \ MATRIX_VAR_PROC(2, 1, 7, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[8]); \ MATRIX_VAR_PROC(2, 1, 8, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[9]); \ MATRIX_VAR_PROC(2, 1, 9, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[10]); \ MATRIX_VAR_PROC(2, 1, 10, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[11]); \ MATRIX_VAR_PROC(2, 1, 11, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[12]); \ MATRIX_VAR_PROC(2, 1, 12, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[13]); \ MATRIX_VAR_PROC(2, 1, 13, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[14]); \ MATRIX_VAR_PROC(2, 1, 14, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[15]); \ MATRIX_VAR_PROC(2, 1, 15, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[16]); \ MATRIX_VAR_PROC(2, 1, 16, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[17]); \ MATRIX_VAR_PROC(2, 1, 17, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[18]); \ MATRIX_VAR_PROC(2, 1, 18, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[19]); \ MATRIX_VAR_PROC(2, 1, 19, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[20]); \ MATRIX_VAR_PROC(2, 1, 20, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[21]); \ MATRIX_VAR_PROC(2, 1, 21, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[22]); \ MATRIX_VAR_PROC(2, 1, 22, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[23]); \ MATRIX_VAR_PROC(2, 1, 23, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[24]); \ MATRIX_VAR_PROC(2, 1, 24, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[25]); \ MATRIX_VAR_PROC(2, 1, 25, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[26]); \ MATRIX_VAR_PROC(2, 1, 26, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[27]); \ MATRIX_VAR_PROC(2, 1, 27, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[28]); \ MATRIX_VAR_PROC(2, 1, 28, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[29]); \ MATRIX_VAR_PROC(2, 1, 29, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[30]); \ MATRIX_VAR_PROC(2, 1, 30, zmm_m, zmm_q, _RES, _PROC) \ zmm_q = _mm512_set1_ps(q[31]); \ MATRIX_VAR_PROC(2, 1, 31, zmm_m, zmm_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=2, N=1) #define MATRIX_FP32_ITER_2X1_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m = vld1q_f32(m); \ float32x2_t v_q = vld1_f32(q); \ float32x4_t v_p = \ vcombine_f32(vdup_lane_f32(v_q, 0), vdup_lane_f32(v_q, 1)); \ _PROC(v_m, v_p, _RES) \ } //! Iterative process of computing distance (FP32, M=2, N=2) #define MATRIX_FP32_ITER_2X2_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_q = vld1q_f32(q); \ float32x4_t v_m = vld1q_f32(m); \ float32x2_t v_q_0 = vget_low_f32(v_q); \ float32x2_t v_q_1 = vget_high_f32(v_q); \ v_q = vcombine_f32(vdup_lane_f32(v_q_0, 0), vdup_lane_f32(v_q_1, 0)); \ _PROC(v_m, v_q, _RES##_0_0) \ v_q = vcombine_f32(vdup_lane_f32(v_q_0, 1), vdup_lane_f32(v_q_1, 1)); \ _PROC(v_m, v_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=1) #define MATRIX_FP32_ITER_4X1_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x2_t v_p = vld1_f32(q); \ float32x4_t v_q = vdupq_lane_f32(v_p, 0); \ _PROC(v_m_0, v_q, _RES##_0_0) \ v_q = vdupq_lane_f32(v_p, 1); \ _PROC(v_m_1, v_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=2) #define MATRIX_FP32_ITER_4X2_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m = vld1q_f32(m); \ float32x2_t v_p = vld1_f32(q); \ float32x4_t v_q = vdupq_lane_f32(v_p, 0); \ _PROC(v_m, v_q, _RES##_0_0) \ v_q = vdupq_lane_f32(v_p, 1); \ _PROC(v_m, v_q, _RES##_0_1) \ } //! Iterative process of computing distance (FP32, M=4, N=4) #define MATRIX_FP32_ITER_4X4_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m = vld1q_f32(m); \ float32x4_t v_p = vld1q_f32(q); \ float32x4_t v_q = vdupq_laneq_f32(v_p, 0); \ _PROC(v_m, v_q, _RES##_0_0) \ v_q = vdupq_laneq_f32(v_p, 1); \ _PROC(v_m, v_q, _RES##_0_1) \ v_q = vdupq_laneq_f32(v_p, 2); \ _PROC(v_m, v_q, _RES##_0_2) \ v_q = vdupq_laneq_f32(v_p, 3); \ _PROC(v_m, v_q, _RES##_0_3) \ } //! Iterative process of computing distance (FP32, M=8, N=1) #define MATRIX_FP32_ITER_8X1_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_q = vld1q_dup_f32(q); \ _PROC(v_m_0, v_q, _RES##_0_0) \ _PROC(v_m_1, v_q, _RES##_1_0) \ } //! Iterative process of computing distance (FP32, M=8, N=2) #define MATRIX_FP32_ITER_8X2_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x2_t v_p = vld1_f32(q); \ float32x4_t v_q = vdupq_lane_f32(v_p, 0); \ MATRIX_VAR_PROC(2, 1, 0, v_m, v_q, _RES, _PROC) \ v_q = vdupq_lane_f32(v_p, 1); \ MATRIX_VAR_PROC(2, 1, 1, v_m, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=8, N=4) #define MATRIX_FP32_ITER_8X4_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_p = vld1q_f32(q); \ float32x4_t v_q = vdupq_laneq_f32(v_p, 0); \ MATRIX_VAR_PROC(2, 1, 0, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 1); \ MATRIX_VAR_PROC(2, 1, 1, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 2); \ MATRIX_VAR_PROC(2, 1, 2, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 3); \ MATRIX_VAR_PROC(2, 1, 3, v_m, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=8, N=8) #define MATRIX_FP32_ITER_8X8_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_p = vld1q_f32(q + 0); \ float32x4_t v_q = vdupq_laneq_f32(v_p, 0); \ MATRIX_VAR_PROC(2, 1, 0, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 1); \ MATRIX_VAR_PROC(2, 1, 1, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 2); \ MATRIX_VAR_PROC(2, 1, 2, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 3); \ MATRIX_VAR_PROC(2, 1, 3, v_m, v_q, _RES, _PROC) \ v_p = vld1q_f32(q + 4); \ v_q = vdupq_laneq_f32(v_p, 0); \ MATRIX_VAR_PROC(2, 1, 4, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 1); \ MATRIX_VAR_PROC(2, 1, 5, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 2); \ MATRIX_VAR_PROC(2, 1, 6, v_m, v_q, _RES, _PROC) \ v_q = vdupq_laneq_f32(v_p, 3); \ MATRIX_VAR_PROC(2, 1, 7, v_m, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=1) #define MATRIX_FP32_ITER_16X1_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_q = vld1q_dup_f32(q); \ MATRIX_VAR_PROC(4, 1, 0, v_m, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=2) #define MATRIX_FP32_ITER_16X2_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x2_t v_p = vld1_f32(q); \ float32x4_t v_q = vdupq_lane_f32(v_p, 0); \ MATRIX_VAR_PROC(4, 1, 0, v_m, v_q, _RES, _PROC) \ v_q = vdupq_lane_f32(v_p, 1); \ MATRIX_VAR_PROC(4, 1, 1, v_m, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=4) #define MATRIX_FP32_ITER_16X4_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_q = vld1q_f32(q); \ float32x4_t v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 0, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 1, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 2, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 3, v_m, v_p, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=8) #define MATRIX_FP32_ITER_16X8_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_q = vld1q_f32(q + 0); \ float32x4_t v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 0, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 1, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 2, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 3, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 4); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 4, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 5, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 6, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 7, v_m, v_p, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=16, N=16) #define MATRIX_FP32_ITER_16X16_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_q = vld1q_f32(q + 0); \ float32x4_t v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 0, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 1, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 2, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 3, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 4); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 4, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 5, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 6, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 7, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 8); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 8, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 9, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 10, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 11, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 12); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(4, 1, 12, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(4, 1, 13, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(4, 1, 14, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(4, 1, 15, v_m, v_p, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=1) #define MATRIX_FP32_ITER_32X1_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_q = vld1q_dup_f32(q); \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ _PROC(v_m_0, v_q, _RES##_0_0) \ _PROC(v_m_1, v_q, _RES##_1_0) \ _PROC(v_m_2, v_q, _RES##_2_0) \ _PROC(v_m_3, v_q, _RES##_3_0) \ v_m_0 = vld1q_f32(m + 16); \ v_m_1 = vld1q_f32(m + 20); \ v_m_2 = vld1q_f32(m + 24); \ v_m_3 = vld1q_f32(m + 28); \ _PROC(v_m_0, v_q, _RES##_4_0) \ _PROC(v_m_1, v_q, _RES##_5_0) \ _PROC(v_m_2, v_q, _RES##_6_0) \ _PROC(v_m_3, v_q, _RES##_7_0) \ } //! Iterative process of computing distance (FP32, M=32, N=2) #define MATRIX_FP32_ITER_32X2_NEON(m, q, _RES, _PROC) \ { \ float32x2_t v_p = vld1_f32(q); \ float32x4_t v_q_0 = vdupq_lane_f32(v_p, 0); \ float32x4_t v_q_1 = vdupq_lane_f32(v_p, 1); \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ MATRIX_VAR_PROC(1, 2, 0, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 1, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 2, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 3, v_m_3, v_q, _RES, _PROC) \ v_m_0 = vld1q_f32(m + 16); \ v_m_1 = vld1q_f32(m + 20); \ v_m_2 = vld1q_f32(m + 24); \ v_m_3 = vld1q_f32(m + 28); \ MATRIX_VAR_PROC(1, 2, 4, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 5, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 6, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 7, v_m_3, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=4) #define MATRIX_FP32_ITER_32X4_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_p = vld1q_f32(q); \ float32x4_t v_q_0 = vdupq_laneq_f32(v_p, 0); \ float32x4_t v_q_1 = vdupq_laneq_f32(v_p, 1); \ float32x4_t v_q_2 = vdupq_laneq_f32(v_p, 2); \ float32x4_t v_q_3 = vdupq_laneq_f32(v_p, 3); \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ MATRIX_VAR_PROC(1, 4, 0, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 1, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 2, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 3, v_m_3, v_q, _RES, _PROC) \ v_m_0 = vld1q_f32(m + 16); \ v_m_1 = vld1q_f32(m + 20); \ v_m_2 = vld1q_f32(m + 24); \ v_m_3 = vld1q_f32(m + 28); \ MATRIX_VAR_PROC(1, 4, 4, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 5, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 6, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 7, v_m_3, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=8) #define MATRIX_FP32_ITER_32X8_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_p_0 = vld1q_f32(q + 0); \ float32x4_t v_p_1 = vld1q_f32(q + 4); \ float32x4_t v_q_0 = vdupq_laneq_f32(v_p_0, 0); \ float32x4_t v_q_1 = vdupq_laneq_f32(v_p_0, 1); \ float32x4_t v_q_2 = vdupq_laneq_f32(v_p_0, 2); \ float32x4_t v_q_3 = vdupq_laneq_f32(v_p_0, 3); \ float32x4_t v_q_4 = vdupq_laneq_f32(v_p_1, 0); \ float32x4_t v_q_5 = vdupq_laneq_f32(v_p_1, 1); \ float32x4_t v_q_6 = vdupq_laneq_f32(v_p_1, 2); \ float32x4_t v_q_7 = vdupq_laneq_f32(v_p_1, 3); \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ MATRIX_VAR_PROC(1, 8, 0, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 1, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 2, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 3, v_m_3, v_q, _RES, _PROC) \ v_m_0 = vld1q_f32(m + 16); \ v_m_1 = vld1q_f32(m + 20); \ v_m_2 = vld1q_f32(m + 24); \ v_m_3 = vld1q_f32(m + 28); \ MATRIX_VAR_PROC(1, 8, 4, v_m_0, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 5, v_m_1, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 6, v_m_2, v_q, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 7, v_m_3, v_q, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=16) #define MATRIX_FP32_ITER_32X16_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_m_4 = vld1q_f32(m + 16); \ float32x4_t v_m_5 = vld1q_f32(m + 20); \ float32x4_t v_m_6 = vld1q_f32(m + 24); \ float32x4_t v_m_7 = vld1q_f32(m + 28); \ float32x4_t v_q = vld1q_f32(q + 0); \ float32x4_t v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 0, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 1, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 2, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 3, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 4); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 4, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 5, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 6, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 7, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 8); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 8, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 9, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 10, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 11, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 12); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 12, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 13, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 14, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 15, v_m, v_p, _RES, _PROC) \ } //! Iterative process of computing distance (FP32, M=32, N=32) #define MATRIX_FP32_ITER_32X32_NEON(m, q, _RES, _PROC) \ { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ float32x4_t v_m_2 = vld1q_f32(m + 8); \ float32x4_t v_m_3 = vld1q_f32(m + 12); \ float32x4_t v_m_4 = vld1q_f32(m + 16); \ float32x4_t v_m_5 = vld1q_f32(m + 20); \ float32x4_t v_m_6 = vld1q_f32(m + 24); \ float32x4_t v_m_7 = vld1q_f32(m + 28); \ float32x4_t v_q = vld1q_f32(q + 0); \ float32x4_t v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 0, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 1, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 2, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 3, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 4); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 4, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 5, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 6, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 7, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 8); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 8, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 9, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 10, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 11, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 12); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 12, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 13, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 14, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 15, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 16); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 16, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 17, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 18, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 19, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 20); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 20, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 21, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 22, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 23, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 24); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 24, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 25, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 26, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 27, v_m, v_p, _RES, _PROC) \ v_q = vld1q_f32(q + 28); \ v_p = vdupq_laneq_f32(v_q, 0); \ MATRIX_VAR_PROC(8, 1, 28, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 1); \ MATRIX_VAR_PROC(8, 1, 29, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 2); \ MATRIX_VAR_PROC(8, 1, 30, v_m, v_p, _RES, _PROC) \ v_p = vdupq_laneq_f32(v_q, 3); \ MATRIX_VAR_PROC(8, 1, 31, v_m, v_p, _RES, _PROC) \ } ================================================ FILE: src/ailego/math/distance_matrix_inner_product_utility.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if defined(__SSE4_1__) //! Four-bits Convert Table static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; #define NEGZEROS_FP32_SSE _mm_set1_ps(-0.0f) #define MASK_INT4_SSE _mm_set1_epi32(0x0f0f0f0f) #define ONES_INT16_SSE _mm_set1_epi32(0x00010001) #define INT4_LOOKUP_SSE _mm_load_si128((const __m128i *)Int4ConvertTable) #endif // __SSE4_1__ #if defined(__AVX__) // #define NEGZEROS_FP32_AVX _mm256_set1_ps(-0.0f) #define MASK_INT4_AVX _mm256_set1_epi32(0x0f0f0f0f) #define ONES_INT16_AVX _mm256_set1_epi32(0x00010001) #define INT4_LOOKUP_AVX _mm256_load_si256((const __m256i *)Int4ConvertTable) #endif // __AVX__ #if defined(__AVX512F__) && !defined(__AVX512DQ__) #define _mm512_xor_ps(a, b) \ _mm512_castsi512_ps( \ _mm512_xor_epi32(_mm512_castps_si512(a), _mm512_castps_si512(b))) #endif // __AVX512DQ__ //! Reverse sign of value (GENERAL) #define NEGATE_FP32_GENERAL(v) -(v) //! Calculate Fused-Multiply-Add (SSE) #define FMA_FP32_SSE(xmm_m, xmm_q, xmm_sum) \ xmm_sum = _mm_fmadd_ps(xmm_m, xmm_q, xmm_sum); //! Calculate Fused-Multiply-Add (AVX) #define FMA_FP32_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_q, ymm_sum); //! Calculate Fused-Multiply-Add (AVX512) #define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); //! Calculate Fused-Multiply-Add (AVX512FP16) #define FMA_FP16_AVX512FP16(zmm_m, zmm_q, zmm_sum) \ zmm_sum = _mm512_fmadd_ph(zmm_m, zmm_q, zmm_sum); //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_FP16_GENERAL(m, q, sum) sum += (m * q); //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_FP32_GENERAL(m, q, sum) sum += (m * q); //! Calculate Fused-Multiply-Add (NEON) #define FMA_FP16_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f16(v_sum, v_m, v_q); //! Calculate Fused-Multiply-Add (NEON) #define FMA_FP32_NEON(v_m, v_q, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_q); //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_INT4_GENERAL(m, q, sum) \ sum += Int4MulTable[(((m) << 4) & 0xf0) | (((q) >> 0) & 0xf)] + \ Int4MulTable[(((m) >> 0) & 0xf0) | (((q) >> 4) & 0xf)]; //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_INT8_GENERAL(m, q, sum) sum += static_cast(m * q); //! Calculate Fused-Multiply-Add (SSE) #define FMA_INT8_SSE(xmm_m, xmm_q, xmm_sum) \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16( \ _mm_maddubs_epi16(_mm_abs_epi8(xmm_q), _mm_sign_epi8(xmm_m, xmm_q)), \ ONES_INT16_SSE), \ xmm_sum); //! Calculate Fused-Multiply-Add (AVX) #define FMA_INT8_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16(_mm256_maddubs_epi16(_mm256_abs_epi8(ymm_q), \ _mm256_sign_epi8(ymm_m, ymm_q)), \ ONES_INT16_AVX), \ ymm_sum); //! Calculate Fused-Multiply-Add (SSE) #define FMA_INT4_SSE(xmm_m, xmm_q, xmm_sum) \ { \ __m128i xmm_lhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ _mm_and_si128((xmm_m), MASK_INT4_SSE)); \ __m128i xmm_rhs = _mm_shuffle_epi8(INT4_LOOKUP_SSE, \ _mm_and_si128((xmm_q), MASK_INT4_SSE)); \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ ONES_INT16_SSE), \ xmm_sum); \ xmm_lhs = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_m), 4), MASK_INT4_SSE)); \ xmm_rhs = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_q), 4), MASK_INT4_SSE)); \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ ONES_INT16_SSE), \ xmm_sum); \ } //! Calculate Fused-Multiply-Add (AVX) #define FMA_INT4_AVX(ymm_m, ymm_q, ymm_sum) \ { \ __m256i ymm_lhs = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_m), MASK_INT4_AVX)); \ __m256i ymm_rhs = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_q), MASK_INT4_AVX)); \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ ONES_INT16_AVX), \ ymm_sum); \ ymm_lhs = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_m), 4), MASK_INT4_AVX)); \ ymm_rhs = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_q), 4), MASK_INT4_AVX)); \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ ONES_INT16_AVX), \ ymm_sum); \ } //! Compute the distance between matrix and query #define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ { \ __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); \ xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); \ xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); \ xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); \ xmm_lhs_0 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), \ ONES_INT16_SSE); \ xmm_lhs_1 = _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), \ ONES_INT16_SSE); \ xmm_sum = _mm_add_epi32(_mm_add_epi32(xmm_lhs_0, xmm_lhs_1), xmm_sum); \ } //! Compute the distance between matrix and query #define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ { \ __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); \ ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); \ ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); \ ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); \ ymm_lhs_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), \ ONES_INT16_AVX); \ ymm_lhs_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), \ ONES_INT16_AVX); \ ymm_sum = \ _mm256_add_epi32(_mm256_add_epi32(ymm_lhs_0, ymm_lhs_1), ymm_sum); \ } #define ACCUM_FP16_STEP_GENERAL FMA_FP16_GENERAL #define ACCUM_FP16_STEP_NEON FMA_FP16_NEON #define ACCUM_FP32_STEP_SSE FMA_FP32_SSE #define ACCUM_FP32_STEP_AVX FMA_FP32_AVX #define ACCUM_FP32_STEP_AVX512 FMA_FP32_AVX512 #define ACCUM_FP32_STEP_NEON FMA_FP32_NEON #define ACCUM_INT4_STEP_SSE FMA_INT4_SSE #define ACCUM_INT4_STEP_AVX FMA_INT4_AVX #define ACCUM_INT8_STEP_SSE FMA_INT8_SSE #define ACCUM_INT8_STEP_AVX FMA_INT8_AVX ================================================ FILE: src/ailego/math/distance_matrix_int32.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "matrix_define.i" #if defined(__AVX__) && defined(__GNUC__) #define _mm256_set_m128i(a, b) \ _mm256_inserti128_si256(_mm256_castsi128_si256(b), (a), 1) #endif // __AVX__ #if !defined(__AVX__) #define _mm_broadcast_si32(a) _mm_castps_si128(_mm_load1_ps((const float *)(a))) #else #define _mm_broadcast_si32(a) \ _mm_castps_si128(_mm_broadcast_ss((const float *)(a))) #define _mm256_broadcast_si32(a) \ _mm256_castps_si256(_mm256_broadcast_ss((const float *)(a))) #endif // !__AVX__ //! Iterative process of computing distance (INT32, M=2, N=1) #define MATRIX_INT32_ITER_2X1_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi = _LOAD((const __m128i *)(qi)); \ __m128i xmm_mi = _LOAD((const __m128i *)(mi)); \ __m128i xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(1, 1, 0, 0)); \ _PROC(xmm_mi, xmm_pi, _RES##_0_0) \ xmm_mi = _LOAD((const __m128i *)(mi + 4)); \ xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(3, 3, 2, 2)); \ _PROC(xmm_mi, xmm_pi, _RES##_0_1) \ } //! Iterative process of computing distance (INT32, M=2, N=2) #define MATRIX_INT32_ITER_2X2_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi = _LOAD((const __m128i *)(qi)); \ __m128i xmm_mi = _LOAD((const __m128i *)(mi)); \ __m128i xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(2, 2, 0, 0)); \ _PROC(xmm_mi, xmm_pi, _RES##_0_0) \ xmm_pi = _mm_shuffle_epi32(xmm_qi, _MM_SHUFFLE(3, 3, 1, 1)); \ _PROC(xmm_mi, xmm_pi, _RES##_0_1) \ } //! Iterative process of computing distance (INT32, M=4, N=1) #define MATRIX_INT32_ITER_4X1_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ _PROC(xmm_mi_0, xmm_qi, _RES##_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ _PROC(xmm_mi_1, xmm_qi, _RES##_1_0) \ } //! Iterative process of computing distance (INT32, M=4, N=2) #define MATRIX_INT32_ITER_4X2_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi_0 = _mm_broadcast_si32(qi + 0); \ __m128i xmm_qi_1 = _mm_broadcast_si32(qi + 1); \ __m128i xmm_mi = _LOAD((const __m128i *)(mi)); \ MATRIX_VAR_PROC(1, 2, 0, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=4, N=4) #define MATRIX_INT32_ITER_4X4_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi = _LOAD((const __m128i *)(mi)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ _PROC(xmm_mi, xmm_qi, _RES##_0_0) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ _PROC(xmm_mi, xmm_qi, _RES##_0_1) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ _PROC(xmm_mi, xmm_qi, _RES##_0_2) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ _PROC(xmm_mi, xmm_qi, _RES##_0_3) \ } //! Iterative process of computing distance (INT32, M=8, N=1) #define MATRIX_INT32_ITER_8X1_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ MATRIX_VAR_PROC(2, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=8, N=2) #define MATRIX_INT32_ITER_8X2_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi_0 = _mm_broadcast_si32(qi + 0); \ __m128i xmm_qi_1 = _mm_broadcast_si32(qi + 1); \ __m128i xmm_mi = _LOAD((const __m128i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 2, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_mi = _LOAD((const __m128i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 2, 1, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=8, N=4) #define MATRIX_INT32_ITER_8X4_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=8, N=8) #define MATRIX_INT32_ITER_8X8_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(2, 1, 4, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(2, 1, 5, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(2, 1, 6, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(2, 1, 7, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=1) #define MATRIX_INT32_ITER_16X1_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ MATRIX_VAR_PROC(4, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=2) #define MATRIX_INT32_ITER_16X2_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=4) #define MATRIX_INT32_ITER_16X4_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=8) #define MATRIX_INT32_ITER_16X8_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=16) #define MATRIX_INT32_ITER_16X16_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(4, 1, 8, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(4, 1, 9, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(4, 1, 10, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(4, 1, 11, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(4, 1, 12, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(4, 1, 13, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(4, 1, 14, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(4, 1, 15, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=1) #define MATRIX_INT32_ITER_32X1_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi = _mm_broadcast_si32(qi); \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ _PROC(xmm_mi_0, xmm_qi, _RES##_0_0) \ _PROC(xmm_mi_1, xmm_qi, _RES##_1_0) \ _PROC(xmm_mi_2, xmm_qi, _RES##_2_0) \ _PROC(xmm_mi_3, xmm_qi, _RES##_3_0) \ xmm_mi_0 = _LOAD((const __m128i *)(mi + 16)); \ xmm_mi_1 = _LOAD((const __m128i *)(mi + 20)); \ xmm_mi_2 = _LOAD((const __m128i *)(mi + 24)); \ xmm_mi_3 = _LOAD((const __m128i *)(mi + 28)); \ _PROC(xmm_mi_0, xmm_qi, _RES##_4_0) \ _PROC(xmm_mi_1, xmm_qi, _RES##_5_0) \ _PROC(xmm_mi_2, xmm_qi, _RES##_6_0) \ _PROC(xmm_mi_3, xmm_qi, _RES##_7_0) \ } //! Iterative process of computing distance (INT32, M=32, N=2) #define MATRIX_INT32_ITER_32X2_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi_0 = _mm_broadcast_si32(qi + 0); \ __m128i xmm_qi_1 = _mm_broadcast_si32(qi + 1); \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 2, 0, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 1, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 2, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 3, xmm_mi_3, xmm_qi, _RES, _PROC) \ xmm_mi_0 = _LOAD((const __m128i *)(mi + 16)); \ xmm_mi_1 = _LOAD((const __m128i *)(mi + 20)); \ xmm_mi_2 = _LOAD((const __m128i *)(mi + 24)); \ xmm_mi_3 = _LOAD((const __m128i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 2, 4, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 5, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 6, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 2, 7, xmm_mi_3, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=4) #define MATRIX_INT32_ITER_32X4_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi_0 = _mm_broadcast_si32(qi + 0); \ __m128i xmm_qi_1 = _mm_broadcast_si32(qi + 1); \ __m128i xmm_qi_2 = _mm_broadcast_si32(qi + 2); \ __m128i xmm_qi_3 = _mm_broadcast_si32(qi + 3); \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 4, 0, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 1, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 2, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 3, xmm_mi_3, xmm_qi, _RES, _PROC) \ xmm_mi_0 = _LOAD((const __m128i *)(mi + 16)); \ xmm_mi_1 = _LOAD((const __m128i *)(mi + 20)); \ xmm_mi_2 = _LOAD((const __m128i *)(mi + 24)); \ xmm_mi_3 = _LOAD((const __m128i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 4, 4, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 5, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 6, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 4, 7, xmm_mi_3, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=8) #define MATRIX_INT32_ITER_32X8_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_qi_0 = _mm_broadcast_si32(qi + 0); \ __m128i xmm_qi_1 = _mm_broadcast_si32(qi + 1); \ __m128i xmm_qi_2 = _mm_broadcast_si32(qi + 2); \ __m128i xmm_qi_3 = _mm_broadcast_si32(qi + 3); \ __m128i xmm_qi_4 = _mm_broadcast_si32(qi + 4); \ __m128i xmm_qi_5 = _mm_broadcast_si32(qi + 5); \ __m128i xmm_qi_6 = _mm_broadcast_si32(qi + 6); \ __m128i xmm_qi_7 = _mm_broadcast_si32(qi + 7); \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 8, 0, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 1, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 2, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 3, xmm_mi_3, xmm_qi, _RES, _PROC) \ xmm_mi_0 = _LOAD((const __m128i *)(mi + 16)); \ xmm_mi_1 = _LOAD((const __m128i *)(mi + 20)); \ xmm_mi_2 = _LOAD((const __m128i *)(mi + 24)); \ xmm_mi_3 = _LOAD((const __m128i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 8, 4, xmm_mi_0, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 5, xmm_mi_1, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 6, xmm_mi_2, xmm_qi, _RES, _PROC) \ MATRIX_VAR_PROC(1, 8, 7, xmm_mi_3, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=16) #define MATRIX_INT32_ITER_32X16_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_mi_4 = _LOAD((const __m128i *)(mi + 16)); \ __m128i xmm_mi_5 = _LOAD((const __m128i *)(mi + 20)); \ __m128i xmm_mi_6 = _LOAD((const __m128i *)(mi + 24)); \ __m128i xmm_mi_7 = _LOAD((const __m128i *)(mi + 28)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(8, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(8, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(8, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(8, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(8, 1, 4, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(8, 1, 5, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(8, 1, 6, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(8, 1, 7, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(8, 1, 8, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(8, 1, 9, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(8, 1, 10, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(8, 1, 11, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(8, 1, 12, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(8, 1, 13, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(8, 1, 14, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(8, 1, 15, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=32) #define MATRIX_INT32_ITER_32X32_SSE(mi, qi, _RES, _LOAD, _PROC) \ { \ __m128i xmm_mi_0 = _LOAD((const __m128i *)(mi + 0)); \ __m128i xmm_mi_1 = _LOAD((const __m128i *)(mi + 4)); \ __m128i xmm_mi_2 = _LOAD((const __m128i *)(mi + 8)); \ __m128i xmm_mi_3 = _LOAD((const __m128i *)(mi + 12)); \ __m128i xmm_mi_4 = _LOAD((const __m128i *)(mi + 16)); \ __m128i xmm_mi_5 = _LOAD((const __m128i *)(mi + 20)); \ __m128i xmm_mi_6 = _LOAD((const __m128i *)(mi + 24)); \ __m128i xmm_mi_7 = _LOAD((const __m128i *)(mi + 28)); \ __m128i xmm_qi = _mm_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(8, 1, 0, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(8, 1, 1, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(8, 1, 2, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(8, 1, 3, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(8, 1, 4, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(8, 1, 5, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(8, 1, 6, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(8, 1, 7, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(8, 1, 8, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(8, 1, 9, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(8, 1, 10, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(8, 1, 11, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(8, 1, 12, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(8, 1, 13, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(8, 1, 14, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(8, 1, 15, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 16); \ MATRIX_VAR_PROC(8, 1, 16, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 17); \ MATRIX_VAR_PROC(8, 1, 17, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 18); \ MATRIX_VAR_PROC(8, 1, 18, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 19); \ MATRIX_VAR_PROC(8, 1, 19, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 20); \ MATRIX_VAR_PROC(8, 1, 20, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 21); \ MATRIX_VAR_PROC(8, 1, 21, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 22); \ MATRIX_VAR_PROC(8, 1, 22, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 23); \ MATRIX_VAR_PROC(8, 1, 23, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 24); \ MATRIX_VAR_PROC(8, 1, 24, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 25); \ MATRIX_VAR_PROC(8, 1, 25, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 26); \ MATRIX_VAR_PROC(8, 1, 26, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 27); \ MATRIX_VAR_PROC(8, 1, 27, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 28); \ MATRIX_VAR_PROC(8, 1, 28, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 29); \ MATRIX_VAR_PROC(8, 1, 29, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 30); \ MATRIX_VAR_PROC(8, 1, 30, xmm_mi, xmm_qi, _RES, _PROC) \ xmm_qi = _mm_broadcast_si32(qi + 31); \ MATRIX_VAR_PROC(8, 1, 31, xmm_mi, xmm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=2, N=1) #define MATRIX_INT32_ITER_2X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = _mm256_set_epi32(qi[3], qi[3], qi[2], qi[2], qi[1], \ qi[1], qi[0], qi[0]); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ } //! Iterative process of computing distance (INT32, M=2, N=2) #define MATRIX_INT32_ITER_2X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _LOAD((const __m256i *)(qi)); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(2, 2, 0, 0)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_0) \ ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(3, 3, 1, 1)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_1) \ } //! Iterative process of computing distance (INT32, M=4, N=1) #define MATRIX_INT32_ITER_4X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = \ _mm256_set_m128i(_mm_broadcast_si32(qi + 1), _mm_broadcast_si32(qi)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ } //! Iterative process of computing distance (INT32, M=4, N=2) #define MATRIX_INT32_ITER_4X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = _mm256_set_m128i(_mm_broadcast_si32(qi + 2), \ _mm_broadcast_si32(qi + 0)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_qi = _mm256_set_m128i(_mm_broadcast_si32(qi + 3), \ _mm_broadcast_si32(qi + 1)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_1) \ } //! Iterative process of computing distance (INT32, M=4, N=4) #define MATRIX_INT32_ITER_4X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _LOAD((const __m256i *)(qi)); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(0, 0, 0, 0)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_0) \ ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(1, 1, 1, 1)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_1) \ ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(2, 2, 2, 2)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_2) \ ymm_pi = _mm256_shuffle_epi32(ymm_qi, _MM_SHUFFLE(3, 3, 3, 3)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_3) \ } //! Iterative process of computing distance (INT32, M=8, N=1) #define MATRIX_INT32_ITER_8X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _mm256_broadcast_si32(qi); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ } //! Iterative process of computing distance (INT32, M=8, N=2) #define MATRIX_INT32_ITER_8X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si32(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si32(qi + 1); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=8, N=4) #define MATRIX_INT32_ITER_8X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ _PROC(ymm_mi, ymm_qi, _RES##_0_1) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ _PROC(ymm_mi, ymm_qi, _RES##_0_2) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ _PROC(ymm_mi, ymm_qi, _RES##_0_3) \ } //! Iterative process of computing distance (INT32, M=8, N=8) #define MATRIX_INT32_ITER_8X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ _PROC(ymm_mi, ymm_qi, _RES##_0_1) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ _PROC(ymm_mi, ymm_qi, _RES##_0_2) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ _PROC(ymm_mi, ymm_qi, _RES##_0_3) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ _PROC(ymm_mi, ymm_qi, _RES##_0_4) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ _PROC(ymm_mi, ymm_qi, _RES##_0_5) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ _PROC(ymm_mi, ymm_qi, _RES##_0_6) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ _PROC(ymm_mi, ymm_qi, _RES##_0_7) \ } //! Iterative process of computing distance (INT32, M=16, N=1) #define MATRIX_INT32_ITER_16X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=2) #define MATRIX_INT32_ITER_16X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=4) #define MATRIX_INT32_ITER_16X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=8) #define MATRIX_INT32_ITER_16X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(2, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(2, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(2, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(2, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=16, N=16) #define MATRIX_INT32_ITER_16X16_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(2, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(2, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(2, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(2, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(2, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(2, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(2, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(2, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(2, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(2, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(2, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(2, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=1) #define MATRIX_INT32_ITER_32X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=2) #define MATRIX_INT32_ITER_32X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=4) #define MATRIX_INT32_ITER_32X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=8) #define MATRIX_INT32_ITER_32X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=16) #define MATRIX_INT32_ITER_32X16_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(4, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(4, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(4, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(4, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(4, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(4, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(4, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(4, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT32, M=32, N=32) #define MATRIX_INT32_ITER_32X32_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_qi = _mm256_broadcast_si32(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 8); \ MATRIX_VAR_PROC(4, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 9); \ MATRIX_VAR_PROC(4, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 10); \ MATRIX_VAR_PROC(4, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 11); \ MATRIX_VAR_PROC(4, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 12); \ MATRIX_VAR_PROC(4, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 13); \ MATRIX_VAR_PROC(4, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 14); \ MATRIX_VAR_PROC(4, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 15); \ MATRIX_VAR_PROC(4, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 16); \ MATRIX_VAR_PROC(4, 1, 16, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 17); \ MATRIX_VAR_PROC(4, 1, 17, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 18); \ MATRIX_VAR_PROC(4, 1, 18, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 19); \ MATRIX_VAR_PROC(4, 1, 19, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 20); \ MATRIX_VAR_PROC(4, 1, 20, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 21); \ MATRIX_VAR_PROC(4, 1, 21, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 22); \ MATRIX_VAR_PROC(4, 1, 22, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 23); \ MATRIX_VAR_PROC(4, 1, 23, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 24); \ MATRIX_VAR_PROC(4, 1, 24, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 25); \ MATRIX_VAR_PROC(4, 1, 25, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 26); \ MATRIX_VAR_PROC(4, 1, 26, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 27); \ MATRIX_VAR_PROC(4, 1, 27, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 28); \ MATRIX_VAR_PROC(4, 1, 28, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 29); \ MATRIX_VAR_PROC(4, 1, 29, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 30); \ MATRIX_VAR_PROC(4, 1, 30, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si32(qi + 31); \ MATRIX_VAR_PROC(4, 1, 31, ymm_mi, ymm_qi, _RES, _PROC) \ } ================================================ FILE: src/ailego/math/distance_matrix_int64.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "matrix_define.i" #if defined(__AVX__) #define _mm256_broadcast_si64(a) \ _mm256_castpd_si256(_mm256_broadcast_sd((const double *)(a))) #endif // __AVX__ //! Iterative process of computing distance (INT64, M=2, N=1) #define MATRIX_INT64_ITER_2X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _LOAD((const __m256i *)(qi)); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_pi = \ _mm256_permute4x64_epi64(ymm_qi, _MM_SHUFFLE(1, 1, 0, 0)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ ymm_pi = _mm256_permute4x64_epi64(ymm_qi, _MM_SHUFFLE(3, 3, 2, 2)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_1) \ } //! Iterative process of computing distance (INT64, M=2, N=2) #define MATRIX_INT64_ITER_2X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _LOAD((const __m256i *)(qi)); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_pi = \ _mm256_permute4x64_epi64(ymm_qi, _MM_SHUFFLE(2, 2, 0, 0)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_0) \ ymm_pi = _mm256_permute4x64_epi64(ymm_qi, _MM_SHUFFLE(3, 3, 1, 1)); \ _PROC(ymm_mi, ymm_pi, _RES##_0_1) \ } //! Iterative process of computing distance (INT64, M=4, N=1) #define MATRIX_INT64_ITER_4X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ _PROC(ymm_mi, ymm_qi, _RES##_1_0) \ } //! Iterative process of computing distance (INT64, M=4, N=2) #define MATRIX_INT64_ITER_4X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=4, N=4) #define MATRIX_INT64_ITER_4X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi = _LOAD((const __m256i *)(mi)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ _PROC(ymm_mi, ymm_qi, _RES##_0_1) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ _PROC(ymm_mi, ymm_qi, _RES##_0_2) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ _PROC(ymm_mi, ymm_qi, _RES##_0_3) \ } //! Iterative process of computing distance (INT64, M=8, N=1) #define MATRIX_INT64_ITER_8X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _mm256_broadcast_si64(qi); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ _PROC(ymm_mi, ymm_qi, _RES##_1_0) \ } //! Iterative process of computing distance (INT64, M=8, N=2) #define MATRIX_INT64_ITER_8X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 2, 1, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=8, N=4) #define MATRIX_INT64_ITER_8X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=8, N=8) #define MATRIX_INT64_ITER_8X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(2, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(2, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(2, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(2, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 4); \ MATRIX_VAR_PROC(2, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 5); \ MATRIX_VAR_PROC(2, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 6); \ MATRIX_VAR_PROC(2, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 7); \ MATRIX_VAR_PROC(2, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=16, N=1) #define MATRIX_INT64_ITER_16X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _mm256_broadcast_si64(qi); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ _PROC(ymm_mi, ymm_qi, _RES##_1_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ _PROC(ymm_mi, ymm_qi, _RES##_2_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ _PROC(ymm_mi, ymm_qi, _RES##_3_0) \ } //! Iterative process of computing distance (INT64, M=16, N=2) #define MATRIX_INT64_ITER_16X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 2, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ MATRIX_VAR_PROC(1, 2, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 2, 3, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=16, N=4) #define MATRIX_INT64_ITER_16X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 12)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=16, N=8) #define MATRIX_INT64_ITER_16X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 12)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=16, N=16) #define MATRIX_INT64_ITER_16X16_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 12)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(4, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(4, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(4, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(4, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 4); \ MATRIX_VAR_PROC(4, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 5); \ MATRIX_VAR_PROC(4, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 6); \ MATRIX_VAR_PROC(4, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 7); \ MATRIX_VAR_PROC(4, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 8); \ MATRIX_VAR_PROC(4, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 9); \ MATRIX_VAR_PROC(4, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 10); \ MATRIX_VAR_PROC(4, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 11); \ MATRIX_VAR_PROC(4, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 12); \ MATRIX_VAR_PROC(4, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 13); \ MATRIX_VAR_PROC(4, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 14); \ MATRIX_VAR_PROC(4, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 15); \ MATRIX_VAR_PROC(4, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=32, N=1) #define MATRIX_INT64_ITER_32X1_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi = _mm256_broadcast_si64(qi); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ _PROC(ymm_mi, ymm_qi, _RES##_0_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ _PROC(ymm_mi, ymm_qi, _RES##_1_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ _PROC(ymm_mi, ymm_qi, _RES##_2_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ _PROC(ymm_mi, ymm_qi, _RES##_3_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 16)); \ _PROC(ymm_mi, ymm_qi, _RES##_4_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 20)); \ _PROC(ymm_mi, ymm_qi, _RES##_5_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 24)); \ _PROC(ymm_mi, ymm_qi, _RES##_6_0) \ ymm_mi = _LOAD((const __m256i *)(mi + 28)); \ _PROC(ymm_mi, ymm_qi, _RES##_7_0) \ } //! Iterative process of computing distance (INT64, M=32, N=2) #define MATRIX_INT64_ITER_32X2_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 2, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 2, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ MATRIX_VAR_PROC(1, 2, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 2, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 16)); \ MATRIX_VAR_PROC(1, 2, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 20)); \ MATRIX_VAR_PROC(1, 2, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 24)); \ MATRIX_VAR_PROC(1, 2, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 2, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=32, N=4) #define MATRIX_INT64_ITER_32X4_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_qi_2 = _mm256_broadcast_si64(qi + 2); \ __m256i ymm_qi_3 = _mm256_broadcast_si64(qi + 3); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 4, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 4, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ MATRIX_VAR_PROC(1, 4, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 4, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 16)); \ MATRIX_VAR_PROC(1, 4, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 20)); \ MATRIX_VAR_PROC(1, 4, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 24)); \ MATRIX_VAR_PROC(1, 4, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 4, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=32, N=8) #define MATRIX_INT64_ITER_32X8_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_qi_0 = _mm256_broadcast_si64(qi + 0); \ __m256i ymm_qi_1 = _mm256_broadcast_si64(qi + 1); \ __m256i ymm_qi_2 = _mm256_broadcast_si64(qi + 2); \ __m256i ymm_qi_3 = _mm256_broadcast_si64(qi + 3); \ __m256i ymm_qi_4 = _mm256_broadcast_si64(qi + 4); \ __m256i ymm_qi_5 = _mm256_broadcast_si64(qi + 5); \ __m256i ymm_qi_6 = _mm256_broadcast_si64(qi + 6); \ __m256i ymm_qi_7 = _mm256_broadcast_si64(qi + 7); \ __m256i ymm_mi = _LOAD((const __m256i *)(mi + 0)); \ MATRIX_VAR_PROC(1, 8, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 4)); \ MATRIX_VAR_PROC(1, 8, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 8)); \ MATRIX_VAR_PROC(1, 8, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 12)); \ MATRIX_VAR_PROC(1, 8, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 16)); \ MATRIX_VAR_PROC(1, 8, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 20)); \ MATRIX_VAR_PROC(1, 8, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 24)); \ MATRIX_VAR_PROC(1, 8, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_mi = _LOAD((const __m256i *)(mi + 28)); \ MATRIX_VAR_PROC(1, 8, 7, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=32, N=16) #define MATRIX_INT64_ITER_32X16_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 12)); \ __m256i ymm_mi_4 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_5 = _LOAD((const __m256i *)(mi + 20)); \ __m256i ymm_mi_6 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_mi_7 = _LOAD((const __m256i *)(mi + 28)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(8, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(8, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(8, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(8, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 4); \ MATRIX_VAR_PROC(8, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 5); \ MATRIX_VAR_PROC(8, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 6); \ MATRIX_VAR_PROC(8, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 7); \ MATRIX_VAR_PROC(8, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 8); \ MATRIX_VAR_PROC(8, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 9); \ MATRIX_VAR_PROC(8, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 10); \ MATRIX_VAR_PROC(8, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 11); \ MATRIX_VAR_PROC(8, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 12); \ MATRIX_VAR_PROC(8, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 13); \ MATRIX_VAR_PROC(8, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 14); \ MATRIX_VAR_PROC(8, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 15); \ MATRIX_VAR_PROC(8, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ } //! Iterative process of computing distance (INT64, M=32, N=32) #define MATRIX_INT64_ITER_32X32_AVX(mi, qi, _RES, _LOAD, _PROC) \ { \ __m256i ymm_mi_0 = _LOAD((const __m256i *)(mi + 0)); \ __m256i ymm_mi_1 = _LOAD((const __m256i *)(mi + 4)); \ __m256i ymm_mi_2 = _LOAD((const __m256i *)(mi + 8)); \ __m256i ymm_mi_3 = _LOAD((const __m256i *)(mi + 12)); \ __m256i ymm_mi_4 = _LOAD((const __m256i *)(mi + 16)); \ __m256i ymm_mi_5 = _LOAD((const __m256i *)(mi + 20)); \ __m256i ymm_mi_6 = _LOAD((const __m256i *)(mi + 24)); \ __m256i ymm_mi_7 = _LOAD((const __m256i *)(mi + 28)); \ __m256i ymm_qi = _mm256_broadcast_si64(qi + 0); \ MATRIX_VAR_PROC(8, 1, 0, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 1); \ MATRIX_VAR_PROC(8, 1, 1, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 2); \ MATRIX_VAR_PROC(8, 1, 2, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 3); \ MATRIX_VAR_PROC(8, 1, 3, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 4); \ MATRIX_VAR_PROC(8, 1, 4, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 5); \ MATRIX_VAR_PROC(8, 1, 5, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 6); \ MATRIX_VAR_PROC(8, 1, 6, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 7); \ MATRIX_VAR_PROC(8, 1, 7, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 8); \ MATRIX_VAR_PROC(8, 1, 8, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 9); \ MATRIX_VAR_PROC(8, 1, 9, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 10); \ MATRIX_VAR_PROC(8, 1, 10, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 11); \ MATRIX_VAR_PROC(8, 1, 11, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 12); \ MATRIX_VAR_PROC(8, 1, 12, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 13); \ MATRIX_VAR_PROC(8, 1, 13, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 14); \ MATRIX_VAR_PROC(8, 1, 14, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 15); \ MATRIX_VAR_PROC(8, 1, 15, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 16); \ MATRIX_VAR_PROC(8, 1, 16, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 17); \ MATRIX_VAR_PROC(8, 1, 17, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 18); \ MATRIX_VAR_PROC(8, 1, 18, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 19); \ MATRIX_VAR_PROC(8, 1, 19, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 20); \ MATRIX_VAR_PROC(8, 1, 20, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 21); \ MATRIX_VAR_PROC(8, 1, 21, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 22); \ MATRIX_VAR_PROC(8, 1, 22, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 23); \ MATRIX_VAR_PROC(8, 1, 23, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 24); \ MATRIX_VAR_PROC(8, 1, 24, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 25); \ MATRIX_VAR_PROC(8, 1, 25, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 26); \ MATRIX_VAR_PROC(8, 1, 26, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 27); \ MATRIX_VAR_PROC(8, 1, 27, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 28); \ MATRIX_VAR_PROC(8, 1, 28, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 29); \ MATRIX_VAR_PROC(8, 1, 29, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 30); \ MATRIX_VAR_PROC(8, 1, 30, ymm_mi, ymm_qi, _RES, _PROC) \ ymm_qi = _mm256_broadcast_si64(qi + 31); \ MATRIX_VAR_PROC(8, 1, 31, ymm_mi, ymm_qi, _RES, _PROC) \ } ================================================ FILE: src/ailego/math/distance_matrix_mips_utility.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Calculate Fused-Multiply-Add (AVX512) #define FMA_FP32_AVX512(zmm_m, zmm_q, zmm_sum) \ zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_q, zmm_sum); #define FMA_MASK_FP32_AVX512(zmm_m, zmm_q, zmm_sum, mask) \ zmm_sum = _mm512_mask3_fmadd_ps(zmm_m, zmm_q, zmm_sum, mask); #define HorizontalAdd_FP16_NEON(v) \ vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(v)), vcvt_high_f32_f16(v))) #define HorizontalAdd_FP32_V512_TO_V256(zmm) \ _mm256_add_ps( \ _mm512_castps512_ps256(zmm), \ _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(zmm), 1))) //! Calculate Fused-Multiply-Add (AVX, FP16) #define FMA_FP16_GENERAL(lhs, rhs, sum, norm1, norm2) \ { \ float v1 = lhs; \ float v2 = rhs; \ sum += v1 * v2; \ norm1 += v1 * v1; \ norm2 += v2 * v2; \ } //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_FP32_GENERAL(lhs, rhs, sum, norm1, norm2) \ { \ sum += (lhs) * (rhs); \ norm1 += (lhs) * (lhs); \ norm2 += (rhs) * (rhs); \ } #if defined(__SSE4_1__) //! Four-bits Convert Table static const AILEGO_ALIGNED(32) int8_t Int4ConvertTable[32] = { 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; #endif // __SSE4_1__ #if defined(__SSE4_1__) static const __m128i MASK_INT4_SSE = _mm_set1_epi32(0x0f0f0f0f); static const __m128i ONES_INT16_SSE = _mm_set1_epi32(0x00010001); static const __m128i INT4_LOOKUP_SSE = _mm_load_si128((const __m128i *)Int4ConvertTable); #endif // __SSE4_1__ #if defined(__AVX2__) static const __m256i MASK_INT4_AVX = _mm256_set1_epi32(0x0f0f0f0f); static const __m256i ONES_INT16_AVX = _mm256_set1_epi32(0x00010001); static const __m256i INT4_LOOKUP_AVX = _mm256_load_si256((const __m256i *)Int4ConvertTable); #endif // __AVX2__ //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_INT4_GENERAL(lhs, rhs, sum, norm1, norm2) \ { \ sum += Int4MulTable[(((lhs) << 4) & 0xf0) | (((rhs) >> 0) & 0xf)] + \ Int4MulTable[(((lhs) >> 0) & 0xf0) | (((rhs) >> 4) & 0xf)]; \ norm1 += static_cast( \ ((int8_t)((lhs) << 4) >> 4) * ((int8_t)((lhs) << 4) >> 4) + \ ((int8_t)((lhs) & 0xf0) >> 4) * ((int8_t)((lhs) & 0xf0) >> 4)); \ norm2 += static_cast( \ ((int8_t)((rhs) << 4) >> 4) * ((int8_t)((rhs) << 4) >> 4) + \ ((int8_t)((rhs) & 0xf0) >> 4) * ((int8_t)((rhs) & 0xf0) >> 4)); \ } //! Compute the distance between matrix and query (SSE) #define FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum_0, xmm_sum_norm1, \ xmm_sum_norm2) \ { \ __m128i xmm_lhs_0 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, _mm_and_si128((xmm_lhs), MASK_INT4_SSE)); \ __m128i xmm_rhs_0 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, _mm_and_si128((xmm_rhs), MASK_INT4_SSE)); \ __m128i xmm_lhs_1 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_lhs), 4), MASK_INT4_SSE)); \ __m128i xmm_rhs_1 = _mm_shuffle_epi8( \ INT4_LOOKUP_SSE, \ _mm_and_si128(_mm_srli_epi32((xmm_rhs), 4), MASK_INT4_SSE)); \ FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); \ FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); \ FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); \ FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum_0); \ FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); \ FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); \ } //! Calculate Fused-Multiply-Add (GENERAL) #define FMA_INT8_GENERAL(lhs, rhs, sum, norm1, norm2) \ { \ sum += static_cast(lhs * rhs); \ norm1 += static_cast(lhs * lhs); \ norm2 += static_cast(rhs * rhs); \ } //! Calculate Fused-Multiply-Add (SSE) #define FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum) \ xmm_sum = _mm_add_epi32( \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ ONES_INT16_SSE), \ xmm_sum) //! Calculate Fused-Multiply-Add (AVX) #define FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum) \ ymm_sum = _mm256_add_epi32( \ _mm256_madd_epi16( \ _mm256_maddubs_epi16(_mm256_abs_epi8(ymm_rhs), \ _mm256_sign_epi8(ymm_lhs, ymm_rhs)), \ ONES_INT16_AVX), \ ymm_sum) #define FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum) \ ymm_sum = _mm256_add_epi32( \ _mm256_set_m128i( \ _mm_setzero_si128(), \ _mm_madd_epi16(_mm_maddubs_epi16(_mm_abs_epi8(xmm_rhs), \ _mm_sign_epi8(xmm_lhs, xmm_rhs)), \ ONES_INT16_SSE)), \ ymm_sum) //! Compute the distance between matrix and query (AVX) #define FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, \ ymm_sum_norm1, ymm_sum_norm2) \ { \ __m256i ymm_lhs_0 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_lhs), MASK_INT4_AVX)); \ __m256i ymm_rhs_0 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, _mm256_and_si256((ymm_rhs), MASK_INT4_AVX)); \ __m256i ymm_lhs_1 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_lhs), 4), MASK_INT4_AVX)); \ __m256i ymm_rhs_1 = _mm256_shuffle_epi8( \ INT4_LOOKUP_AVX, \ _mm256_and_si256(_mm256_srli_epi32((ymm_rhs), 4), MASK_INT4_AVX)); \ FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); \ FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); \ FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); \ FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); \ FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); \ FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); \ } ================================================ FILE: src/ailego/math/distance_matrix_popcnt.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_int32.i" #include "distance_matrix_int64.i" #include "matrix_utility.i" //! Calculate population count (UINT32 Permute 1 SSE) #define POPCNT_UINT32_PERMUTE1_SSE(v, ...) \ _mm_add_epi16(_mm_srli_epi16(v, 8), _mm_and_si128(v, _mm_set1_epi16(0xff))) //! Calculate population count (UINT32 Permute 2 SSE) #define POPCNT_UINT32_PERMUTE2_SSE(v, ...) \ _mm_add_epi32(_mm_srli_epi32(v, 16), _mm_and_si128(v, _mm_set1_epi32(0xffff))) //! Calculate population count (UINT32 Permute 1 AVX) #define POPCNT_UINT32_PERMUTE1_AVX(v, ...) \ _mm256_add_epi16(_mm256_srli_epi16(v, 8), \ _mm256_and_si256(v, _mm256_set1_epi16(0xff))) //! Calculate population count (UINT32 Permute 2 AVX) #define POPCNT_UINT32_PERMUTE2_AVX(v, ...) \ _mm256_add_epi32(_mm256_srli_epi32(v, 16), \ _mm256_and_si256(v, _mm256_set1_epi32(0xffff))) //! Calculate population count (UINT64 Permute AVX) #define POPCNT_UINT64_PERMUTE_AVX(v, ...) \ _mm256_sad_epu8(v, _mm256_setzero_si256()) //! Compute the distance between matrix and query (UINT32, M=2, N=1) #define POPCNT_UINT32_2X1_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + ((cnt >> 2) << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 2) << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 2) << 2) : qe_0); \ const uint32_t *qe_3 = q + cnt; \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ if (qe_3 >= qe_0 + 2) { \ __m128i xmm_m = _mm_load_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_set_epi32(q[1], q[1], q[0], q[0]); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ if (qe_3 >= qe_0 + 2) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_set_epi32(q[1], q[1], q[0], q[0]); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ m += 4; \ q += 2; \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_0_1); \ xmm_sum_0_0 = _mm_add_epi32( \ xmm_sum_0_0, _mm_shuffle_epi32(xmm_sum_0_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (q != qe_3) { \ __m128i xmm_m = _mm_set_epi32(0, 0, m[1], m[0]); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0_0)); //! Compute the distance between matrix and query (UINT32, M=2, N=2) #define POPCNT_UINT32_2X2_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + ((cnt >> 1) << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 1) << 2) : qe_0); \ const uint32_t *qe_3 = q + (cnt << 1); \ if (((uintptr_t)m & 0xf) == 0 && ((uintptr_t)q & 0xf) == 0) { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT32_ITER_2X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (q != qe_3) { \ __m128i xmm_m = _mm_set_epi32(m[1], m[0], m[1], m[0]); \ __m128i xmm_q = _mm_set_epi32(q[1], q[1], q[0], q[0]); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=1) #define POPCNT_UINT32_4X1_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + ((cnt >> 1) << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 1) << 1) : qe_0); \ const uint32_t *qe_3 = q + cnt; \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ if (q != qe_3) { \ __m128i xmm_m = _mm_load_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ if (q != qe_3) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ } \ xmm_sum_0_0 = _mm_add_epi32(xmm_sum_0_0, xmm_sum_1_0); \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=2) #define POPCNT_UINT32_4X2_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 2) { \ MATRIX_INT32_ITER_4X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=4) #define POPCNT_UINT32_4X4_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(1, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT32_ITER_4X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=1) #define POPCNT_UINT32_8X1_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=2) #define POPCNT_UINT32_8X2_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=4) #define POPCNT_UINT32_8X4_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=8) #define POPCNT_UINT32_8X8_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(2, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=1) #define POPCNT_UINT32_16X1_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=2) #define POPCNT_UINT32_16X2_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=4) #define POPCNT_UINT32_16X4_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=8) #define POPCNT_UINT32_16X8_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=16) #define POPCNT_UINT32_16X16_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 4); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 4) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 16, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 16, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 16, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(4, 16, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=1) #define POPCNT_UINT32_32X1_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 1, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 1, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 1, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=2) #define POPCNT_UINT32_32X2_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 2, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 2, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 2, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=4) #define POPCNT_UINT32_32X4_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 4, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 4, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 4, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=8) #define POPCNT_UINT32_32X8_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 8, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 8, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 8, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 8, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=16) #define POPCNT_UINT32_32X16_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 16, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 4); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 4) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 16, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 16, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 16, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 16, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 16, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=32) #define POPCNT_UINT32_32X32_SSE(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 32, __m128i, xmm_sum, _mm_setzero_si128()) \ const uint32_t *qe_0 = q + (cnt << 5); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 5) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 5) : qe_0); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 32, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 32, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_load_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } else { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP1_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 32, xmm_sum, POPCNT_UINT32_PERMUTE1_SSE) \ for (; q != qe_2; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP2_SSE) \ } \ MATRIX_VAR_PERMUTE(8, 32, xmm_sum, POPCNT_UINT32_PERMUTE2_SSE) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_SSE(m, q, xmm_sum, _mm_loadu_si128, \ POPCNT_UINT32_STEP3_SSE) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 32, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=2, N=1) #define POPCNT_UINT32_2X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + ((cnt >> 2) << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 2) << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 2) << 2) : qe_0); \ const uint32_t *qe_3 = q + cnt; \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ __m128i xmm_sum_0 = _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (qe_3 >= qe_0 + 2) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_set_epi32(q[1], q[1], q[0], q[0]); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0) \ m += 4; \ q += 2; \ } \ xmm_sum_0 = _mm_add_epi32( \ xmm_sum_0, _mm_shuffle_epi32(xmm_sum_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (q != qe_3) { \ __m128i xmm_m = _mm_set_epi32(0, 0, m[1], m[0]); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(xmm_sum_0)); //! Compute the distance between matrix and query (UINT32, M=2, N=2) #define POPCNT_UINT32_2X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + ((cnt >> 2) << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 2) << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 2) << 3) : qe_0); \ const uint32_t *qe_3 = q + (cnt << 1); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (qe_3 >= qe_0 + 4) { \ __m128i xmm_q = _mm_loadu_si128((const __m128i *)(q)); \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_p = _mm_shuffle_epi32(xmm_q, _MM_SHUFFLE(2, 2, 0, 0)); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_p, xmm_sum_0_0) \ xmm_p = _mm_shuffle_epi32(xmm_q, _MM_SHUFFLE(3, 3, 1, 1)); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_p, xmm_sum_0_1) \ m += 4; \ q += 4; \ } \ xmm_sum_0_0 = _mm_add_epi32(_mm_unpacklo_epi64(xmm_sum_0_0, xmm_sum_0_1), \ _mm_unpackhi_epi64(xmm_sum_0_0, xmm_sum_0_1)); \ if (q != qe_3) { \ __m128i xmm_m = _mm_set_epi32(m[1], m[0], m[1], m[0]); \ __m128i xmm_q = _mm_set_epi32(q[1], q[1], q[0], q[0]); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=1) #define POPCNT_UINT32_4X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + ((cnt >> 1) << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 1) << 1) : qe_0); \ const uint32_t *qe_3 = q + cnt; \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ if (q != qe_3) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=2) #define POPCNT_UINT32_4X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + ((cnt >> 1) << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 1) << 2) : qe_0); \ const uint32_t *qe_3 = q + (cnt << 1); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ if (q != qe_3) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ xmm_q = _mm_broadcast_si32(q + 1); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_1) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=4, N=4) #define POPCNT_UINT32_4X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + ((cnt >> 1) << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + ((4095 >> 1) << 3) : qe_0); \ const uint32_t *qe_3 = q + (cnt << 2); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ __m128i xmm_sum_0_0 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_0), \ _mm256_extracti128_si256(ymm_sum_0_0, 1)); \ __m128i xmm_sum_0_1 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_1), \ _mm256_extracti128_si256(ymm_sum_0_1, 1)); \ __m128i xmm_sum_0_2 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_2), \ _mm256_extracti128_si256(ymm_sum_0_2, 1)); \ __m128i xmm_sum_0_3 = \ _mm_add_epi32(_mm256_castsi256_si128(ymm_sum_0_3), \ _mm256_extracti128_si256(ymm_sum_0_3, 1)); \ if (q != qe_3) { \ __m128i xmm_m = _mm_loadu_si128((const __m128i *)(m)); \ __m128i xmm_q = _mm_broadcast_si32(q); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_0) \ xmm_q = _mm_broadcast_si32(q + 1); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_1) \ xmm_q = _mm_broadcast_si32(q + 2); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_2) \ xmm_q = _mm_broadcast_si32(q + 3); \ POPCNT_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum_0_3) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, xmm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=1) #define POPCNT_UINT32_8X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT32_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=2) #define POPCNT_UINT32_8X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT32_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=4) #define POPCNT_UINT32_8X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT32_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=8, N=8) #define POPCNT_UINT32_8X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT32_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=1) #define POPCNT_UINT32_16X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT32_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=2) #define POPCNT_UINT32_16X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT32_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=4) #define POPCNT_UINT32_16X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT32_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=8) #define POPCNT_UINT32_16X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT32_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=16, N=16) #define POPCNT_UINT32_16X16_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 4); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 4) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 16, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 16, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 16, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 16, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT32_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=1) #define POPCNT_UINT32_32X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + cnt; \ const uint32_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + 4095 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT32_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=2) #define POPCNT_UINT32_32X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 1); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT32_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=4) #define POPCNT_UINT32_32X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 2); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT32_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=8) #define POPCNT_UINT32_32X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 3); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT32_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=16) #define POPCNT_UINT32_32X16_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 4); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 4) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT32_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT32, M=32, N=32) #define POPCNT_UINT32_32X32_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 32, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint32_t *qe_0 = q + (cnt << 5); \ const uint32_t *qe_1 = (cnt > 31 ? q + (31 << 5) : qe_0); \ const uint32_t *qe_2 = (cnt > 4095 ? q + (4095 << 5) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 32, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 32, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 32, ymm_sum, POPCNT_UINT32_PERMUTE1_AVX) \ for (; q != qe_2; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP2_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 32, ymm_sum, POPCNT_UINT32_PERMUTE2_AVX) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT32_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT32_STEP3_AVX) \ } \ } \ if (((uintptr_t)out & 0x1f) == 0) { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 32, 8, ymm_sum, out, _mm256_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=2, N=1) #define POPCNT_UINT64_2X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + ((cnt >> 2) << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + ((31 >> 2) << 2) : qe_0); \ const uint64_t *qe_2 = q + cnt; \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT64_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT64_ITER_2X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ if (qe_2 >= qe_0 + 2) { \ __m256i ymm_m = _mm256_load_si256((const __m256i *)(m)); \ __m256i ymm_q = _mm256_set_epi64x(q[1], q[1], q[0], q[0]); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ m += 4; \ q += 2; \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT64_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT64_ITER_2X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ if (qe_2 >= qe_0 + 2) { \ __m256i ymm_m = _mm256_loadu_si256((const __m256i *)(m)); \ __m256i ymm_q = _mm256_set_epi64x(q[1], q[1], q[0], q[0]); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ m += 4; \ q += 2; \ } \ } \ ymm_sum_0_0 = _mm256_add_epi64(ymm_sum_0_0, ymm_sum_0_1); \ ymm_sum_0_0 = _mm256_add_epi64( \ ymm_sum_0_0, \ _mm256_permute4x64_epi64(ymm_sum_0_0, _MM_SHUFFLE(0, 0, 3, 2))); \ if (q != qe_2) { \ __m256i ymm_m = _mm256_set_epi64x(0, 0, m[1], m[0]); \ __m256i ymm_q = _mm256_broadcast_si64(q); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ } \ _mm_storel_pi((__m64 *)out, _NORM(ymm_sum_0_0)); //! Compute the distance between matrix and query (UINT64, M=2, N=2) #define POPCNT_UINT64_2X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + ((cnt >> 1) << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 2) : qe_0); \ const uint64_t *qe_2 = q + (cnt << 1); \ if (((uintptr_t)m & 0x1f) == 0 && ((uintptr_t)q & 0x1f) == 0) { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT64_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT64_ITER_2X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT64_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT64_ITER_2X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ ymm_sum_0_0 = _mm256_add_epi64( \ _mm256_inserti128_si256(ymm_sum_0_0, \ _mm256_castsi256_si128(ymm_sum_0_1), 1), \ _mm256_inserti128_si256(ymm_sum_0_1, \ _mm256_extractf128_si256(ymm_sum_0_0, 1), 0)); \ if (q != qe_2) { \ __m256i ymm_m = _mm256_set_epi64x(m[1], m[0], m[1], m[0]); \ __m256i ymm_q = _mm256_set_epi64x(q[1], q[1], q[0], q[0]); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=4, N=1) #define POPCNT_UINT64_4X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + ((cnt >> 1) << 1); \ const uint64_t *qe_1 = (cnt > 31 ? q + ((31 >> 1) << 1) : qe_0); \ const uint64_t *qe_2 = q + cnt; \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT64_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT64_ITER_4X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ if (q != qe_2) { \ __m256i ymm_m = _mm256_load_si256((const __m256i *)(m)); \ __m256i ymm_q = _mm256_broadcast_si64(q); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT64_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT64_ITER_4X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ if (q != qe_2) { \ __m256i ymm_m = _mm256_loadu_si256((const __m256i *)(m)); \ __m256i ymm_q = _mm256_broadcast_si64(q); \ POPCNT_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum_0_0) \ } \ } \ ymm_sum_0_0 = _mm256_add_epi64(ymm_sum_0_0, ymm_sum_1_0); \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 1, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 1, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=4, N=2) #define POPCNT_UINT64_4X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 1); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 4, q += 2) { \ MATRIX_INT64_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 2) { \ MATRIX_INT64_ITER_4X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 4, q += 2) { \ MATRIX_INT64_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 2) { \ MATRIX_INT64_ITER_4X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 2, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 2, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=4, N=4) #define POPCNT_UINT64_4X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(1, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT64_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT64_ITER_4X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 4, q += 4) { \ MATRIX_INT64_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(1, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 4, q += 4) { \ MATRIX_INT64_ITER_4X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(1, 4, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(1, 4, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=8, N=1) #define POPCNT_UINT64_8X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + cnt; \ const uint64_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT64_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT64_ITER_8X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, ++q) { \ MATRIX_INT64_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, ++q) { \ MATRIX_INT64_ITER_8X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 1, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 1, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=8, N=2) #define POPCNT_UINT64_8X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 1); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT64_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT64_ITER_8X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 2) { \ MATRIX_INT64_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 2) { \ MATRIX_INT64_ITER_8X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 2, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 2, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=8, N=4) #define POPCNT_UINT64_8X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT64_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT64_ITER_8X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 4) { \ MATRIX_INT64_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 4) { \ MATRIX_INT64_ITER_8X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 4, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 4, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=8, N=8) #define POPCNT_UINT64_8X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(2, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 3); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT64_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT64_ITER_8X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 8, q += 8) { \ MATRIX_INT64_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(2, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 8, q += 8) { \ MATRIX_INT64_ITER_8X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(2, 8, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(2, 8, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=16, N=1) #define POPCNT_UINT64_16X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + cnt; \ const uint64_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT64_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT64_ITER_16X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, ++q) { \ MATRIX_INT64_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, ++q) { \ MATRIX_INT64_ITER_16X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 1, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 1, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=16, N=2) #define POPCNT_UINT64_16X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 1); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT64_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT64_ITER_16X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 2) { \ MATRIX_INT64_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 2) { \ MATRIX_INT64_ITER_16X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 2, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 2, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=16, N=4) #define POPCNT_UINT64_16X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT64_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT64_ITER_16X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 4) { \ MATRIX_INT64_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 4) { \ MATRIX_INT64_ITER_16X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 4, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 4, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=16, N=8) #define POPCNT_UINT64_16X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 3); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT64_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT64_ITER_16X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 8) { \ MATRIX_INT64_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 8) { \ MATRIX_INT64_ITER_16X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 8, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 8, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=16, N=16) #define POPCNT_UINT64_16X16_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(4, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 4); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT64_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT64_ITER_16X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 16, q += 16) { \ MATRIX_INT64_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(4, 16, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 16, q += 16) { \ MATRIX_INT64_ITER_16X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(4, 16, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(4, 16, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=1) #define POPCNT_UINT64_32X1_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 1, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + cnt; \ const uint64_t *qe_1 = (cnt > 31 ? q + 31 : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT64_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT64_ITER_32X1_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, ++q) { \ MATRIX_INT64_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 1, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, ++q) { \ MATRIX_INT64_ITER_32X1_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 1, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 1, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=2) #define POPCNT_UINT64_32X2_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 2, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 1); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 1) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT64_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT64_ITER_32X2_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 2) { \ MATRIX_INT64_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 2, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 2) { \ MATRIX_INT64_ITER_32X2_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 2, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 2, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=4) #define POPCNT_UINT64_32X4_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 4, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 2); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 2) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT64_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT64_ITER_32X4_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 4) { \ MATRIX_INT64_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 4, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 4) { \ MATRIX_INT64_ITER_32X4_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 4, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 4, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=8) #define POPCNT_UINT64_32X8_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 8, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 3); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 3) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT64_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT64_ITER_32X8_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 8) { \ MATRIX_INT64_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 8, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 8) { \ MATRIX_INT64_ITER_32X8_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 8, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 8, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=16) #define POPCNT_UINT64_32X16_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 16, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 4); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 4) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT64_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 16, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT64_ITER_32X16_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 16) { \ MATRIX_INT64_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 16, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 16) { \ MATRIX_INT64_ITER_32X16_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 16, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 16, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } //! Compute the distance between matrix and query (UINT64, M=32, N=32) #define POPCNT_UINT64_32X32_AVX(m, q, cnt, out, _NORM) \ MATRIX_VAR_INIT(8, 32, __m256i, ymm_sum, _mm256_setzero_si256()) \ const uint64_t *qe_0 = q + (cnt << 5); \ const uint64_t *qe_1 = (cnt > 31 ? q + (31 << 5) : qe_0); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT64_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 32, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT64_ITER_32X32_AVX(m, q, ymm_sum, _mm256_load_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } else { \ for (; q != qe_1; m += 32, q += 32) { \ MATRIX_INT64_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP1_AVX) \ } \ MATRIX_VAR_PERMUTE(8, 32, ymm_sum, POPCNT_UINT64_PERMUTE_AVX) \ for (; q != qe_0; m += 32, q += 32) { \ MATRIX_INT64_ITER_32X32_AVX(m, q, ymm_sum, _mm256_loadu_si256, \ POPCNT_UINT64_STEP2_AVX) \ } \ } \ if (((uintptr_t)out & 0xf) == 0) { \ MATRIX_VAR_STORE(8, 32, 4, ymm_sum, out, _mm_store_ps, _NORM) \ } else { \ MATRIX_VAR_STORE(8, 32, 4, ymm_sum, out, _mm_storeu_ps, _NORM) \ } ================================================ FILE: src/ailego/math/distance_utility.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { /*! Four-bits Squared Difference Table */ static const AILEGO_ALIGNED(64) uint8_t Int4SquaredDiffTable[256] = { 0, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1, 1, 0, 1, 4, 9, 16, 25, 36, 81, 64, 49, 36, 25, 16, 9, 4, 4, 1, 0, 1, 4, 9, 16, 25, 100, 81, 64, 49, 36, 25, 16, 9, 9, 4, 1, 0, 1, 4, 9, 16, 121, 100, 81, 64, 49, 36, 25, 16, 16, 9, 4, 1, 0, 1, 4, 9, 144, 121, 100, 81, 64, 49, 36, 25, 25, 16, 9, 4, 1, 0, 1, 4, 169, 144, 121, 100, 81, 64, 49, 36, 36, 25, 16, 9, 4, 1, 0, 1, 196, 169, 144, 121, 100, 81, 64, 49, 49, 36, 25, 16, 9, 4, 1, 0, 225, 196, 169, 144, 121, 100, 81, 64, 64, 81, 100, 121, 144, 169, 196, 225, 0, 1, 4, 9, 16, 25, 36, 49, 49, 64, 81, 100, 121, 144, 169, 196, 1, 0, 1, 4, 9, 16, 25, 36, 36, 49, 64, 81, 100, 121, 144, 169, 4, 1, 0, 1, 4, 9, 16, 25, 25, 36, 49, 64, 81, 100, 121, 144, 9, 4, 1, 0, 1, 4, 9, 16, 16, 25, 36, 49, 64, 81, 100, 121, 16, 9, 4, 1, 0, 1, 4, 9, 9, 16, 25, 36, 49, 64, 81, 100, 25, 16, 9, 4, 1, 0, 1, 4, 4, 9, 16, 25, 36, 49, 64, 81, 36, 25, 16, 9, 4, 1, 0, 1, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1, 0, }; /*! Four-bits Integer Multiplication Table */ static const AILEGO_ALIGNED(64) int8_t Int4MulTable[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 10, 12, 14, -16, -14, -12, -10, -8, -6, -4, -2, 0, 3, 6, 9, 12, 15, 18, 21, -24, -21, -18, -15, -12, -9, -6, -3, 0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4, 0, 5, 10, 15, 20, 25, 30, 35, -40, -35, -30, -25, -20, -15, -10, -5, 0, 6, 12, 18, 24, 30, 36, 42, -48, -42, -36, -30, -24, -18, -12, -6, 0, 7, 14, 21, 28, 35, 42, 49, -56, -49, -42, -35, -28, -21, -14, -7, 0, -8, -16, -24, -32, -40, -48, -56, 64, 56, 48, 40, 32, 24, 16, 8, 0, -7, -14, -21, -28, -35, -42, -49, 56, 49, 42, 35, 28, 21, 14, 7, 0, -6, -12, -18, -24, -30, -36, -42, 48, 42, 36, 30, 24, 18, 12, 6, 0, -5, -10, -15, -20, -25, -30, -35, 40, 35, 30, 25, 20, 15, 10, 5, 0, -4, -8, -12, -16, -20, -24, -28, 32, 28, 24, 20, 16, 12, 8, 4, 0, -3, -6, -9, -12, -15, -18, -21, 24, 21, 18, 15, 12, 9, 6, 3, 0, -2, -4, -6, -8, -10, -12, -14, 16, 14, 12, 10, 8, 6, 4, 2, 0, -1, -2, -3, -4, -5, -6, -7, 8, 7, 6, 5, 4, 3, 2, 1, }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "distance_utility.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- /*! Squared Euclidean Distance Matrix */ template struct SquaredEuclideanDistanceMatrix; /*! Squared Euclidean Distance Matrix (M=1, N=1) */ template struct SquaredEuclideanDistanceMatrix< T, 1, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += MathHelper::SquaredDifference(m[i], q[i]); } *out = sum; } }; template <> struct SquaredEuclideanDistanceMatrix { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct SquaredEuclideanDistanceMatrix { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct SquaredEuclideanDistanceMatrix { //! Type of value using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct SquaredEuclideanDistanceMatrix { //! Type of value using ValueType = float; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; /*! Squared Euclidean Distance Matrix */ template struct SquaredEuclideanDistanceMatrix< T, M, N, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2 && N >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); if (dim > 0) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = MathHelper::SquaredDifference(m_val, q[j]); r += M; } } m += M; q += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += MathHelper::SquaredDifference(m_val, q[j]); r += M; } } m += M; q += N; } } }; /*! Squared Euclidean Distance Matrix (N=1) */ template struct SquaredEuclideanDistanceMatrix< T, M, 1, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); const ValueType *q_end = q + dim; if (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) = MathHelper::SquaredDifference(m[i], q_val); } m += M; } while (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) += MathHelper::SquaredDifference(m[i], q_val); } m += M; } } }; /*! Squared Euclidean Distance Matrix (INT8) */ template struct SquaredEuclideanDistanceMatrix< int8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 2; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = SquaredDifference(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += SquaredDifference(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = MathHelper::SquaredDifference( (int8_t)(lhs >> 0), (int8_t)(rhs >> 0)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 8), (int8_t)(rhs >> 8)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 16), (int8_t)(rhs >> 16)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 24), (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Squared Euclidean Distance Matrix (INT8, N=1) */ template struct SquaredEuclideanDistanceMatrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 2); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = SquaredDifference(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) += SquaredDifference(m_it[i], q_val); } m_it += M; } } protected: //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = MathHelper::SquaredDifference( (int8_t)(lhs >> 0), (int8_t)(rhs >> 0)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 8), (int8_t)(rhs >> 8)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 16), (int8_t)(rhs >> 16)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 24), (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Squared Euclidean Distance Matrix (INT4) */ template struct SquaredEuclideanDistanceMatrix< uint8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 3; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = SquaredDifference(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += SquaredDifference(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast( Int4SquaredDiffTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4SquaredDiffTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4SquaredDiffTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4SquaredDiffTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4SquaredDiffTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4SquaredDiffTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4SquaredDiffTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4SquaredDiffTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; /*! Squared Euclidean Distance Matrix (INT4, N=1) */ template struct SquaredEuclideanDistanceMatrix= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 3); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = SquaredDifference(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) += SquaredDifference(m_it[i], q_val); } m_it += M; } } protected: //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast( Int4SquaredDiffTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4SquaredDiffTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4SquaredDiffTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4SquaredDiffTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4SquaredDiffTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4SquaredDiffTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4SquaredDiffTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4SquaredDiffTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; /*! Euclidean Distance Matrix */ template ::value || std::is_same::value) && M >= 1 && N >= 1>::type> struct EuclideanDistanceMatrix { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); for (size_t i = 0; i < N * M; ++i) { float val = *out; *out++ = std::sqrt(val); } } }; /*! Euclidean Distance Matrix (M=1, N=1) */ template struct EuclideanDistanceMatrix< T, 1, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += MathHelper::SquaredDifference(m[i], q[i]); } *out = std::sqrt(sum); } }; template <> struct EuclideanDistanceMatrix { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct EuclideanDistanceMatrix { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct EuclideanDistanceMatrix { //! Type of value using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct EuclideanDistanceMatrix { //! Type of value using ValueType = float; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; //-------------------------------------------------- // Sparse //-------------------------------------------------- /*! Squared Euclidean Distance Sparse Matrix */ template struct SquaredEuclideanSparseDistanceMatrix { //! Type of value using ValueType = typename std::remove_cv::type; static float ComputeSquaredEuclideanSparseDistanceInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value); //! Compute the distance between matrix and query static inline void Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out) { ailego_assert(out); const uint8_t *m_sparse_data = reinterpret_cast(m_sparse_data_in); const uint8_t *q_sparse_data = reinterpret_cast(q_sparse_data_in); const uint32_t m_sparse_count = *reinterpret_cast(m_sparse_data); const uint32_t q_sparse_count = *reinterpret_cast(q_sparse_data); const uint32_t m_seg_count = *reinterpret_cast(m_sparse_data + sizeof(uint32_t)); const uint32_t q_seg_count = *reinterpret_cast(q_sparse_data + sizeof(uint32_t)); const uint32_t *m_seg_id = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *q_seg_id = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *m_seg_vec_cnt = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * sizeof(uint32_t)); const uint32_t *q_seg_vec_cnt = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * sizeof(uint32_t)); const uint16_t *m_sparse_index = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t)); const uint16_t *q_sparse_index = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t)); const ValueType *m_sparse_value = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t) + m_sparse_count * sizeof(uint16_t)); const ValueType *q_sparse_value = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t) + q_sparse_count * sizeof(uint16_t)); float sum = 0.0f; size_t m_s = 0; size_t q_s = 0; size_t m_count = 0; size_t q_count = 0; while (m_s < m_seg_count && q_s < q_seg_count) { if (m_seg_id[m_s] == q_seg_id[q_s]) { sum += ComputeSquaredEuclideanSparseDistanceInSegment( m_seg_vec_cnt[m_s], m_sparse_index + m_count, m_sparse_value + m_count, q_seg_vec_cnt[q_s], q_sparse_index + q_count, q_sparse_value + q_count); m_count += m_seg_vec_cnt[m_s]; q_count += q_seg_vec_cnt[q_s]; ++m_s; ++q_s; } else if (m_seg_id[m_s] < q_seg_id[q_s]) { for (size_t i = 0; i < m_seg_vec_cnt[m_s]; i++) { float value = (m_sparse_value + m_count)[i]; sum += value * value; } m_count += m_seg_vec_cnt[m_s]; ++m_s; } else { for (size_t i = 0; i < q_seg_vec_cnt[q_s]; i++) { float value = (q_sparse_value + q_count)[i]; sum += value * value; } q_count += q_seg_vec_cnt[q_s]; ++q_s; } } for (; m_s < m_seg_count; m_s++) { for (size_t i = 0; i < m_seg_vec_cnt[m_s]; i++) { float diff = (m_sparse_value + m_count)[i]; sum += diff * diff; } m_count += m_seg_vec_cnt[m_s]; } for (; q_s < q_seg_count; q_s++) { for (size_t i = 0; i < q_seg_vec_cnt[q_s]; i++) { float diff = (q_sparse_value + q_count)[i]; sum += diff * diff; } q_count += q_seg_vec_cnt[q_s]; } *out = sum; } }; template float SquaredEuclideanSparseDistanceMatrix:: ComputeSquaredEuclideanSparseDistanceInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value) { float sum = 0.0f; size_t m_i = 0; size_t q_i = 0; while (m_i < m_sparse_count && q_i < q_sparse_count) { if (m_sparse_index[m_i] == q_sparse_index[q_i]) { float diff = m_sparse_value[m_i] - q_sparse_value[q_i]; sum += diff * diff; ++m_i; ++q_i; } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { float diff = m_sparse_value[m_i]; sum += diff * diff; ++m_i; } else { float diff = q_sparse_value[q_i]; sum += diff * diff; ++q_i; } } for (; m_i < m_sparse_count; m_i++) { float diff = m_sparse_value[m_i]; sum += diff * diff; } for (; q_i < q_sparse_count; q_i++) { float diff = q_sparse_value[q_i]; sum += diff * diff; } return sum; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp16_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX__) float SquaredEuclideanDistanceFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX(lhs, rhs, size, &score, 0ull, ) return score; } #endif // __AVX__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp16_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512F__) float SquaredEuclideanDistanceFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX512(lhs, rhs, size, &score, 0ull, ) return score; } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp16_avx512fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512FP16__) //! Squared Euclidean Distance float SquaredEuclideanDistanceFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size) { const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 6) << 6); __m512h zmm_sum_0 = _mm512_setzero_ph(); __m512h zmm_sum_1 = _mm512_setzero_ph(); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m512h zmm_d_0 = _mm512_sub_ph(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0)); __m512h zmm_d_1 = _mm512_sub_ph(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32)); zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); } if (last >= last_aligned + 32) { __m512h zmm_d = _mm512_sub_ph(_mm512_load_ph(lhs), _mm512_load_ph(rhs)); zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); lhs += 32; rhs += 32; } } else { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m512h zmm_d_0 = _mm512_sub_ph(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0)); __m512h zmm_d_1 = _mm512_sub_ph(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32)); zmm_sum_0 = _mm512_fmadd_ph(zmm_d_0, zmm_d_0, zmm_sum_0); zmm_sum_1 = _mm512_fmadd_ph(zmm_d_1, zmm_d_1, zmm_sum_1); } if (last >= last_aligned + 32) { __m512h zmm_d = _mm512_sub_ph(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs)); zmm_sum_0 = _mm512_fmadd_ph(zmm_d, zmm_d, zmm_sum_0); lhs += 32; rhs += 32; } } zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); if (lhs != last) { __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); __m512i zmm_undefined = _mm512_undefined_epi32(); __m512h zmm_undefined_ph = _mm512_undefined_ph(); __m512h zmm_d = _mm512_mask_sub_ph( zmm_undefined_ph, mask, _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs))); zmm_sum_0 = _mm512_mask3_fmadd_ph(zmm_d, zmm_d, zmm_sum_0, mask); } return HorizontalAdd_FP16_V512(zmm_sum_0); } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp16_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) float SquaredEuclideanDistanceFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX512FP16__) float SquaredEuclideanDistanceFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX512F__) float SquaredEuclideanDistanceFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX__) float SquaredEuclideanDistanceFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size); #endif float SquaredEuclideanDistanceFp16Scalar(const Float16 *lhs, const Float16 *rhs, size_t size); //! Compute the distance between matrix and query (FP16, M=1, N=1) void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__ARM_NEON) *out = SquaredEuclideanDistanceFp16NEON(m, q, dim); #else #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { *out = SquaredEuclideanDistanceFp16AVX512FP16(m, q, dim); return; } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = SquaredEuclideanDistanceFp16AVX512(m, q, dim); return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = SquaredEuclideanDistanceFp16AVX(m, q, dim); return; } #endif *out = SquaredEuclideanDistanceFp16Scalar(m, q, dim); #endif //__ARM_NEON } //! Compute the distance between matrix and query (FP16, M=1, N=1) void EuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); *out = std::sqrt(*out); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp16_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) float SquaredEuclideanDistanceFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_NEON(lhs, rhs, size, &score, 0ull, ) return score; } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp32_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX__) float SquaredEuclideanDistanceFp32SSEInternal(const float *lhs, const float *rhs, size_t size); float SquaredEuclideanDistanceFp32AVXInternal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 4) << 4); __m256 ymm_sum_0 = _mm256_setzero_ps(); __m256 ymm_sum_1 = _mm256_setzero_ps(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_d_0 = _mm256_sub_ps(_mm256_load_ps(lhs + 0), _mm256_load_ps(rhs + 0)); __m256 ymm_d_1 = _mm256_sub_ps(_mm256_load_ps(lhs + 8), _mm256_load_ps(rhs + 8)); ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); } if (last >= last_aligned + 8) { __m256 ymm_d = _mm256_sub_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs)); ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); lhs += 8; rhs += 8; } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_d_0 = _mm256_sub_ps(_mm256_loadu_ps(lhs + 0), _mm256_loadu_ps(rhs + 0)); __m256 ymm_d_1 = _mm256_sub_ps(_mm256_loadu_ps(lhs + 8), _mm256_loadu_ps(rhs + 8)); ymm_sum_0 = _mm256_fmadd_ps(ymm_d_0, ymm_d_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_d_1, ymm_d_1, ymm_sum_1); } if (last >= last_aligned + 8) { __m256 ymm_d = _mm256_sub_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs)); ymm_sum_0 = _mm256_fmadd_ps(ymm_d, ymm_d, ymm_sum_0); lhs += 8; rhs += 8; } } float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); switch (last - lhs) { case 7: SSD_FP32_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: SSD_FP32_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: SSD_FP32_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: SSD_FP32_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: SSD_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_FP32_GENERAL(lhs[0], rhs[0], result) } return result; } float SquaredEuclideanDistanceFp32AVX(const float *lhs, const float *rhs, size_t size) { if (size > 7) { return SquaredEuclideanDistanceFp32AVXInternal(lhs, rhs, size); } return SquaredEuclideanDistanceFp32SSEInternal(lhs, rhs, size); } #endif // __AVX__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp32_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512F__) float SquaredEuclideanDistanceFp32SSEInternal(const float *lhs, const float *rhs, size_t size); float SquaredEuclideanDistanceFp32AVXInternal(const float *lhs, const float *rhs, size_t size); float SquaredEuclideanDistanceFp32AVX512Internal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 5) << 5); __m512 zmm_sum_0 = _mm512_setzero_ps(); __m512 zmm_sum_1 = _mm512_setzero_ps(); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512 zmm_d_0 = _mm512_sub_ps(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0)); __m512 zmm_d_1 = _mm512_sub_ps(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16)); zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); } if (last >= last_aligned + 16) { __m512 zmm_d = _mm512_sub_ps(_mm512_load_ps(lhs), _mm512_load_ps(rhs)); zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512 zmm_d_0 = _mm512_sub_ps(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0)); __m512 zmm_d_1 = _mm512_sub_ps(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16)); zmm_sum_0 = _mm512_fmadd_ps(zmm_d_0, zmm_d_0, zmm_sum_0); zmm_sum_1 = _mm512_fmadd_ps(zmm_d_1, zmm_d_1, zmm_sum_1); } if (last >= last_aligned + 16) { __m512 zmm_d = _mm512_sub_ps(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs)); zmm_sum_0 = _mm512_fmadd_ps(zmm_d, zmm_d, zmm_sum_0); lhs += 16; rhs += 16; } } zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); if (lhs != last) { __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); __m512 zmm_undefined = _mm512_undefined_ps(); __m512 zmm_d = _mm512_mask_sub_ps( zmm_undefined, mask, _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), _mm512_mask_loadu_ps(zmm_undefined, mask, rhs)); zmm_sum_0 = _mm512_mask3_fmadd_ps(zmm_d, zmm_d, zmm_sum_0, mask); } return HorizontalAdd_FP32_V512(zmm_sum_0); } float SquaredEuclideanDistanceFp32AVX512(const float *lhs, const float *rhs, size_t size) { if (size > 15) { return SquaredEuclideanDistanceFp32AVX512Internal(lhs, rhs, size); } if (size > 7) { return SquaredEuclideanDistanceFp32AVXInternal(lhs, rhs, size); } return SquaredEuclideanDistanceFp32SSEInternal(lhs, rhs, size); } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp32_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) void SquaredEuclideanDistanceFp32NEON(const float *lhs, const float *rhs, size_t size, float *out); #endif #if defined(__AVX512F__) float SquaredEuclideanDistanceFp32AVX512(const float *lhs, const float *rhs, size_t size); #endif #if defined(__AVX__) float SquaredEuclideanDistanceFp32AVX(const float *lhs, const float *rhs, size_t size); #endif #if defined(__SSE__) float SquaredEuclideanDistanceFp32SSE(const float *lhs, const float *rhs, size_t size); #endif float SquaredEuclideanDistanceFp32Scalar(const float *lhs, const float *rhs, size_t size); //----------------------------------------------------------- // SquaredEuclideanDistance //----------------------------------------------------------- //! Compute the distance between matrix and query (FP32, M=1, N=1) void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__ARM_NEON) SquaredEuclideanDistanceFp32NEON(m, q, dim, out); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = SquaredEuclideanDistanceFp32AVX512(m, q, dim); return; } #endif // __AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = SquaredEuclideanDistanceFp32AVX(m, q, dim); return; } #endif // __AVX__ #if defined(__SSE__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE) { *out = SquaredEuclideanDistanceFp32SSE(m, q, dim); return; } #endif // __SSE__ *out = SquaredEuclideanDistanceFp32Scalar(m, q, dim); #endif // __ARM_NEON } //----------------------------------------------------------- // EuclideanDistance //----------------------------------------------------------- //! Compute the distance between matrix and query (FP32, M=1, N=1) void EuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); *out = std::sqrt(*out); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp32_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) //! Squared Euclidean Distance void SquaredEuclideanDistanceFp32NEON(const float *lhs, const float *rhs, size_t size, float *out) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); float32x4_t v_sum_0 = vdupq_n_f32(0); float32x4_t v_sum_1 = vdupq_n_f32(0); for (; lhs != last_aligned; lhs += 8, rhs += 8) { float32x4_t v_d_0 = vsubq_f32(vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); float32x4_t v_d_1 = vsubq_f32(vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); v_sum_0 = vfmaq_f32(v_sum_0, v_d_0, v_d_0); v_sum_1 = vfmaq_f32(v_sum_1, v_d_1, v_d_1); } if (last >= last_aligned + 4) { float32x4_t v_d = vsubq_f32(vld1q_f32(lhs), vld1q_f32(rhs)); v_sum_0 = vfmaq_f32(v_sum_0, v_d, v_d); lhs += 4; rhs += 4; } float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); switch (last - lhs) { case 3: SSD_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_FP32_GENERAL(lhs[0], rhs[0], result) } *out = result; } #endif // __ARM_NEON } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_fp32_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE__) float SquaredEuclideanDistanceFp32SSEInternal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); __m128 xmm_sum_0 = _mm_setzero_ps(); __m128 xmm_sum_1 = _mm_setzero_ps(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_d_0 = _mm_sub_ps(_mm_load_ps(lhs + 0), _mm_load_ps(rhs + 0)); __m128 xmm_d_1 = _mm_sub_ps(_mm_load_ps(lhs + 4), _mm_load_ps(rhs + 4)); xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); } if (last >= last_aligned + 4) { __m128 xmm_d = _mm_sub_ps(_mm_load_ps(lhs), _mm_load_ps(rhs)); xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_d_0 = _mm_sub_ps(_mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); __m128 xmm_d_1 = _mm_sub_ps(_mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); xmm_sum_0 = _mm_fmadd_ps(xmm_d_0, xmm_d_0, xmm_sum_0); xmm_sum_1 = _mm_fmadd_ps(xmm_d_1, xmm_d_1, xmm_sum_1); } if (last >= last_aligned + 4) { __m128 xmm_d = _mm_sub_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs)); xmm_sum_0 = _mm_fmadd_ps(xmm_d, xmm_d, xmm_sum_0); lhs += 4; rhs += 4; } } float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); switch (last - lhs) { case 3: SSD_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_FP32_GENERAL(lhs[0], rhs[0], result) } return result; } float SquaredEuclideanDistanceFp32SSE(const float *lhs, const float *rhs, size_t size) { return SquaredEuclideanDistanceFp32SSEInternal(lhs, rhs, size); } #endif // __SSE__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int4_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int4.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float SquaredEuclideanDistanceInt4SSEInternal(const uint8_t *lhs, const uint8_t *rhs, size_t size); inline float SquaredEuclideanDistanceInt4AVX2Internal(const uint8_t *lhs, const uint8_t *rhs, size_t size) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 5) << 5); __m256i ymm_sum = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); SSD_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum); lhs += 16; rhs += 16; } } float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); switch (last - lhs) { case 15: SSD_INT4_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: SSD_INT4_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: SSD_INT4_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: SSD_INT4_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: SSD_INT4_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: SSD_INT4_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: SSD_INT4_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: SSD_INT4_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: SSD_INT4_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: SSD_INT4_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: SSD_INT4_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: SSD_INT4_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: SSD_INT4_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_INT4_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_INT4_GENERAL(lhs[0], rhs[0], result) } return result; } float SquaredEuclideanDistanceInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size) { if (size > 63) { return SquaredEuclideanDistanceInt4AVX2Internal(lhs, rhs, size >> 1); } return SquaredEuclideanDistanceInt4SSEInternal(lhs, rhs, size >> 1); } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int4_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float SquaredEuclideanDistanceInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size); #endif #if defined(__SSE4_1__) float SquaredEuclideanDistanceInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); #endif float SquaredEuclideanDistanceInt4Scalar(const uint8_t *lhs, const uint8_t *rhs, size_t size); //! Compute the distance between matrix and query (INT4, M=1, N=1) void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = SquaredEuclideanDistanceInt4AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = SquaredEuclideanDistanceInt4SSE(m, q, dim); return; } #endif *out = SquaredEuclideanDistanceInt4Scalar(m, q, dim); } //! Compute the distance between matrix and query (INT4, M=1, N=1) void EuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); *out = std::sqrt(*out); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int4_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int4.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE4_1__) float SquaredEuclideanDistanceInt4SSEInternal(const uint8_t *lhs, const uint8_t *rhs, size_t size) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 4) << 4); __m128i xmm_sum = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); SSD_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) } } float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); switch (last - lhs) { case 15: SSD_INT4_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: SSD_INT4_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: SSD_INT4_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: SSD_INT4_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: SSD_INT4_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: SSD_INT4_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: SSD_INT4_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: SSD_INT4_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: SSD_INT4_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: SSD_INT4_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: SSD_INT4_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: SSD_INT4_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: SSD_INT4_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_INT4_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_INT4_GENERAL(lhs[0], rhs[0], result) } return result; } float SquaredEuclideanDistanceInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size) { return SquaredEuclideanDistanceInt4SSEInternal(lhs, rhs, size >> 1); } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int8_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float SquaredEuclideanDistanceInt8SSEInternal(const int8_t *lhs, const int8_t *rhs, size_t size); float SquaredEuclideanDistanceInt8AVX2Internal(const int8_t *lhs, const int8_t *rhs, size_t size) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 6) << 6); float result = 0.0; __m256i ymm_sum_0 = _mm256_setzero_si256(); __m256i ymm_sum_1 = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), _mm256_min_epi8(ymm_lhs, ymm_rhs)); ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); lhs += 32; rhs += 32; } } else { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_0, ymm_rhs_0), _mm256_min_epi8(ymm_lhs_0, ymm_rhs_0)); ymm_lhs_0 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs_0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_0, ymm_lhs_0), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_0, ymm_rhs_0), ymm_sum_1); ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs_1, ymm_rhs_1), _mm256_min_epi8(ymm_lhs_1, ymm_rhs_1)); ymm_lhs_1 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs_1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs_1, ymm_lhs_1), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs_1, ymm_rhs_1), ymm_sum_1); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); __m256i ymm_d = _mm256_sub_epi8(_mm256_max_epi8(ymm_lhs, ymm_rhs), _mm256_min_epi8(ymm_lhs, ymm_rhs)); ymm_lhs = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(ymm_d)); ymm_rhs = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(ymm_d, 1)); ymm_sum_0 = _mm256_add_epi32(_mm256_madd_epi16(ymm_lhs, ymm_lhs), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32(_mm256_madd_epi16(ymm_rhs, ymm_rhs), ymm_sum_1); lhs += 32; rhs += 32; } } result = static_cast( HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), _mm_min_epi8(xmm_lhs, xmm_rhs)); xmm_lhs = _mm_cvtepu8_epi16(xmm_sum); xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_sum, xmm_sum)); xmm_sum = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), _mm_madd_epi16(xmm_rhs, xmm_rhs)); result += static_cast(HorizontalAdd_INT32_V128(xmm_sum)); lhs += 16; rhs += 16; } switch (last - lhs) { case 15: SSD_INT8_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: SSD_INT8_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: SSD_INT8_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: SSD_INT8_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: SSD_INT8_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: SSD_INT8_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: SSD_INT8_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: SSD_INT8_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: SSD_INT8_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: SSD_INT8_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: SSD_INT8_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: SSD_INT8_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: SSD_INT8_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_INT8_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_INT8_GENERAL(lhs[0], rhs[0], result) } return result; } float SquaredEuclideanDistanceInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size) { if (size > 31) { return SquaredEuclideanDistanceInt8AVX2Internal(lhs, rhs, size); } return SquaredEuclideanDistanceInt8SSEInternal(lhs, rhs, size); } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int8_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float SquaredEuclideanDistanceInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size); #endif #if defined(__SSE4_1__) float SquaredEuclideanDistanceInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size); #endif float SquaredEuclideanDistanceInt8Scalar(const int8_t *lhs, const int8_t *rhs, size_t size); //! Compute the distance between matrix and query (INT8, M=1, N=1) void SquaredEuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = SquaredEuclideanDistanceInt8AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = SquaredEuclideanDistanceInt8SSE(m, q, dim); return; } #endif *out = SquaredEuclideanDistanceInt8Scalar(m, q, dim); } //! Compute the distance between matrix and query (INT8, M=1, N=1) void EuclideanDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { SquaredEuclideanDistanceMatrix::Compute(m, q, dim, out); *out = std::sqrt(*out); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_int8_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_euclidean_utility.i" #include "euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE4_1__) inline float SquaredEuclideanDistanceInt8SSEInternal(const int8_t *lhs, const int8_t *rhs, size_t size) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 5) << 5); __m128i xmm_sum_0 = _mm_setzero_si128(); __m128i xmm_sum_1 = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), _mm_min_epi8(xmm_lhs, xmm_rhs)); xmm_lhs = _mm_cvtepu8_epi16(xmm_d); xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_0, xmm_rhs_0), _mm_min_epi8(xmm_lhs_0, xmm_rhs_0)); xmm_lhs_0 = _mm_cvtepu8_epi16(xmm_d); xmm_rhs_0 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs_1, xmm_rhs_1), _mm_min_epi8(xmm_lhs_1, xmm_rhs_1)); xmm_lhs_1 = _mm_cvtepu8_epi16(xmm_d); xmm_rhs_1 = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs_0, xmm_lhs_0), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs_0, xmm_rhs_0), xmm_sum_1); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs_1, xmm_lhs_1), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs_1, xmm_rhs_1), xmm_sum_1); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); __m128i xmm_d = _mm_sub_epi8(_mm_max_epi8(xmm_lhs, xmm_rhs), _mm_min_epi8(xmm_lhs, xmm_rhs)); xmm_lhs = _mm_cvtepu8_epi16(xmm_d); xmm_rhs = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(xmm_d, xmm_d)); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(xmm_lhs, xmm_lhs), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(xmm_rhs, xmm_rhs), xmm_sum_1); lhs += 16; rhs += 16; } } float result = static_cast( HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); switch (last - lhs) { case 15: SSD_INT8_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: SSD_INT8_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: SSD_INT8_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: SSD_INT8_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: SSD_INT8_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: SSD_INT8_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: SSD_INT8_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: SSD_INT8_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: SSD_INT8_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: SSD_INT8_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: SSD_INT8_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: SSD_INT8_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: SSD_INT8_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: SSD_INT8_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: SSD_INT8_GENERAL(lhs[0], rhs[0], result) } return result; } //! Squared Euclidean Distance float SquaredEuclideanDistanceInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size) { return SquaredEuclideanDistanceInt8SSEInternal(lhs, rhs, size); } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/euclidean_distance_matrix_scalar.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "distance_utility.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- template inline float SquaredEuclideanDistanceScalar(const T *m, const T *q, size_t dim) { ailego_assert(m && q && dim); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += MathHelper::SquaredDifference(m[i], q[i]); } return sum; } template inline float EuclideanDistanceScalar(const T *m, const T *q, size_t dim) { ailego_assert(m && q && dim); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += MathHelper::SquaredDifference(m[i], q[i]); } return std::sqrt(sum); } float SquaredEuclideanDistanceInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim) { ailego_assert(m && q && dim && !(dim & 1)); float sum = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { uint8_t m_val = m[i]; uint8_t q_val = q[i]; sum += Int4SquaredDiffTable[((m_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4SquaredDiffTable[((m_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } return sum; } float EuclideanDistanceInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim) { ailego_assert(m && q && dim && !(dim & 1)); float sum = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { uint8_t m_val = m[i]; uint8_t q_val = q[i]; sum += Int4SquaredDiffTable[((m_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4SquaredDiffTable[((m_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } return std::sqrt(sum); } float SquaredEuclideanDistanceInt8Scalar(const int8_t *m, const int8_t *q, size_t dim) { return SquaredEuclideanDistanceScalar(m, q, dim); } float EuclideanDistanceInt8Scalar(const int8_t *m, const int8_t *q, size_t dim) { return EuclideanDistanceScalar(m, q, dim); } float SquaredEuclideanDistanceFp16Scalar(const ailego::Float16 *m, const ailego::Float16 *q, size_t dim) { return SquaredEuclideanDistanceScalar(m, q, dim); } float EuclideanDistanceFp16Scalar(const ailego::Float16 *m, const ailego::Float16 *q, size_t dim) { return EuclideanDistanceScalar(m, q, dim); } float SquaredEuclideanDistanceFp32Scalar(const float *m, const float *q, size_t dim) { return SquaredEuclideanDistanceScalar(m, q, dim); } float EuclideanDistanceFp32Scalar(const float *m, const float *q, size_t dim) { return EuclideanDistanceScalar(m, q, dim); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/hamming_distance_matrix.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hamming_distance_matrix.h" #include #include #include "distance_matrix_popcnt.i" namespace zvec { namespace ailego { #define POPCNT_UINT32_STEP1_SSE HAMMING_UINT32_STEP1_SSE #define POPCNT_UINT32_STEP2_SSE HAMMING_UINT32_STEP2_SSE #define POPCNT_UINT32_STEP3_SSE HAMMING_UINT32_STEP3_SSE #define POPCNT_UINT32_STEP1_AVX HAMMING_UINT32_STEP1_AVX #define POPCNT_UINT32_STEP2_AVX HAMMING_UINT32_STEP2_AVX #define POPCNT_UINT32_STEP3_AVX HAMMING_UINT32_STEP3_AVX #define POPCNT_UINT64_STEP1_AVX HAMMING_UINT64_STEP1_AVX #define POPCNT_UINT64_STEP2_AVX HAMMING_UINT64_STEP2_AVX //! Calculate population count (Step 1 SSE) #define HAMMING_UINT32_STEP1_SSE(xmm_m, xmm_q, xmm_sum) \ xmm_sum = _mm_add_epi8( \ VerticalPopCount_INT8_V128(_mm_xor_si128(xmm_m, xmm_q)), xmm_sum); //! Calculate population count (Step 2 SSE) #define HAMMING_UINT32_STEP2_SSE(xmm_m, xmm_q, xmm_sum) \ xmm_sum = _mm_add_epi16( \ VerticalPopCount_INT16_V128(_mm_xor_si128(xmm_m, xmm_q)), xmm_sum); //! Calculate population count (Step 3 SSE) #define HAMMING_UINT32_STEP3_SSE(xmm_m, xmm_q, xmm_sum) \ xmm_sum = _mm_add_epi32( \ VerticalPopCount_INT32_V128(_mm_xor_si128(xmm_m, xmm_q)), xmm_sum); //! Calculate population count (Step 1 AVX) #define HAMMING_UINT32_STEP1_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi8( \ VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_m, ymm_q)), ymm_sum); //! Calculate population count (Step 2 AVX) #define HAMMING_UINT32_STEP2_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi16( \ VerticalPopCount_INT16_V256(_mm256_xor_si256(ymm_m, ymm_q)), ymm_sum); //! Calculate population count (Step 3 AVX) #define HAMMING_UINT32_STEP3_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi32( \ VerticalPopCount_INT32_V256(_mm256_xor_si256(ymm_m, ymm_q)), ymm_sum); //! Calculate population count (Step 1 AVX) #define HAMMING_UINT64_STEP1_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi8( \ VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_m, ymm_q)), ymm_sum); //! Calculate population count (Step 2 AVX) #define HAMMING_UINT64_STEP2_AVX(ymm_m, ymm_q, ymm_sum) \ ymm_sum = _mm256_add_epi64( \ VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_m, ymm_q)), ymm_sum); #if defined(__AVX512VL__) && defined(__AVX512DQ__) #define CONVERT_UINT64_TO_FP32(v, ...) _mm256_cvtepu64_ps(v) #elif defined(__AVX2__) static const __m256i CONVERT_UINT32_MASK_AVX = _mm256_set_epi32(0, 0, 0, 0, 6, 4, 2, 0); #define CONVERT_UINT64_TO_FP32(v, ...) \ _mm_cvtepi32_ps(_mm256_castsi256_si128( \ _mm256_permutevar8x32_epi32(v, CONVERT_UINT32_MASK_AVX))) #endif // __AVX512VL__ && __AVX512DQ__ #define SQRT_UINT64_TO_FP32(v, ...) _mm_sqrt_ps(CONVERT_UINT64_TO_FP32(v)) #define SQRT_UINT32_TO_FP32_SSE(v, ...) _mm_sqrt_ps(_mm_cvtepi32_ps(v)) #define SQRT_UINT32_TO_FP32_AVX(v, ...) _mm256_sqrt_ps(_mm256_cvtepi32_ps(v)) #if defined(__AVX2__) static inline size_t HammingDistanceAVX(const uint32_t *lhs, const uint32_t *rhs, size_t size) { __m256i ymm_sum_0 = _mm256_setzero_si256(); __m256i ymm_sum_1 = _mm256_setzero_si256(); const uint32_t *lhs_0 = lhs + ((size >> 4) << 4); const uint32_t *lhs_1 = (size > 496 ? lhs + 496 : lhs_0); const uint32_t *lhs_2 = lhs + size; if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != lhs_1; lhs += 16, rhs += 16) { __m256i ymm_lhs_0 = _mm256_load_si256((__m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_load_si256((__m256i *)(lhs + 8)); __m256i ymm_rhs_0 = _mm256_load_si256((__m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_load_si256((__m256i *)(rhs + 8)); ymm_sum_0 = _mm256_add_epi8( VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_lhs_0, ymm_rhs_0)), ymm_sum_0); ymm_sum_1 = _mm256_add_epi8( VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_lhs_1, ymm_rhs_1)), ymm_sum_1); } ymm_sum_0 = _mm256_sad_epu8(ymm_sum_0, POPCNT_ZERO_AVX); ymm_sum_1 = _mm256_sad_epu8(ymm_sum_1, POPCNT_ZERO_AVX); for (; lhs != lhs_0; lhs += 16, rhs += 16) { __m256i ymm_lhs_0 = _mm256_load_si256((__m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_load_si256((__m256i *)(lhs + 8)); __m256i ymm_rhs_0 = _mm256_load_si256((__m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_load_si256((__m256i *)(rhs + 8)); ymm_sum_0 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs_0, ymm_rhs_0)), ymm_sum_0); ymm_sum_1 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs_1, ymm_rhs_1)), ymm_sum_1); } if (lhs_2 >= lhs + 8) { __m256i ymm_lhs = _mm256_load_si256((__m256i *)(lhs)); __m256i ymm_rhs = _mm256_load_si256((__m256i *)(rhs)); ymm_sum_0 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs, ymm_rhs)), ymm_sum_0); lhs += 8; rhs += 8; } } else { for (; lhs != lhs_1; lhs += 16, rhs += 16) { __m256i ymm_lhs_0 = _mm256_loadu_si256((__m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_loadu_si256((__m256i *)(lhs + 8)); __m256i ymm_rhs_0 = _mm256_loadu_si256((__m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_loadu_si256((__m256i *)(rhs + 8)); ymm_sum_0 = _mm256_add_epi8( VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_lhs_0, ymm_rhs_0)), ymm_sum_0); ymm_sum_1 = _mm256_add_epi8( VerticalPopCount_INT8_V256(_mm256_xor_si256(ymm_lhs_1, ymm_rhs_1)), ymm_sum_1); } ymm_sum_0 = _mm256_sad_epu8(ymm_sum_0, POPCNT_ZERO_AVX); ymm_sum_1 = _mm256_sad_epu8(ymm_sum_1, POPCNT_ZERO_AVX); for (; lhs != lhs_0; lhs += 16, rhs += 16) { __m256i ymm_lhs_0 = _mm256_loadu_si256((__m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_loadu_si256((__m256i *)(lhs + 8)); __m256i ymm_rhs_0 = _mm256_loadu_si256((__m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_loadu_si256((__m256i *)(rhs + 8)); ymm_sum_0 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs_0, ymm_rhs_0)), ymm_sum_0); ymm_sum_1 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs_1, ymm_rhs_1)), ymm_sum_1); } if (lhs_2 >= lhs + 8) { __m256i ymm_lhs = _mm256_loadu_si256((__m256i *)(lhs)); __m256i ymm_rhs = _mm256_loadu_si256((__m256i *)(rhs)); ymm_sum_0 = _mm256_add_epi64( VerticalPopCount_INT64_V256(_mm256_xor_si256(ymm_lhs, ymm_rhs)), ymm_sum_0); lhs += 8; rhs += 8; } } size_t count = (size_t)HorizontalAdd_INT64_V256(_mm256_add_epi64(ymm_sum_0, ymm_sum_1)); switch (lhs_2 - lhs) { case 7: count += ailego_popcount32(lhs[6] ^ rhs[6]); /* FALLTHRU */ case 6: count += ailego_popcount32(lhs[5] ^ rhs[5]); /* FALLTHRU */ case 5: count += ailego_popcount32(lhs[4] ^ rhs[4]); /* FALLTHRU */ case 4: count += ailego_popcount32(lhs[3] ^ rhs[3]); /* FALLTHRU */ case 3: count += ailego_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += ailego_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += ailego_popcount32(lhs[0] ^ rhs[0]); } return count; } static inline size_t HammingDistanceAVX(const uint64_t *lhs, const uint64_t *rhs, size_t size) { return HammingDistanceAVX(reinterpret_cast(lhs), reinterpret_cast(rhs), (size << 1)); } #endif // __AVX2__ #if defined(AILEGO_M64) static inline size_t HammingDistance(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8, rhs += 8) { count += ailego_popcount64(*(uint64_t *)(&lhs[6]) ^ *(uint64_t *)(&rhs[6])); count += ailego_popcount64(*(uint64_t *)(&lhs[4]) ^ *(uint64_t *)(&rhs[4])); count += ailego_popcount64(*(uint64_t *)(&lhs[2]) ^ *(uint64_t *)(&rhs[2])); count += ailego_popcount64(*(uint64_t *)(&lhs[0]) ^ *(uint64_t *)(&rhs[0])); } switch (last - last_aligned) { case 7: count += ailego_popcount32(lhs[6] ^ rhs[6]); /* FALLTHRU */ case 6: count += ailego_popcount32(lhs[5] ^ rhs[5]); /* FALLTHRU */ case 5: count += ailego_popcount32(lhs[4] ^ rhs[4]); /* FALLTHRU */ case 4: count += ailego_popcount32(lhs[3] ^ rhs[3]); /* FALLTHRU */ case 3: count += ailego_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += ailego_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += ailego_popcount32(lhs[0] ^ rhs[0]); } return count; } static inline size_t HammingDistance(const uint64_t *lhs, const uint64_t *rhs, size_t size) { const uint64_t *last = lhs + size; const uint64_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += ailego_popcount64(lhs[3] ^ rhs[3]); count += ailego_popcount64(lhs[2] ^ rhs[2]); count += ailego_popcount64(lhs[1] ^ rhs[1]); count += ailego_popcount64(lhs[0] ^ rhs[0]); } switch (last - last_aligned) { case 3: count += ailego_popcount64(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += ailego_popcount64(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += ailego_popcount64(lhs[0] ^ rhs[0]); } return count; } #else static inline size_t HammingDistance(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += ailego_popcount32(lhs[3] ^ rhs[3]); count += ailego_popcount32(lhs[2] ^ rhs[2]); count += ailego_popcount32(lhs[1] ^ rhs[1]); count += ailego_popcount32(lhs[0] ^ rhs[0]); } switch (last - last_aligned) { case 3: count += ailego_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += ailego_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += ailego_popcount32(lhs[0] ^ rhs[0]); } return count; } #endif // AILEGO_M64 //! Compute the distance between matrix and query (UINT32, M=1, N=1) void HammingDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { size_t cnt = (dim >> 5); #if defined(__AVX2__) if (cnt > 63) { *out = static_cast(HammingDistanceAVX(m, q, cnt)); return; } #endif *out = static_cast(HammingDistance(m, q, cnt)); } #if defined(AILEGO_M64) //! Compute the distance between matrix and query (UINT64, M=1, N=1) void HammingDistanceMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { size_t cnt = (dim >> 6); #if defined(__AVX2__) if (cnt > 31) { *out = static_cast(HammingDistanceAVX(m, q, cnt)); return; } #endif *out = static_cast(HammingDistance(m, q, cnt)); } #endif // AILEGO_M64 //! Compute the distance between matrix and query (UINT32, M=1, N=1) void HammingSquareRootDistanceMatrix::Compute( const ValueType *m, const ValueType *q, size_t dim, float *out) { size_t cnt = (dim >> 5); #if defined(__AVX2__) if (cnt > 63) { *out = std::sqrt(static_cast(HammingDistanceAVX(m, q, cnt))); return; } #endif *out = std::sqrt(static_cast(HammingDistance(m, q, cnt))); } #if defined(AILEGO_M64) //! Compute the distance between matrix and query (UINT64, M=1, N=1) void HammingSquareRootDistanceMatrix::Compute( const ValueType *m, const ValueType *q, size_t dim, float *out) { size_t cnt = (dim >> 6); #if defined(__AVX2__) if (cnt > 31) { *out = std::sqrt(static_cast(HammingDistanceAVX(m, q, cnt))); return; } #endif *out = std::sqrt(static_cast(HammingDistance(m, q, cnt))); } #endif // AILEGO_M64 } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/hamming_distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include namespace zvec { namespace ailego { /*! Hamming Distance Matrix */ template // NOTE: useless 'typename=void' to avoid clang // compile error struct HammingDistanceMatrix; /*! Hamming Distance Matrix (UINT32) */ template struct HammingDistanceMatrix { //! Type of value using ValueType = uint32_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && !(dim & 31) && out); size_t cnt = (dim >> 5); if (cnt > 0) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = static_cast(ailego_popcount32(m_val ^ q[j])); r += M; } } m += M; q += N; } for (size_t k = 1; k < cnt; ++k) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += static_cast(ailego_popcount32(m_val ^ q[j])); r += M; } } m += M; q += N; } } }; /*! Hamming Distance Matrix (UINT32, M=1, N=1) */ template <> struct HammingDistanceMatrix { //! Type of value using ValueType = uint32_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; #if defined(AILEGO_M64) /*! Hamming Distance Matrix (UINT64) */ template struct HammingDistanceMatrix { //! Type of value using ValueType = uint64_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && !(dim & 63) && out); size_t cnt = (dim >> 6); if (cnt > 0) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = static_cast(ailego_popcount64(m_val ^ q[j])); r += M; } } m += M; q += N; } for (size_t k = 1; k < cnt; ++k) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += static_cast(ailego_popcount64(m_val ^ q[j])); r += M; } } m += M; q += N; } } }; /*! Hamming Distance Matrix (UINT64, M=1, N=1) */ template <> struct HammingDistanceMatrix { //! Type of value using ValueType = uint64_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; #endif // AILEGO_M64 /*! Hamming Square Root Distance Matrix */ template struct HammingSquareRootDistanceMatrix { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); HammingDistanceMatrix::Compute(m, q, dim, out); for (size_t i = 0; i < N * M; ++i) { float val = *out; *out++ = std::sqrt(val); } } }; /*! Hamming Square Root Distance Matrix (UINT32, M=1, N=1) */ template <> struct HammingSquareRootDistanceMatrix { //! Type of value using ValueType = uint32_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; #if defined(AILEGO_M64) /*! Hamming Square Root Distance Matrix (UINT64, M=1, N=1) */ template <> struct HammingSquareRootDistanceMatrix { //! Type of value using ValueType = uint64_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; #endif // AILEGO_M64 } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "distance_utility.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- /*! Inner Product Matrix */ template struct InnerProductMatrix; /*! Inner Product Matrix */ template struct MinusInnerProductMatrix; /*! Inner Product Matrix (M=1, N=1) */ template struct InnerProductMatrix< T, 1, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += static_cast(m[i] * q[i]); } *out = sum; } }; /*! Minus Inner Product Matrix (M=1, N=1) */ template struct MinusInnerProductMatrix< T, 1, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += static_cast(m[i] * q[i]); } *out = -sum; } }; template <> struct InnerProductMatrix { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct InnerProductMatrix { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct InnerProductMatrix { //! Type of value using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct InnerProductMatrix { //! Type of value using ValueType = float; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct MinusInnerProductMatrix { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct MinusInnerProductMatrix { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct MinusInnerProductMatrix { //! Type of value using ValueType = Float16; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; template <> struct MinusInnerProductMatrix { //! Type of value using ValueType = float; //! Compute the distance between matrix and query static void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out); }; /*! Inner Product Matrix */ template struct InnerProductMatrix< T, M, N, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2 && N >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); if (dim > 0) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = static_cast(m_val * q[j]); r += M; } } m += M; q += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += m_val * q[j]; r += M; } } m += M; q += N; } } }; /*! Inner Product Matrix (N=1) */ template struct InnerProductMatrix< T, M, 1, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); const ValueType *q_end = q + dim; if (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) = static_cast(m[i] * q_val); } m += M; } while (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) += m[i] * q_val; } m += M; } } }; /*! Inner Product Matrix (INT8) */ template struct InnerProductMatrix= 2 && N >= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 2; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = ((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Inner Product Matrix (INT8, N=1) */ template struct InnerProductMatrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 2); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) += FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = ((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Inner Product Matrix (INT4) */ template struct InnerProductMatrix= 2 && N >= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 3; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; /*! Inner Product Matrix (INT4, N=1) */ template struct InnerProductMatrix= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 3); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) += FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; /*! Minus Inner Product Matrix */ template struct MinusInnerProductMatrix< T, M, N, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2 && N >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); if (dim > 0) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = -static_cast(m_val * q[j]); r += M; } } m += M; q += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { ValueType m_val = m[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r -= m_val * q[j]; r += M; } } m += M; q += N; } } }; /*! Minus Inner Product Matrix (N=1) */ template struct MinusInnerProductMatrix< T, M, 1, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && out); const ValueType *q_end = q + dim; if (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) = -static_cast(m[i] * q_val); } m += M; } while (q != q_end) { ValueType q_val = *q++; for (size_t i = 0; i < M; ++i) { *(out + i) -= m[i] * q_val; } m += M; } } }; /*! Minus Inner Product Matrix (INT8) */ template struct MinusInnerProductMatrix< int8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 2; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = -FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r -= FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = ((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Minus Inner Product Matrix (INT8, N=1) */ template struct MinusInnerProductMatrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 2); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = -FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) -= FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { volatile int32_t sum = ((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); return static_cast(sum); } }; /*! Minus Inner Product Matrix (INT4) */ template struct MinusInnerProductMatrix< uint8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); dim >>= 3; if (dim > 0) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = -FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { uint32_t m_val = m_it[i]; float *r = out + i; for (size_t j = 0; j < N; ++j) { *r -= FusedMultiplyAdd(m_val, q_it[j]); r += M; } } m_it += M; q_it += N; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; /*! Minus Inner Product Matrix (INT4, N=1) */ template struct MinusInnerProductMatrix= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the distance between matrix and query static inline void Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { ailego_assert(m && q && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 3); if (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) = -FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } while (q_it != q_end) { uint32_t q_val = *q_it++; for (size_t i = 0; i < M; ++i) { *(out + i) -= FusedMultiplyAdd(m_it[i], q_val); } m_it += M; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } }; //-------------------------------------------------- // Sparse //-------------------------------------------------- struct SparseSegmentInfo { public: uint32_t seg_id_{-1U}; uint32_t vec_cnt_{0}; public: SparseSegmentInfo() : seg_id_{-1U}, vec_cnt_{0} {} SparseSegmentInfo(uint32_t seg_id, uint32_t vec_cnt) : seg_id_{seg_id}, vec_cnt_{vec_cnt} {} }; constexpr static uint32_t SEGMENT_ID_BITS = 16; constexpr static uint32_t SEGMENT_ID_MASK = 0xFFFF; template struct MinusInnerProductSparseMatrix { //! Type of value using ValueType = typename std::remove_cv::type; static inline float ComputeInnerProductSparseInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value); //! Compute the distance between matrix and query static inline void Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out); static inline void transform_sparse_format(uint32_t sparse_count, const uint32_t *sparse_index, const void *sparse_value, std::string &buffer); }; template <> struct MinusInnerProductSparseMatrix { //! Type of value using ValueType = Float16; static float ComputeInnerProductSparseInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value); //! Compute the distance between matrix and query static void Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out); static void transform_sparse_format(uint32_t sparse_count, const uint32_t *sparse_index, const void *sparse_value, std::string &buffer) { uint32_t unit_size = sizeof(ValueType); uint32_t seg_count = 0; if (sparse_count == 0) { buffer.reserve(sizeof(uint32_t) + sizeof(uint32_t)); buffer.append(reinterpret_cast(&sparse_count), sizeof(uint32_t)); buffer.append(reinterpret_cast(&seg_count), sizeof(uint32_t)); return; } std::vector seg_infos; uint32_t cur_seg_id = -1U; uint32_t cur_vec_cnt = 0; for (size_t i = 0; i < sparse_count; ++i) { uint32_t seg_id = sparse_index[i] >> SEGMENT_ID_BITS; if (cur_seg_id == -1U) { cur_seg_id = seg_id; cur_vec_cnt++; } else { if (seg_id == cur_seg_id) { cur_vec_cnt++; } else if (seg_id > cur_seg_id) { seg_infos.emplace_back(cur_seg_id, cur_vec_cnt); cur_seg_id = seg_id; cur_vec_cnt = 1; } else { // std::abort(); } } } if (cur_vec_cnt > 0) { seg_infos.emplace_back(cur_seg_id, cur_vec_cnt); } uint32_t buffer_len = 2 * sizeof(uint32_t) + seg_infos.size() * 2 * sizeof(uint32_t) + sparse_count * (sizeof(uint16_t) + sizeof(ValueType)); buffer.reserve(buffer_len); buffer.append(reinterpret_cast(&sparse_count), sizeof(uint32_t)); seg_count = seg_infos.size(); buffer.append(reinterpret_cast(&seg_count), sizeof(uint32_t)); for (size_t i = 0; i < seg_count; ++i) { uint32_t seg_id = seg_infos[i].seg_id_; buffer.append(reinterpret_cast(&seg_id), sizeof(uint32_t)); } for (size_t i = 0; i < seg_count; ++i) { uint32_t vec_cnt = seg_infos[i].vec_cnt_; buffer.append(reinterpret_cast(&vec_cnt), sizeof(uint32_t)); } for (size_t i = 0; i < sparse_count; ++i) { uint16_t temp_dim = sparse_index[i] & SEGMENT_ID_MASK; buffer.append(reinterpret_cast(&temp_dim), sizeof(uint16_t)); } const char *sparse_value_ptr = reinterpret_cast(sparse_value); for (size_t i = 0; i < sparse_count; ++i) { buffer.append(sparse_value_ptr, unit_size); sparse_value_ptr += unit_size; } } }; template <> struct MinusInnerProductSparseMatrix { //! Type of value using ValueType = float; static float ComputeInnerProductSparseInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); //! Compute the distance between matrix and query static void Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out); static void transform_sparse_format(uint32_t sparse_count, const uint32_t *sparse_index, const void *sparse_value, std::string &buffer) { uint32_t unit_size = sizeof(ValueType); uint32_t seg_count = 0; if (sparse_count == 0) { buffer.reserve(sizeof(uint32_t) + sizeof(uint32_t)); buffer.append(reinterpret_cast(&sparse_count), sizeof(uint32_t)); buffer.append(reinterpret_cast(&seg_count), sizeof(uint32_t)); return; } std::vector seg_infos; uint32_t cur_seg_id = -1U; uint32_t cur_vec_cnt = 0; for (size_t i = 0; i < sparse_count; ++i) { uint32_t seg_id = sparse_index[i] >> SEGMENT_ID_BITS; if (cur_seg_id == -1U) { cur_seg_id = seg_id; cur_vec_cnt++; } else { if (seg_id == cur_seg_id) { cur_vec_cnt++; } else if (seg_id > cur_seg_id) { seg_infos.emplace_back(cur_seg_id, cur_vec_cnt); cur_seg_id = seg_id; cur_vec_cnt = 1; } else { // std::abort(); } } } if (cur_vec_cnt > 0) { seg_infos.emplace_back(cur_seg_id, cur_vec_cnt); } uint32_t buffer_len = 2 * sizeof(uint32_t) + seg_infos.size() * 2 * sizeof(uint32_t) + sparse_count * (sizeof(uint16_t) + sizeof(ValueType)); buffer.reserve(buffer_len); buffer.append(reinterpret_cast(&sparse_count), sizeof(uint32_t)); seg_count = seg_infos.size(); buffer.append(reinterpret_cast(&seg_count), sizeof(uint32_t)); for (size_t i = 0; i < seg_count; ++i) { uint32_t seg_id = seg_infos[i].seg_id_; buffer.append(reinterpret_cast(&seg_id), sizeof(uint32_t)); } for (size_t i = 0; i < seg_count; ++i) { uint32_t vec_cnt = seg_infos[i].vec_cnt_; buffer.append(reinterpret_cast(&vec_cnt), sizeof(uint32_t)); } for (size_t i = 0; i < sparse_count; ++i) { uint16_t temp_dim = sparse_index[i] & SEGMENT_ID_MASK; buffer.append(reinterpret_cast(&temp_dim), sizeof(uint16_t)); } const char *sparse_value_ptr = reinterpret_cast(sparse_value); for (size_t i = 0; i < sparse_count; ++i) { buffer.append(sparse_value_ptr, unit_size); sparse_value_ptr += unit_size; } } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp16_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX__) float InnerProductFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX(lhs, rhs, size, &score, 0ull, ) return score; } float MinusInnerProductFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX(lhs, rhs, size, &score, 0ull, NEGATE_FP32_GENERAL) return score; } #endif //-------------------------------------------------- // Sparse //-------------------------------------------------- #if defined(__AVX__) const static __m128i SHUFFLE_MASK256[256] = { _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), }; constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; float InnerProductSparseInSegmentFp16AVX(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value) { float sum = 0.0f; // handle if the first dim is zero bool m_zero = false; Float16 m_zero_value{0.0f}; if (m_sparse_count > 0 && m_sparse_index[0] == 0) { m_sparse_count--; m_sparse_index++; m_zero_value = *m_sparse_value++; m_zero = true; } bool q_zero = false; Float16 q_zero_value{0.0f}; if (q_sparse_count > 0 && q_sparse_index[0] == 0) { q_sparse_count--; q_sparse_index++; q_zero_value = *q_sparse_value++; q_zero = true; } if (m_zero && q_zero) { sum = m_zero_value * q_zero_value; } size_t i1 = 0, i2 = 0; size_t end1 = m_sparse_count / 8 * 8; size_t end2 = q_sparse_count / 8 * 8; uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); Float16 *val_1 = val_start_1; Float16 *val_2 = val_start_2; if (i1 < end1 && i2 < end2) { while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { i1 += 8; if (i1 >= end1) goto do_scalar; } while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { i2 += 8; if (i2 >= end2) goto do_scalar; } __m128i mm_index_m = _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); __m128i mm_index_q = _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); while (true) { #ifdef DEBUG_PRINT std::cout << "index 1: " << std::endl; print_data16(&mm_index_m); std::cout << "index 2: " << std::endl; print_data16(&mm_index_q); #endif __m128i mm_cmp_res = _mm_cmpistrm(mm_index_q, mm_index_m, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); #ifdef DEBUG_PRINT std::cout << "cmp res: " << std::endl; print_data16(&mm_cmp_res); #endif int r = _mm_extract_epi32(mm_cmp_res, 0); if (r) { int r1 = r; __m128i v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1])); __m128i vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); _mm_storeu_si128(reinterpret_cast<__m128i *>(val_1), vs); val_1 += _mm_popcnt_u32(r1); mm_cmp_res = _mm_cmpistrm( mm_index_m, mm_index_q, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); r = _mm_extract_epi32(mm_cmp_res, 0); r1 = r; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2])); vs = _mm_shuffle_epi8(v, SHUFFLE_MASK256[r1]); _mm_storeu_si128(reinterpret_cast<__m128i *>(val_2), vs); val_2 += _mm_popcnt_u32(r1); } const uint16_t id1_max = m_sparse_index[i1 + 7]; if (id1_max <= q_sparse_index[i2 + 7]) { i1 += 8; if (i1 >= end1) goto do_scalar; mm_index_m = _mm_loadu_si128( reinterpret_cast(&m_sparse_index[i1])); } if (id1_max >= q_sparse_index[i2 + 7]) { i2 += 8; if (i2 >= end2) goto do_scalar; mm_index_q = _mm_loadu_si128( reinterpret_cast(&q_sparse_index[i2])); } } } do_scalar: while (i1 < m_sparse_count && i2 < q_sparse_count) { if (m_sparse_index[i1] == q_sparse_index[i2]) { *val_1++ = m_sparse_value[i1]; *val_2++ = q_sparse_value[i2]; ++i1; ++i2; } else if (m_sparse_index[i1] < q_sparse_index[i2]) { ++i1; } else { ++i2; } } size_t res_num = val_1 - val_start_1; size_t res_num8 = res_num / 8 * 8; if (res_num8) { __m256 sum256 = _mm256_setzero_ps(); for (size_t k = 0; k < res_num8; k += 8) { __m256 ymm_1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_1 + k))); __m256 ymm_2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(val_start_2 + k))); ACCUM_FP32_STEP_AVX(ymm_1, ymm_2, sum256); } sum += HorizontalAdd_FP32_V256(sum256); } for (size_t k = res_num8; k < res_num; ++k) sum += val_start_1[k] * val_start_2[k]; return sum; } #endif // __AVX__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp16_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512F__) float InnerProductFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX512(lhs, rhs, size, &score, 0ull, ) return score; } float MinusInnerProductFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size) { float score{0.0f}; ACCUM_FP16_1X1_AVX512(lhs, rhs, size, &score, 0ull, NEGATE_FP32_GENERAL) return score; } #endif //__AVX512F__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp16_avx512fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512FP16__) //! Inner Product float InnerProductFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size) { const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 6) << 6); __m512h zmm_sum_0 = _mm512_setzero_ph(); __m512h zmm_sum_1 = _mm512_setzero_ph(); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 64, rhs += 64) { FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 0), _mm512_load_ph(rhs + 0), zmm_sum_0) FMA_FP16_AVX512FP16(_mm512_load_ph(lhs + 32), _mm512_load_ph(rhs + 32), zmm_sum_1) } if (last >= last_aligned + 32) { FMA_FP16_AVX512FP16(_mm512_load_ph(lhs), _mm512_load_ph(rhs), zmm_sum_0) lhs += 32; rhs += 32; } } else { for (; lhs != last_aligned; lhs += 64, rhs += 64) { FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 0), _mm512_loadu_ph(rhs + 0), zmm_sum_0) FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs + 32), _mm512_loadu_ph(rhs + 32), zmm_sum_1) } if (last >= last_aligned + 32) { FMA_FP16_AVX512FP16(_mm512_loadu_ph(lhs), _mm512_loadu_ph(rhs), zmm_sum_0) lhs += 32; rhs += 32; } } zmm_sum_0 = _mm512_add_ph(zmm_sum_0, zmm_sum_1); if (lhs != last) { __mmask32 mask = (__mmask32)((1 << (last - lhs)) - 1); __m512i zmm_undefined = _mm512_undefined_epi32(); zmm_sum_0 = _mm512_mask3_fmadd_ph( _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, lhs)), _mm512_castsi512_ph(_mm512_mask_loadu_epi16(zmm_undefined, mask, rhs)), zmm_sum_0, mask); } return HorizontalAdd_FP16_V512(zmm_sum_0); } float MinusInnerProductFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size) { return -1 * InnerProductFp16AVX512FP16(lhs, rhs, size); } #endif // sparse #if defined(__AVX512FP16__) constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; float InnerProductSparseInSegmentFp16AVX512FP16(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value) { const static __m128i SHUFFLE_MASK256[256] = { _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, 2), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 1, 0), _mm_set_epi8(-127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2), _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), }; float sum = 0.0f; // handle if the first dim is zero bool m_zero = false; Float16 m_zero_value{0.0f}; if (m_sparse_count > 0 && m_sparse_index[0] == 0) { m_sparse_count--; m_sparse_index++; m_zero_value = *m_sparse_value++; m_zero = true; } bool q_zero = false; Float16 q_zero_value{0.0f}; if (q_sparse_count > 0 && q_sparse_index[0] == 0) { q_sparse_count--; q_sparse_index++; q_zero_value = *q_sparse_value++; q_zero = true; } if (m_zero && q_zero) { sum = m_zero_value * q_zero_value; } size_t i1 = 0, i2 = 0; size_t end1 = m_sparse_count / 8 * 8; size_t end2 = q_sparse_count / 8 * 8; uint16_t fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; uint16_t fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; Float16 *val_start_1 = reinterpret_cast(fixed_buffer_1); Float16 *val_start_2 = reinterpret_cast(fixed_buffer_2); Float16 *val_1 = val_start_1; Float16 *val_2 = val_start_2; if (i1 < end1 && i2 < end2) { while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { i1 += 8; if (i1 >= end1) goto do_scalar; } while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { i2 += 8; if (i2 >= end2) goto do_scalar; } __m128i mm_index_m = _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); __m128i mm_index_q = _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); while (true) { #ifdef DEBUG_PRINT std::cout << "index 1: " << std::endl; print_data16(&mm_index_m); std::cout << "index 2: " << std::endl; print_data16(&mm_index_q); #endif __m128i mm_cmp_res = _mm_cmpistrm(mm_index_q, mm_index_m, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); #ifdef DEBUG_PRINT std::cout << "cmp res: " << std::endl; print_data16(&mm_cmp_res); #endif int r = _mm_extract_epi32(mm_cmp_res, 0); if (r) { int r1 = r; __m128i v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1])); __m128h vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); _mm_storeu_ph(val_1, vs); val_1 += _mm_popcnt_u32(r1); mm_cmp_res = _mm_cmpistrm( mm_index_m, mm_index_q, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); r = _mm_extract_epi32(mm_cmp_res, 0); r1 = r; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2])); vs = _mm_castsi128_ph(_mm_shuffle_epi8(v, SHUFFLE_MASK256[r1])); _mm_storeu_ph(val_2, vs); val_2 += _mm_popcnt_u32(r1); } const uint16_t id1_max = m_sparse_index[i1 + 7]; if (id1_max <= q_sparse_index[i2 + 7]) { i1 += 8; if (i1 >= end1) goto do_scalar; mm_index_m = _mm_loadu_si128( reinterpret_cast(&m_sparse_index[i1])); } if (id1_max >= q_sparse_index[i2 + 7]) { i2 += 8; if (i2 >= end2) goto do_scalar; mm_index_q = _mm_loadu_si128( reinterpret_cast(&q_sparse_index[i2])); } } } do_scalar: while (i1 < m_sparse_count && i2 < q_sparse_count) { if (m_sparse_index[i1] == q_sparse_index[i2]) { *val_1++ = m_sparse_value[i1]; *val_2++ = q_sparse_value[i2]; ++i1; ++i2; } else if (m_sparse_index[i1] < q_sparse_index[i2]) { ++i1; } else { ++i2; } } size_t res_num = val_1 - val_start_1; size_t res_num8 = res_num / 8 * 8; if (res_num8) { __m128h sum128 = _mm_set1_ph(0); for (size_t k = 0; k < res_num8; k += 8) { sum128 = _mm_add_ph(sum128, _mm_mul_ph(_mm_loadu_ph(val_start_1 + k), _mm_loadu_ph(val_start_2 + k))); } Float16 __attribute__((aligned(16))) tmp_res[8]; _mm_store_ph(tmp_res, sum128); sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3] + tmp_res[4] + tmp_res[5] + tmp_res[6] + tmp_res[7]); } for (size_t k = res_num8; k < res_num; ++k) sum += val_start_1[k] * val_start_2[k]; return sum; } #endif // __AVX512FP16__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp16_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__ARM_NEON) float InnerProductFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size); float MinusInnerProductFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX__) float InnerProductFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size); float MinusInnerProductFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX512F__) float InnerProductFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size); float MinusInnerProductFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size); #endif #if defined(__AVX512FP16__) float InnerProductFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size); float MinusInnerProductFp16AVX512FP16(const Float16 *lhs, const Float16 *rhs, size_t size); #endif float InnerProductFp16Scalar(const Float16 *lhs, const Float16 *rhs, size_t size); float MinusInnerProductFp16Scalar(const Float16 *lhs, const Float16 *rhs, size_t size); //! Compute the distance between matrix and query (FP16, M=1, N=1) void InnerProductMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__ARM_NEON) *out = InnerProductFp16NEON(m, q, dim); #else #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { *out = InnerProductFp16AVX512FP16(m, q, dim); return; } #endif //__AVX512FP16__ #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = InnerProductFp16AVX512(m, q, dim); return; } #endif //__AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = InnerProductFp16AVX(m, q, dim); return; } #endif //__AVX__ *out = InnerProductFp16Scalar(m, q, dim); #endif //__ARM_NEON } //! Compute the distance between matrix and query (FP16, M=1, N=1) void MinusInnerProductMatrix::Compute(const ValueType *m, const ValueType *q, size_t dim, float *out) { #if defined(__ARM_NEON) *out = MinusInnerProductFp16NEON(m, q, dim); #else #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { *out = MinusInnerProductFp16AVX512FP16(m, q, dim); return; } #endif //__AVX512FP16__ #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MinusInnerProductFp16AVX512(m, q, dim); return; } #endif //__AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MinusInnerProductFp16AVX(m, q, dim); return; } #endif //__AVX__ *out = MinusInnerProductFp16Scalar(m, q, dim); #endif //__ARM_NEON } //-------------------------------------------------- // Sparse //-------------------------------------------------- #if defined(__AVX512FP16__) float InnerProductSparseInSegmentFp16AVX512FP16(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value); #endif //__AVX512FP16__ #if defined(__AVX__) float InnerProductSparseInSegmentFp16AVX(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value); #endif //__AVX__ float InnerProductSparseInSegmentFp16Scalar(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value); float MinusInnerProductSparseFp16Scalar(const void *m_sparse_data_in, const void *q_sparse_data_in); //! Compute the distance between matrix and query void MinusInnerProductSparseMatrix::Compute( const void *m_sparse_data_in, const void *q_sparse_data_in, float *out) { *out = MinusInnerProductSparseFp16Scalar(m_sparse_data_in, q_sparse_data_in); } float ComputeInnerProductSparseInSegmentFp16(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { return InnerProductSparseInSegmentFp16AVX512FP16( m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { return InnerProductSparseInSegmentFp16AVX(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } #endif return InnerProductSparseInSegmentFp16Scalar(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp16_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) float InnerProductFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size) { float score; ACCUM_FP16_1X1_NEON(lhs, rhs, size, &score, 0ull, ) return score; } float MinusInnerProductFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size) { float score; ACCUM_FP16_1X1_NEON(lhs, rhs, size, &score, 0ull, NEGATE_FP32_GENERAL) return score; } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp32_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX__) float InnerProductFp32SSEInternal(const float *lhs, const float *rhs, size_t size); //! Inner Product float InnerProductFp32AVXInternal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 4) << 4); __m256 ymm_sum_0 = _mm256_setzero_ps(); __m256 ymm_sum_1 = _mm256_setzero_ps(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); } if (last >= last_aligned + 8) { ymm_sum_0 = _mm256_fmadd_ps(_mm256_load_ps(lhs), _mm256_load_ps(rhs), ymm_sum_0); lhs += 8; rhs += 8; } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); } if (last >= last_aligned + 8) { ymm_sum_0 = _mm256_fmadd_ps(_mm256_loadu_ps(lhs), _mm256_loadu_ps(rhs), ymm_sum_0); lhs += 8; rhs += 8; } } float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); switch (last - lhs) { case 7: FMA_FP32_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: FMA_FP32_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: FMA_FP32_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: FMA_FP32_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductFp32AVX(const float *lhs, const float *rhs, size_t size) { if (size > 7) { return InnerProductFp32AVXInternal(lhs, rhs, size); } return InnerProductFp32SSEInternal(lhs, rhs, size); } float MinusInnerProductFp32AVX(const float *lhs, const float *rhs, size_t size) { return -1 * InnerProductFp32AVX(lhs, rhs, size); } #endif // __AVX__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp32_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX512F__) float InnerProductFp32AVXInternal(const float *lhs, const float *rhs, size_t size); float InnerProductFp32SSEInternal(const float *lhs, const float *rhs, size_t size); //! Inner Product float InnerProductFp32AVX512Internal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 5) << 5); __m512 zmm_sum_0 = _mm512_setzero_ps(); __m512 zmm_sum_1 = _mm512_setzero_ps(); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { FMA_FP32_AVX512(_mm512_load_ps(lhs + 0), _mm512_load_ps(rhs + 0), zmm_sum_0) FMA_FP32_AVX512(_mm512_load_ps(lhs + 16), _mm512_load_ps(rhs + 16), zmm_sum_1) } if (last >= last_aligned + 16) { FMA_FP32_AVX512(_mm512_load_ps(lhs), _mm512_load_ps(rhs), zmm_sum_0) lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 0), _mm512_loadu_ps(rhs + 0), zmm_sum_0) FMA_FP32_AVX512(_mm512_loadu_ps(lhs + 16), _mm512_loadu_ps(rhs + 16), zmm_sum_1) } if (last >= last_aligned + 16) { FMA_FP32_AVX512(_mm512_loadu_ps(lhs), _mm512_loadu_ps(rhs), zmm_sum_0) lhs += 16; rhs += 16; } } zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); if (lhs != last) { __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); __m512 zmm_undefined = _mm512_undefined_ps(); zmm_sum_0 = _mm512_mask3_fmadd_ps( _mm512_mask_loadu_ps(zmm_undefined, mask, lhs), _mm512_mask_loadu_ps(zmm_undefined, mask, rhs), zmm_sum_0, mask); } return HorizontalAdd_FP32_V512(zmm_sum_0); } float InnerProductFp32AVX512(const float *lhs, const float *rhs, size_t size) { if (size > 15) { return InnerProductFp32AVX512Internal(lhs, rhs, size); } if (size > 7) { return InnerProductFp32AVXInternal(lhs, rhs, size); } return InnerProductFp32SSEInternal(lhs, rhs, size); } float MinusInnerProductFp32AVX512(const float *lhs, const float *rhs, size_t size) { return -1 * InnerProductFp32AVX512(lhs, rhs, size); } #endif } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp32_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__ARM_NEON) float InnerProductFp32NEON(const float *lhs, const float *rhs, size_t size); float MinusInnerProductFp32NEON(const float *lhs, const float *rhs, size_t size); #endif #if defined(__AVX512F__) float InnerProductFp32AVX512(const float *lhs, const float *rhs, size_t size); float MinusInnerProductFp32AVX512(const float *lhs, const float *rhs, size_t size); #endif #if defined(__AVX__) float InnerProductFp32AVX(const float *lhs, const float *rhs, size_t size); float MinusInnerProductFp32AVX(const float *lhs, const float *rhs, size_t size); #endif #if defined(__SSE__) float InnerProductFp32SSE(const float *lhs, const float *rhs, size_t size); float MinusInnerProductFp32SSE(const float *lhs, const float *rhs, size_t size); #endif float InnerProductFp32Scalar(const float *lhs, const float *rhs, size_t size); float MinusInnerProductFp32Scalar(const float *lhs, const float *rhs, size_t size); //! Compute the distance between matrix and query (FP32, M=1, N=1) void InnerProductMatrix::Compute(const float *m, const float *q, size_t dim, float *out) { #if defined(__ARM_NEON) *out = InnerProductFp32NEON(m, q, dim); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = InnerProductFp32AVX512(m, q, dim); return; } #endif // __AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = InnerProductFp32AVX(m, q, dim); return; } #endif // __AVX__ #if defined(__SSE__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE) { *out = InnerProductFp32SSE(m, q, dim); return; } #endif // __SSE__ *out = InnerProductFp32Scalar(m, q, dim); #endif // __ARM_NEON } //! Compute the distance between matrix and query (FP32, M=1, N=1) void MinusInnerProductMatrix::Compute(const float *m, const float *q, size_t dim, float *out) { #if defined(__ARM_NEON) *out = MinusInnerProductFp32NEON(m, q, dim); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MinusInnerProductFp32AVX512(m, q, dim); return; } #endif // __AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MinusInnerProductFp32AVX(m, q, dim); return; } #endif // __AVX__ #if defined(__SSE__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE) { *out = MinusInnerProductFp32SSE(m, q, dim); return; } #endif // __SSE__ *out = MinusInnerProductFp32Scalar(m, q, dim); #endif // __ARM_NEON } //-------------------------------------------------- // Sparse //-------------------------------------------------- #if defined(__SSE4_1__) float InnerProductSparseInSegmentFp32SSE(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); #endif float InnerProductSparseInSegmentFp32Scalar(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); float MinusInnerProductSparseFp32Scalar(const void *m_sparse_data_in, const void *q_sparse_data_in); void MinusInnerProductSparseMatrix::Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out) { *out = MinusInnerProductSparseFp32Scalar(m_sparse_data_in, q_sparse_data_in); } float ComputeInnerProductSparseInSegmentFp32(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { return InnerProductSparseInSegmentFp32SSE(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } #endif return InnerProductSparseInSegmentFp32Scalar(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp32_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__ARM_NEON) float InnerProductFp32NEON(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); float32x4_t v_sum_0 = vdupq_n_f32(0); float32x4_t v_sum_1 = vdupq_n_f32(0); for (; lhs != last_aligned; lhs += 8, rhs += 8) { v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs + 0), vld1q_f32(rhs + 0)); v_sum_1 = vfmaq_f32(v_sum_1, vld1q_f32(lhs + 4), vld1q_f32(rhs + 4)); } if (last >= last_aligned + 4) { v_sum_0 = vfmaq_f32(v_sum_0, vld1q_f32(lhs), vld1q_f32(rhs)); lhs += 4; rhs += 4; } float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); switch (last - lhs) { case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result) } return result; } float MinusInnerProductFp32NEON(const float *lhs, const float *rhs, size_t size) { return -1 * InnerProductFp32NEON(lhs, rhs, size); } #endif // __ARM_NEON } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_fp32_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__SSE__) float InnerProductFp32SSEInternal(const float *lhs, const float *rhs, size_t size) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); __m128 xmm_sum_0 = _mm_setzero_ps(); __m128 xmm_sum_1 = _mm_setzero_ps(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); } if (last >= last_aligned + 4) { xmm_sum_0 = _mm_fmadd_ps(_mm_load_ps(lhs), _mm_load_ps(rhs), xmm_sum_0); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); xmm_sum_0 = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum_0); xmm_sum_1 = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum_1); } if (last >= last_aligned + 4) { xmm_sum_0 = _mm_fmadd_ps(_mm_loadu_ps(lhs), _mm_loadu_ps(rhs), xmm_sum_0); lhs += 4; rhs += 4; } } float result = HorizontalAdd_FP32_V128(_mm_add_ps(xmm_sum_0, xmm_sum_1)); switch (last - lhs) { case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductFp32SSE(const float *lhs, const float *rhs, size_t size) { return InnerProductFp32SSEInternal(lhs, rhs, size); } float MinusInnerProductFp32SSE(const float *lhs, const float *rhs, size_t size) { return -1 * InnerProductFp32SSE(lhs, rhs, size); } #endif // __SSE__ //-------------------------------------------------- // Sparse //-------------------------------------------------- #if defined(__SSE4_1__) const static __m128i SHUFFLE_MASK16[16] = { _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), }; constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; float InnerProductSparseInSegmentFp32SSE(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { float sum = 0.0f; // handle if the first dim is zero bool m_zero = false; float m_zero_value = 0.0f; if (m_sparse_count > 0 && m_sparse_index[0] == 0) { m_sparse_count--; m_sparse_index++; m_zero_value = *m_sparse_value++; m_zero = true; } bool q_zero = false; float q_zero_value = 0.0f; if (q_sparse_count > 0 && q_sparse_index[0] == 0) { q_sparse_count--; q_sparse_index++; q_zero_value = *q_sparse_value++; q_zero = true; } if (m_zero && q_zero) { sum = m_zero_value * q_zero_value; } size_t i1 = 0, i2 = 0; size_t end1 = m_sparse_count / 8 * 8; size_t end2 = q_sparse_count / 8 * 8; // std::vector mem1; // std::vector mem2; float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; float *val_start_1 = fixed_buffer_1; float *val_start_2 = fixed_buffer_2; // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { // mem1.reserve(max_count); // mem2.reserve(max_count); // val_start_1 = mem1.data(); // val_start_2 = mem2.data(); // } float *val_1 = val_start_1; float *val_2 = val_start_2; if (i1 < end1 && i2 < end2) { while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { i1 += 8; if (i1 >= end1) goto do_scalar; } while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { i2 += 8; if (i2 >= end2) goto do_scalar; } __m128i mm_index_m = _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); __m128i mm_index_q = _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); while (true) { #ifdef DEBUG_PRINT std::cout << "index 1: " << std::endl; print_data16(&mm_index_m); std::cout << "index 2: " << std::endl; print_data16(&mm_index_q); #endif __m128i mm_cmp_res = _mm_cmpistrm(mm_index_q, mm_index_m, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); #ifdef DEBUG_PRINT std::cout << "cmp res: " << std::endl; print_data16(&mm_cmp_res); #endif int r = _mm_extract_epi32(mm_cmp_res, 0); if (r) { int r1 = r & 15; __m128i v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1])); __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); _mm_storeu_ps(val_1, vs); val_1 += _mm_popcnt_u32(r1); int r2 = (r >> 4) & 15; v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1 + 4])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); _mm_storeu_ps(val_1, vs); val_1 += _mm_popcnt_u32(r2); mm_cmp_res = _mm_cmpistrm( mm_index_m, mm_index_q, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); r = _mm_extract_epi32(mm_cmp_res, 0); r1 = r & 15; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); _mm_storeu_ps(val_2, vs); val_2 += _mm_popcnt_u32(r1); r2 = (r >> 4) & 15; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2 + 4])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); _mm_storeu_ps(val_2, vs); val_2 += _mm_popcnt_u32(r2); } const uint16_t id1_max = m_sparse_index[i1 + 7]; if (id1_max <= q_sparse_index[i2 + 7]) { i1 += 8; if (i1 >= end1) goto do_scalar; mm_index_m = _mm_loadu_si128( reinterpret_cast(&m_sparse_index[i1])); } if (id1_max >= q_sparse_index[i2 + 7]) { i2 += 8; if (i2 >= end2) goto do_scalar; mm_index_q = _mm_loadu_si128( reinterpret_cast(&q_sparse_index[i2])); } } } do_scalar: while (i1 < m_sparse_count && i2 < q_sparse_count) { if (m_sparse_index[i1] == q_sparse_index[i2]) { *val_1++ = m_sparse_value[i1]; *val_2++ = q_sparse_value[i2]; ++i1; ++i2; } else if (m_sparse_index[i1] < q_sparse_index[i2]) { ++i1; } else { ++i2; } } size_t res_num = val_1 - val_start_1; // if (res_num != val_2 - val_start_2) { // std::cerr << "size mismatch!" << std::endl; // } size_t res_num4 = res_num / 4 * 4; if (res_num4) { __m128 sum128 = _mm_set1_ps(0); for (size_t k = 0; k < res_num4; k += 4) { sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), _mm_loadu_ps(val_start_2 + k))); } float __attribute__((aligned(16))) tmp_res[4]; _mm_store_ps(tmp_res, sum128); sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); } for (size_t k = res_num4; k < res_num; ++k) sum += val_start_1[k] * val_start_2[k]; return sum; } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int4_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int4.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX2__) float InnerProductInt4SSEInternal(const uint8_t *lhs, const uint8_t *rhs, size_t size); //! Inner Product float InnerProductInt4AVX2Internal(const uint8_t *lhs, const uint8_t *rhs, size_t size) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 5) << 5); __m256i ymm_sum = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) ymm_sum = _mm256_add_epi32(_mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum); lhs += 16; rhs += 16; } } float result = static_cast(HorizontalAdd_INT32_V256(ymm_sum)); switch (last - lhs) { case 15: FMA_INT4_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: FMA_INT4_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: FMA_INT4_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: FMA_INT4_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: FMA_INT4_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: FMA_INT4_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: FMA_INT4_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: FMA_INT4_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: FMA_INT4_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: FMA_INT4_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: FMA_INT4_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: FMA_INT4_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: FMA_INT4_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_INT4_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_INT4_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size) { if (size > 63) { return InnerProductInt4AVX2Internal(lhs, rhs, size >> 1); } return InnerProductInt4SSEInternal(lhs, rhs, size >> 1); } float MinusInnerProductInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size) { return -InnerProductInt4AVX2(lhs, rhs, size); } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int4_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX2__) float InnerProductInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size); float MinusInnerProductInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size); #endif #if defined(__SSE4_1__) float InnerProductInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); float MinusInnerProductInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size); #endif float InnerProductInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim); float MinusInnerProductInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim); //! Compute the distance between matrix and query (INT4, M=1, N=1) void InnerProductMatrix::Compute(const uint8_t *m, const uint8_t *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = InnerProductInt4AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = InnerProductInt4SSE(m, q, dim); return; } #endif //__SSE4_1__ *out = InnerProductInt4Scalar(m, q, dim); } //! Compute the distance between matrix and query (INT4, M=1, N=1) void MinusInnerProductMatrix::Compute(const uint8_t *m, const uint8_t *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MinusInnerProductInt4AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MinusInnerProductInt4SSE(m, q, dim); return; } #endif //__SSE4_1__ *out = MinusInnerProductInt4Scalar(m, q, dim); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int4_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int4.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__SSE4_1__) float InnerProductInt4SSEInternal(const uint8_t *lhs, const uint8_t *rhs, size_t size) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 4) << 4); __m128i xmm_sum = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum) } } float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); switch (last - lhs) { case 15: FMA_INT4_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: FMA_INT4_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: FMA_INT4_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: FMA_INT4_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: FMA_INT4_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: FMA_INT4_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: FMA_INT4_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: FMA_INT4_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: FMA_INT4_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: FMA_INT4_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: FMA_INT4_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: FMA_INT4_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: FMA_INT4_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_INT4_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_INT4_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size) { return InnerProductInt4SSEInternal(lhs, rhs, size >> 1); } float MinusInnerProductInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size) { return -InnerProductInt4SSE(lhs, rhs, size); } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int8_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX2__) float InnerProductInt8SSEInternal(const int8_t *lhs, const int8_t *rhs, size_t size); inline float InnerProductInt8AVX2Internal(const int8_t *lhs, const int8_t *rhs, size_t size) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 6) << 6); float result = 0.0; __m256i ymm_sum_0 = _mm256_setzero_si256(); __m256i ymm_sum_1 = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); ymm_sum_0 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), ONES_INT16_AVX), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), ONES_INT16_AVX), ymm_sum_1); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); ymm_rhs = _mm256_abs_epi8(ymm_rhs); ymm_sum_0 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), ONES_INT16_AVX), ymm_sum_0); lhs += 32; rhs += 32; } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); xmm_rhs = _mm_abs_epi8(xmm_rhs); ymm_sum_0 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE)), ymm_sum_0); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); ymm_lhs_0 = _mm256_sign_epi8(ymm_lhs_0, ymm_rhs_0); ymm_lhs_1 = _mm256_sign_epi8(ymm_lhs_1, ymm_rhs_1); ymm_rhs_0 = _mm256_abs_epi8(ymm_rhs_0); ymm_rhs_1 = _mm256_abs_epi8(ymm_rhs_1); ymm_sum_0 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_0, ymm_lhs_0), ONES_INT16_AVX), ymm_sum_0); ymm_sum_1 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs_1, ymm_lhs_1), ONES_INT16_AVX), ymm_sum_1); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); ymm_lhs = _mm256_sign_epi8(ymm_lhs, ymm_rhs); ymm_rhs = _mm256_abs_epi8(ymm_rhs); ymm_sum_0 = _mm256_add_epi32( _mm256_madd_epi16(_mm256_maddubs_epi16(ymm_rhs, ymm_lhs), ONES_INT16_AVX), ymm_sum_0); lhs += 32; rhs += 32; } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); xmm_rhs = _mm_abs_epi8(xmm_rhs); ymm_sum_0 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE)), ymm_sum_0); lhs += 16; rhs += 16; } } result = static_cast( HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); switch (last - lhs) { case 15: FMA_INT8_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: FMA_INT8_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: FMA_INT8_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: FMA_INT8_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: FMA_INT8_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: FMA_INT8_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: FMA_INT8_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: FMA_INT8_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: FMA_INT8_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: FMA_INT8_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: FMA_INT8_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: FMA_INT8_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: FMA_INT8_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_INT8_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_INT8_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size) { if (size > 31) { return InnerProductInt8AVX2Internal(lhs, rhs, size); } return InnerProductInt8SSEInternal(lhs, rhs, size); } float MinusInnerProductInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size) { return -InnerProductInt8AVX2(lhs, rhs, size); } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int8_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__AVX2__) float InnerProductInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size); float MinusInnerProductInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size); #endif #if defined(__SSE4_1__) float InnerProductInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size); float MinusInnerProductInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size); #endif float InnerProductInt8Scalar(const int8_t *m, const int8_t *q, size_t dim); float MinusInnerProductInt8Scalar(const int8_t *m, const int8_t *q, size_t dim); //! Compute the distance between matrix and query (INT8, M=1, N=1) void InnerProductMatrix::Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = InnerProductInt8AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = InnerProductInt8SSE(m, q, dim); return; } #endif //__SSE4_1__ *out = InnerProductInt8Scalar(m, q, dim); } //! Compute the distance between matrix and query (INT8, M=1, N=1) void MinusInnerProductMatrix::Compute(const int8_t *m, const int8_t *q, size_t dim, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MinusInnerProductInt8AVX2(m, q, dim); return; } #endif // __AVX2__ #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MinusInnerProductInt8SSE(m, q, dim); return; } #endif //__SSE4_1__ *out = MinusInnerProductInt8Scalar(m, q, dim); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_int8_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_inner_product_utility.i" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- #if defined(__SSE4_1__) //! Inner Product float InnerProductInt8SSEInternal(const int8_t *lhs, const int8_t *rhs, size_t size) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 5) << 5); __m128i xmm_sum_0 = _mm_setzero_si128(); __m128i xmm_sum_1 = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), ONES_INT16_SSE), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), ONES_INT16_SSE), xmm_sum_1); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); xmm_rhs = _mm_abs_epi8(xmm_rhs); xmm_sum_0 = _mm_add_epi32( _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), xmm_sum_0); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); xmm_lhs_0 = _mm_sign_epi8(xmm_lhs_0, xmm_rhs_0); xmm_lhs_1 = _mm_sign_epi8(xmm_lhs_1, xmm_rhs_1); xmm_rhs_0 = _mm_abs_epi8(xmm_rhs_0); xmm_rhs_1 = _mm_abs_epi8(xmm_rhs_1); xmm_sum_0 = _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_0, xmm_lhs_0), ONES_INT16_SSE), xmm_sum_0); xmm_sum_1 = _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs_1, xmm_lhs_1), ONES_INT16_SSE), xmm_sum_1); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); xmm_lhs = _mm_sign_epi8(xmm_lhs, xmm_rhs); xmm_rhs = _mm_abs_epi8(xmm_rhs); xmm_sum_0 = _mm_add_epi32( _mm_madd_epi16(_mm_maddubs_epi16(xmm_rhs, xmm_lhs), ONES_INT16_SSE), xmm_sum_0); lhs += 16; rhs += 16; } } float result = static_cast( HorizontalAdd_INT32_V128(_mm_add_epi32(xmm_sum_0, xmm_sum_1))); switch (last - lhs) { case 15: FMA_INT8_GENERAL(lhs[14], rhs[14], result) /* FALLTHRU */ case 14: FMA_INT8_GENERAL(lhs[13], rhs[13], result) /* FALLTHRU */ case 13: FMA_INT8_GENERAL(lhs[12], rhs[12], result) /* FALLTHRU */ case 12: FMA_INT8_GENERAL(lhs[11], rhs[11], result) /* FALLTHRU */ case 11: FMA_INT8_GENERAL(lhs[10], rhs[10], result) /* FALLTHRU */ case 10: FMA_INT8_GENERAL(lhs[9], rhs[9], result) /* FALLTHRU */ case 9: FMA_INT8_GENERAL(lhs[8], rhs[8], result) /* FALLTHRU */ case 8: FMA_INT8_GENERAL(lhs[7], rhs[7], result) /* FALLTHRU */ case 7: FMA_INT8_GENERAL(lhs[6], rhs[6], result) /* FALLTHRU */ case 6: FMA_INT8_GENERAL(lhs[5], rhs[5], result) /* FALLTHRU */ case 5: FMA_INT8_GENERAL(lhs[4], rhs[4], result) /* FALLTHRU */ case 4: FMA_INT8_GENERAL(lhs[3], rhs[3], result) /* FALLTHRU */ case 3: FMA_INT8_GENERAL(lhs[2], rhs[2], result) /* FALLTHRU */ case 2: FMA_INT8_GENERAL(lhs[1], rhs[1], result) /* FALLTHRU */ case 1: FMA_INT8_GENERAL(lhs[0], rhs[0], result) } return result; } float InnerProductInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size) { return InnerProductInt8SSEInternal(lhs, rhs, size); } float MinusInnerProductInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size) { return -InnerProductInt8SSEInternal(lhs, rhs, size); } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/inner_product_matrix_scalar.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "distance_utility.h" #include "inner_product_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- template inline float InnerProductScalar(const T *m, const T *q, size_t dim) { ailego_assert(m && q && dim); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += static_cast(m[i] * q[i]); } return sum; } template inline float MinusInnerProductScalar(const T *m, const T *q, size_t dim) { ailego_assert(m && q && dim); float sum = 0.0; for (size_t i = 0; i < dim; ++i) { sum += static_cast(m[i] * q[i]); } return -sum; } float InnerProductInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim) { ailego_assert(m && q && dim && !(dim & 1)); float sum = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { uint8_t m_val = m[i]; uint8_t q_val = q[i]; sum += Int4MulTable[((m_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4MulTable[((m_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } return sum; } float MinusInnerProductInt4Scalar(const uint8_t *m, const uint8_t *q, size_t dim) { ailego_assert(m && q && dim && !(dim & 1)); float sum = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { uint8_t m_val = m[i]; uint8_t q_val = q[i]; sum -= Int4MulTable[((m_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4MulTable[((m_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } return sum; } float InnerProductInt8Scalar(const int8_t *m, const int8_t *q, size_t dim) { return InnerProductScalar(m, q, dim); } float MinusInnerProductInt8Scalar(const int8_t *m, const int8_t *q, size_t dim) { return MinusInnerProductScalar(m, q, dim); } float InnerProductFp16Scalar(const ailego::Float16 *m, const ailego::Float16 *q, size_t dim) { return InnerProductScalar(m, q, dim); } float MinusInnerProductFp16Scalar(const ailego::Float16 *m, const ailego::Float16 *q, size_t dim) { return MinusInnerProductScalar(m, q, dim); } float InnerProductFp32Scalar(const float *m, const float *q, size_t dim) { return InnerProductScalar(m, q, dim); } float MinusInnerProductFp32Scalar(const float *m, const float *q, size_t dim) { return MinusInnerProductScalar(m, q, dim); } //-------------------------------------------------- // Sparse //-------------------------------------------------- float ComputeInnerProductSparseInSegmentFp32(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); float ComputeInnerProductSparseInSegmentFp16(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value); template float ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const T *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const T *q_sparse_value); template <> float ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { return ComputeInnerProductSparseInSegmentFp32(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } template <> float ComputeInnerProductSparseInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value) { return ComputeInnerProductSparseInSegmentFp16(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); } template float ComputeSegments(const void *m_sparse_data_in, const void *q_sparse_data_in) { ailego_assert(m_sparse_data_in && q_sparse_data_in); float sum{0.0f}; const uint8_t *m_sparse_data = reinterpret_cast(m_sparse_data_in); const uint8_t *q_sparse_data = reinterpret_cast(q_sparse_data_in); const uint32_t m_sparse_count = *reinterpret_cast(m_sparse_data); const uint32_t q_sparse_count = *reinterpret_cast(q_sparse_data); if (m_sparse_count == 0 || q_sparse_count == 0) { return 0.0f; } const uint32_t m_seg_count = *reinterpret_cast(m_sparse_data + sizeof(uint32_t)); const uint32_t q_seg_count = *reinterpret_cast(q_sparse_data + sizeof(uint32_t)); const uint32_t *m_seg_id = reinterpret_cast(m_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *q_seg_id = reinterpret_cast(q_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *m_seg_vec_cnt = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * sizeof(uint32_t)); const uint32_t *q_seg_vec_cnt = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * sizeof(uint32_t)); const uint16_t *m_sparse_index = reinterpret_cast(m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t)); const uint16_t *q_sparse_index = reinterpret_cast(q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t)); const T *m_sparse_value = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t) + m_sparse_count * sizeof(uint16_t)); const T *q_sparse_value = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t) + q_sparse_count * sizeof(uint16_t)); size_t m_s = 0; size_t q_s = 0; size_t m_count = 0; size_t q_count = 0; while (m_s < m_seg_count && q_s < q_seg_count) { if (m_seg_id[m_s] == q_seg_id[q_s]) { sum += ComputeInnerProductSparseInSegment( m_seg_vec_cnt[m_s], m_sparse_index + m_count, m_sparse_value + m_count, q_seg_vec_cnt[q_s], q_sparse_index + q_count, q_sparse_value + q_count); m_count += m_seg_vec_cnt[m_s]; q_count += q_seg_vec_cnt[q_s]; ++m_s; ++q_s; } else if (m_seg_id[m_s] < q_seg_id[q_s]) { m_count += m_seg_vec_cnt[m_s]; ++m_s; } else { q_count += q_seg_vec_cnt[q_s]; ++q_s; } } return -sum; } float MinusInnerProductSparseFp16Scalar(const void *m_sparse_data_in, const void *q_sparse_data_in) { return ComputeSegments(m_sparse_data_in, q_sparse_data_in); } float MinusInnerProductSparseFp32Scalar(const void *m_sparse_data_in, const void *q_sparse_data_in) { return ComputeSegments(m_sparse_data_in, q_sparse_data_in); } float InnerProductSparseInSegmentFp16Scalar(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const Float16 *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const Float16 *q_sparse_value) { float sum = 0.0f; size_t m_i = 0; size_t q_i = 0; while (m_i < m_sparse_count && q_i < q_sparse_count) { if (m_sparse_index[m_i] == q_sparse_index[q_i]) { sum += m_sparse_value[m_i] * q_sparse_value[q_i]; ++m_i; ++q_i; } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { ++m_i; } else { ++q_i; } } return sum; } float InnerProductSparseInSegmentFp32Scalar(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { float sum = 0.0f; size_t m_i = 0; size_t q_i = 0; while (m_i < m_sparse_count && q_i < q_sparse_count) { if (m_sparse_index[m_i] == q_sparse_index[q_i]) { sum += m_sparse_value[m_i] * q_sparse_value[q_i]; ++m_i; ++q_i; } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { ++m_i; } else { ++q_i; } } return sum; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/matrix_define.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #define MATRIX_VAR_INIT_1X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_0 = (_VAR_INIT); #define MATRIX_VAR_INIT_1X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_1X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_1 = (_VAR_INIT); #define MATRIX_VAR_INIT_1X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_1X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_3 = (_VAR_INIT); #define MATRIX_VAR_INIT_1X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_1X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_7 = (_VAR_INIT); #define MATRIX_VAR_INIT_1X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_1X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_15 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_1X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_1_0 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_1 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_3 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_7 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_15 = (_VAR_INIT); #define MATRIX_VAR_INIT_2X32(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_31 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_2X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_2_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_0 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_1 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_3 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_7 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_15 = (_VAR_INIT); #define MATRIX_VAR_INIT_4X32(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_31 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_4X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_4_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_0 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_1 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_3 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_4 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_5 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_6 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_7 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_7 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X8(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_8 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_9 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_10 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_11 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_12 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_13 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_14 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_15 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_15 = (_VAR_INIT); #define MATRIX_VAR_INIT_8X32(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X16(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_16 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_17 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_18 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_19 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_20 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_21 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_22 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_23 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_24 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_25 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_26 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_27 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_28 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_29 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_30 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_31 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_31 = (_VAR_INIT); #define MATRIX_VAR_INIT_16X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_8X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_8_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_9_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_10_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_11_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_12_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_13_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_14_0 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_15_0 = (_VAR_INIT); #define MATRIX_VAR_INIT_16X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_16X1(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_8_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_9_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_10_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_11_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_12_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_13_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_14_1 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_15_1 = (_VAR_INIT); #define MATRIX_VAR_INIT_16X4(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_16X2(_VAR_TYPE, _VAR_NAME, _VAR_INIT) \ _VAR_TYPE _VAR_NAME##_0_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_8_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_9_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_10_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_11_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_12_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_13_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_14_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_15_2 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_0_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_1_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_2_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_3_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_4_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_5_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_6_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_7_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_8_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_9_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_10_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_11_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_12_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_13_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_14_3 = (_VAR_INIT); \ _VAR_TYPE _VAR_NAME##_15_3 = (_VAR_INIT); #define MATRIX_VAR_STORE_1X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ _STORE((_ARRAY) + (_STEP) * (0), _NORM((_VAR##_0_0), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_1X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_1X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (1), _NORM((_VAR##_0_1), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_1X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_1X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (2), _NORM((_VAR##_0_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (3), _NORM((_VAR##_0_3), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_1X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_1X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (4), _NORM((_VAR##_0_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (5), _NORM((_VAR##_0_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (6), _NORM((_VAR##_0_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (7), _NORM((_VAR##_0_7), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_1X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_1X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (8), _NORM((_VAR##_0_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (9), _NORM((_VAR##_0_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (10), _NORM((_VAR##_0_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (11), _NORM((_VAR##_0_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (12), _NORM((_VAR##_0_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (13), _NORM((_VAR##_0_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (14), _NORM((_VAR##_0_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (15), _NORM((_VAR##_0_15), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_1X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (1), _NORM((_VAR##_1_0), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (2), _NORM((_VAR##_0_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (3), _NORM((_VAR##_1_1), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (4), _NORM((_VAR##_0_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (5), _NORM((_VAR##_1_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (6), _NORM((_VAR##_0_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (7), _NORM((_VAR##_1_3), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (8), _NORM((_VAR##_0_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (9), _NORM((_VAR##_1_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (10), _NORM((_VAR##_0_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (11), _NORM((_VAR##_1_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (12), _NORM((_VAR##_0_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (13), _NORM((_VAR##_1_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (14), _NORM((_VAR##_0_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (15), _NORM((_VAR##_1_7), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (16), _NORM((_VAR##_0_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (17), _NORM((_VAR##_1_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (18), _NORM((_VAR##_0_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (19), _NORM((_VAR##_1_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (20), _NORM((_VAR##_0_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (21), _NORM((_VAR##_1_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (22), _NORM((_VAR##_0_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (23), _NORM((_VAR##_1_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (24), _NORM((_VAR##_0_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (25), _NORM((_VAR##_1_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (26), _NORM((_VAR##_0_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (27), _NORM((_VAR##_1_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (28), _NORM((_VAR##_0_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (29), _NORM((_VAR##_1_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (30), _NORM((_VAR##_0_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (31), _NORM((_VAR##_1_15), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_2X32(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (32), _NORM((_VAR##_0_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (33), _NORM((_VAR##_1_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (34), _NORM((_VAR##_0_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (35), _NORM((_VAR##_1_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (36), _NORM((_VAR##_0_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (37), _NORM((_VAR##_1_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (38), _NORM((_VAR##_0_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (39), _NORM((_VAR##_1_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (40), _NORM((_VAR##_0_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (41), _NORM((_VAR##_1_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (42), _NORM((_VAR##_0_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (43), _NORM((_VAR##_1_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (44), _NORM((_VAR##_0_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (45), _NORM((_VAR##_1_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (46), _NORM((_VAR##_0_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (47), _NORM((_VAR##_1_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (48), _NORM((_VAR##_0_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (49), _NORM((_VAR##_1_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (50), _NORM((_VAR##_0_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (51), _NORM((_VAR##_1_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (52), _NORM((_VAR##_0_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (53), _NORM((_VAR##_1_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (54), _NORM((_VAR##_0_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (55), _NORM((_VAR##_1_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (56), _NORM((_VAR##_0_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (57), _NORM((_VAR##_1_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (58), _NORM((_VAR##_0_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (59), _NORM((_VAR##_1_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (60), _NORM((_VAR##_0_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (61), _NORM((_VAR##_1_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (62), _NORM((_VAR##_0_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (63), _NORM((_VAR##_1_31), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_2X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (2), _NORM((_VAR##_2_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (3), _NORM((_VAR##_3_0), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (4), _NORM((_VAR##_0_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (5), _NORM((_VAR##_1_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (6), _NORM((_VAR##_2_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (7), _NORM((_VAR##_3_1), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (8), _NORM((_VAR##_0_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (9), _NORM((_VAR##_1_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (10), _NORM((_VAR##_2_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (11), _NORM((_VAR##_3_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (12), _NORM((_VAR##_0_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (13), _NORM((_VAR##_1_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (14), _NORM((_VAR##_2_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (15), _NORM((_VAR##_3_3), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (16), _NORM((_VAR##_0_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (17), _NORM((_VAR##_1_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (18), _NORM((_VAR##_2_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (19), _NORM((_VAR##_3_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (20), _NORM((_VAR##_0_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (21), _NORM((_VAR##_1_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (22), _NORM((_VAR##_2_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (23), _NORM((_VAR##_3_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (24), _NORM((_VAR##_0_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (25), _NORM((_VAR##_1_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (26), _NORM((_VAR##_2_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (27), _NORM((_VAR##_3_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (28), _NORM((_VAR##_0_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (29), _NORM((_VAR##_1_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (30), _NORM((_VAR##_2_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (31), _NORM((_VAR##_3_7), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (32), _NORM((_VAR##_0_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (33), _NORM((_VAR##_1_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (34), _NORM((_VAR##_2_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (35), _NORM((_VAR##_3_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (36), _NORM((_VAR##_0_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (37), _NORM((_VAR##_1_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (38), _NORM((_VAR##_2_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (39), _NORM((_VAR##_3_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (40), _NORM((_VAR##_0_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (41), _NORM((_VAR##_1_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (42), _NORM((_VAR##_2_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (43), _NORM((_VAR##_3_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (44), _NORM((_VAR##_0_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (45), _NORM((_VAR##_1_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (46), _NORM((_VAR##_2_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (47), _NORM((_VAR##_3_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (48), _NORM((_VAR##_0_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (49), _NORM((_VAR##_1_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (50), _NORM((_VAR##_2_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (51), _NORM((_VAR##_3_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (52), _NORM((_VAR##_0_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (53), _NORM((_VAR##_1_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (54), _NORM((_VAR##_2_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (55), _NORM((_VAR##_3_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (56), _NORM((_VAR##_0_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (57), _NORM((_VAR##_1_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (58), _NORM((_VAR##_2_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (59), _NORM((_VAR##_3_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (60), _NORM((_VAR##_0_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (61), _NORM((_VAR##_1_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (62), _NORM((_VAR##_2_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (63), _NORM((_VAR##_3_15), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_4X32(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (64), _NORM((_VAR##_0_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (65), _NORM((_VAR##_1_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (66), _NORM((_VAR##_2_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (67), _NORM((_VAR##_3_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (68), _NORM((_VAR##_0_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (69), _NORM((_VAR##_1_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (70), _NORM((_VAR##_2_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (71), _NORM((_VAR##_3_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (72), _NORM((_VAR##_0_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (73), _NORM((_VAR##_1_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (74), _NORM((_VAR##_2_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (75), _NORM((_VAR##_3_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (76), _NORM((_VAR##_0_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (77), _NORM((_VAR##_1_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (78), _NORM((_VAR##_2_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (79), _NORM((_VAR##_3_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (80), _NORM((_VAR##_0_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (81), _NORM((_VAR##_1_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (82), _NORM((_VAR##_2_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (83), _NORM((_VAR##_3_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (84), _NORM((_VAR##_0_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (85), _NORM((_VAR##_1_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (86), _NORM((_VAR##_2_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (87), _NORM((_VAR##_3_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (88), _NORM((_VAR##_0_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (89), _NORM((_VAR##_1_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (90), _NORM((_VAR##_2_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (91), _NORM((_VAR##_3_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (92), _NORM((_VAR##_0_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (93), _NORM((_VAR##_1_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (94), _NORM((_VAR##_2_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (95), _NORM((_VAR##_3_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (96), _NORM((_VAR##_0_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (97), _NORM((_VAR##_1_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (98), _NORM((_VAR##_2_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (99), _NORM((_VAR##_3_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (100), _NORM((_VAR##_0_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (101), _NORM((_VAR##_1_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (102), _NORM((_VAR##_2_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (103), _NORM((_VAR##_3_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (104), _NORM((_VAR##_0_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (105), _NORM((_VAR##_1_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (106), _NORM((_VAR##_2_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (107), _NORM((_VAR##_3_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (108), _NORM((_VAR##_0_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (109), _NORM((_VAR##_1_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (110), _NORM((_VAR##_2_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (111), _NORM((_VAR##_3_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (112), _NORM((_VAR##_0_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (113), _NORM((_VAR##_1_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (114), _NORM((_VAR##_2_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (115), _NORM((_VAR##_3_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (116), _NORM((_VAR##_0_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (117), _NORM((_VAR##_1_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (118), _NORM((_VAR##_2_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (119), _NORM((_VAR##_3_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (120), _NORM((_VAR##_0_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (121), _NORM((_VAR##_1_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (122), _NORM((_VAR##_2_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (123), _NORM((_VAR##_3_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (124), _NORM((_VAR##_0_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (125), _NORM((_VAR##_1_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (126), _NORM((_VAR##_2_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (127), _NORM((_VAR##_3_31), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_4X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (4), _NORM((_VAR##_4_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (5), _NORM((_VAR##_5_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (6), _NORM((_VAR##_6_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (7), _NORM((_VAR##_7_0), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (8), _NORM((_VAR##_0_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (9), _NORM((_VAR##_1_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (10), _NORM((_VAR##_2_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (11), _NORM((_VAR##_3_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (12), _NORM((_VAR##_4_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (13), _NORM((_VAR##_5_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (14), _NORM((_VAR##_6_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (15), _NORM((_VAR##_7_1), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (16), _NORM((_VAR##_0_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (17), _NORM((_VAR##_1_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (18), _NORM((_VAR##_2_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (19), _NORM((_VAR##_3_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (20), _NORM((_VAR##_4_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (21), _NORM((_VAR##_5_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (22), _NORM((_VAR##_6_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (23), _NORM((_VAR##_7_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (24), _NORM((_VAR##_0_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (25), _NORM((_VAR##_1_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (26), _NORM((_VAR##_2_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (27), _NORM((_VAR##_3_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (28), _NORM((_VAR##_4_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (29), _NORM((_VAR##_5_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (30), _NORM((_VAR##_6_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (31), _NORM((_VAR##_7_3), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (32), _NORM((_VAR##_0_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (33), _NORM((_VAR##_1_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (34), _NORM((_VAR##_2_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (35), _NORM((_VAR##_3_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (36), _NORM((_VAR##_4_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (37), _NORM((_VAR##_5_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (38), _NORM((_VAR##_6_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (39), _NORM((_VAR##_7_4), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (40), _NORM((_VAR##_0_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (41), _NORM((_VAR##_1_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (42), _NORM((_VAR##_2_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (43), _NORM((_VAR##_3_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (44), _NORM((_VAR##_4_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (45), _NORM((_VAR##_5_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (46), _NORM((_VAR##_6_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (47), _NORM((_VAR##_7_5), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (48), _NORM((_VAR##_0_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (49), _NORM((_VAR##_1_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (50), _NORM((_VAR##_2_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (51), _NORM((_VAR##_3_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (52), _NORM((_VAR##_4_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (53), _NORM((_VAR##_5_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (54), _NORM((_VAR##_6_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (55), _NORM((_VAR##_7_6), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (56), _NORM((_VAR##_0_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (57), _NORM((_VAR##_1_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (58), _NORM((_VAR##_2_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (59), _NORM((_VAR##_3_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (60), _NORM((_VAR##_4_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (61), _NORM((_VAR##_5_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (62), _NORM((_VAR##_6_7), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (63), _NORM((_VAR##_7_7), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X8(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (64), _NORM((_VAR##_0_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (65), _NORM((_VAR##_1_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (66), _NORM((_VAR##_2_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (67), _NORM((_VAR##_3_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (68), _NORM((_VAR##_4_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (69), _NORM((_VAR##_5_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (70), _NORM((_VAR##_6_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (71), _NORM((_VAR##_7_8), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (72), _NORM((_VAR##_0_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (73), _NORM((_VAR##_1_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (74), _NORM((_VAR##_2_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (75), _NORM((_VAR##_3_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (76), _NORM((_VAR##_4_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (77), _NORM((_VAR##_5_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (78), _NORM((_VAR##_6_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (79), _NORM((_VAR##_7_9), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (80), _NORM((_VAR##_0_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (81), _NORM((_VAR##_1_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (82), _NORM((_VAR##_2_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (83), _NORM((_VAR##_3_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (84), _NORM((_VAR##_4_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (85), _NORM((_VAR##_5_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (86), _NORM((_VAR##_6_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (87), _NORM((_VAR##_7_10), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (88), _NORM((_VAR##_0_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (89), _NORM((_VAR##_1_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (90), _NORM((_VAR##_2_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (91), _NORM((_VAR##_3_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (92), _NORM((_VAR##_4_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (93), _NORM((_VAR##_5_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (94), _NORM((_VAR##_6_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (95), _NORM((_VAR##_7_11), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (96), _NORM((_VAR##_0_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (97), _NORM((_VAR##_1_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (98), _NORM((_VAR##_2_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (99), _NORM((_VAR##_3_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (100), _NORM((_VAR##_4_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (101), _NORM((_VAR##_5_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (102), _NORM((_VAR##_6_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (103), _NORM((_VAR##_7_12), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (104), _NORM((_VAR##_0_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (105), _NORM((_VAR##_1_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (106), _NORM((_VAR##_2_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (107), _NORM((_VAR##_3_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (108), _NORM((_VAR##_4_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (109), _NORM((_VAR##_5_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (110), _NORM((_VAR##_6_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (111), _NORM((_VAR##_7_13), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (112), _NORM((_VAR##_0_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (113), _NORM((_VAR##_1_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (114), _NORM((_VAR##_2_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (115), _NORM((_VAR##_3_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (116), _NORM((_VAR##_4_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (117), _NORM((_VAR##_5_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (118), _NORM((_VAR##_6_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (119), _NORM((_VAR##_7_14), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (120), _NORM((_VAR##_0_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (121), _NORM((_VAR##_1_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (122), _NORM((_VAR##_2_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (123), _NORM((_VAR##_3_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (124), _NORM((_VAR##_4_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (125), _NORM((_VAR##_5_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (126), _NORM((_VAR##_6_15), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (127), _NORM((_VAR##_7_15), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_8X32(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X16(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (128), _NORM((_VAR##_0_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (129), _NORM((_VAR##_1_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (130), _NORM((_VAR##_2_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (131), _NORM((_VAR##_3_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (132), _NORM((_VAR##_4_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (133), _NORM((_VAR##_5_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (134), _NORM((_VAR##_6_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (135), _NORM((_VAR##_7_16), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (136), _NORM((_VAR##_0_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (137), _NORM((_VAR##_1_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (138), _NORM((_VAR##_2_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (139), _NORM((_VAR##_3_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (140), _NORM((_VAR##_4_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (141), _NORM((_VAR##_5_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (142), _NORM((_VAR##_6_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (143), _NORM((_VAR##_7_17), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (144), _NORM((_VAR##_0_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (145), _NORM((_VAR##_1_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (146), _NORM((_VAR##_2_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (147), _NORM((_VAR##_3_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (148), _NORM((_VAR##_4_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (149), _NORM((_VAR##_5_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (150), _NORM((_VAR##_6_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (151), _NORM((_VAR##_7_18), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (152), _NORM((_VAR##_0_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (153), _NORM((_VAR##_1_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (154), _NORM((_VAR##_2_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (155), _NORM((_VAR##_3_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (156), _NORM((_VAR##_4_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (157), _NORM((_VAR##_5_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (158), _NORM((_VAR##_6_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (159), _NORM((_VAR##_7_19), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (160), _NORM((_VAR##_0_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (161), _NORM((_VAR##_1_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (162), _NORM((_VAR##_2_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (163), _NORM((_VAR##_3_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (164), _NORM((_VAR##_4_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (165), _NORM((_VAR##_5_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (166), _NORM((_VAR##_6_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (167), _NORM((_VAR##_7_20), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (168), _NORM((_VAR##_0_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (169), _NORM((_VAR##_1_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (170), _NORM((_VAR##_2_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (171), _NORM((_VAR##_3_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (172), _NORM((_VAR##_4_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (173), _NORM((_VAR##_5_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (174), _NORM((_VAR##_6_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (175), _NORM((_VAR##_7_21), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (176), _NORM((_VAR##_0_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (177), _NORM((_VAR##_1_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (178), _NORM((_VAR##_2_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (179), _NORM((_VAR##_3_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (180), _NORM((_VAR##_4_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (181), _NORM((_VAR##_5_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (182), _NORM((_VAR##_6_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (183), _NORM((_VAR##_7_22), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (184), _NORM((_VAR##_0_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (185), _NORM((_VAR##_1_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (186), _NORM((_VAR##_2_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (187), _NORM((_VAR##_3_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (188), _NORM((_VAR##_4_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (189), _NORM((_VAR##_5_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (190), _NORM((_VAR##_6_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (191), _NORM((_VAR##_7_23), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (192), _NORM((_VAR##_0_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (193), _NORM((_VAR##_1_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (194), _NORM((_VAR##_2_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (195), _NORM((_VAR##_3_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (196), _NORM((_VAR##_4_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (197), _NORM((_VAR##_5_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (198), _NORM((_VAR##_6_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (199), _NORM((_VAR##_7_24), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (200), _NORM((_VAR##_0_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (201), _NORM((_VAR##_1_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (202), _NORM((_VAR##_2_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (203), _NORM((_VAR##_3_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (204), _NORM((_VAR##_4_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (205), _NORM((_VAR##_5_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (206), _NORM((_VAR##_6_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (207), _NORM((_VAR##_7_25), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (208), _NORM((_VAR##_0_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (209), _NORM((_VAR##_1_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (210), _NORM((_VAR##_2_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (211), _NORM((_VAR##_3_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (212), _NORM((_VAR##_4_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (213), _NORM((_VAR##_5_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (214), _NORM((_VAR##_6_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (215), _NORM((_VAR##_7_26), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (216), _NORM((_VAR##_0_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (217), _NORM((_VAR##_1_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (218), _NORM((_VAR##_2_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (219), _NORM((_VAR##_3_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (220), _NORM((_VAR##_4_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (221), _NORM((_VAR##_5_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (222), _NORM((_VAR##_6_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (223), _NORM((_VAR##_7_27), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (224), _NORM((_VAR##_0_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (225), _NORM((_VAR##_1_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (226), _NORM((_VAR##_2_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (227), _NORM((_VAR##_3_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (228), _NORM((_VAR##_4_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (229), _NORM((_VAR##_5_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (230), _NORM((_VAR##_6_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (231), _NORM((_VAR##_7_28), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (232), _NORM((_VAR##_0_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (233), _NORM((_VAR##_1_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (234), _NORM((_VAR##_2_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (235), _NORM((_VAR##_3_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (236), _NORM((_VAR##_4_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (237), _NORM((_VAR##_5_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (238), _NORM((_VAR##_6_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (239), _NORM((_VAR##_7_29), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (240), _NORM((_VAR##_0_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (241), _NORM((_VAR##_1_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (242), _NORM((_VAR##_2_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (243), _NORM((_VAR##_3_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (244), _NORM((_VAR##_4_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (245), _NORM((_VAR##_5_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (246), _NORM((_VAR##_6_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (247), _NORM((_VAR##_7_30), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (248), _NORM((_VAR##_0_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (249), _NORM((_VAR##_1_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (250), _NORM((_VAR##_2_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (251), _NORM((_VAR##_3_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (252), _NORM((_VAR##_4_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (253), _NORM((_VAR##_5_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (254), _NORM((_VAR##_6_31), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (255), _NORM((_VAR##_7_31), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_16X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_8X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (8), _NORM((_VAR##_8_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (9), _NORM((_VAR##_9_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (10), _NORM((_VAR##_10_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (11), _NORM((_VAR##_11_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (12), _NORM((_VAR##_12_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (13), _NORM((_VAR##_13_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (14), _NORM((_VAR##_14_0), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (15), _NORM((_VAR##_15_0), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_16X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_16X1(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (16), _NORM((_VAR##_0_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (17), _NORM((_VAR##_1_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (18), _NORM((_VAR##_2_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (19), _NORM((_VAR##_3_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (20), _NORM((_VAR##_4_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (21), _NORM((_VAR##_5_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (22), _NORM((_VAR##_6_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (23), _NORM((_VAR##_7_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (24), _NORM((_VAR##_8_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (25), _NORM((_VAR##_9_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (26), _NORM((_VAR##_10_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (27), _NORM((_VAR##_11_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (28), _NORM((_VAR##_12_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (29), _NORM((_VAR##_13_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (30), _NORM((_VAR##_14_1), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (31), _NORM((_VAR##_15_1), ##__VA_ARGS__)); #define MATRIX_VAR_STORE_16X4(_STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_16X2(_STEP, _VAR, _ARRAY, _STORE, _NORM, ##__VA_ARGS__) \ _STORE((_ARRAY) + (_STEP) * (32), _NORM((_VAR##_0_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (33), _NORM((_VAR##_1_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (34), _NORM((_VAR##_2_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (35), _NORM((_VAR##_3_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (36), _NORM((_VAR##_4_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (37), _NORM((_VAR##_5_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (38), _NORM((_VAR##_6_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (39), _NORM((_VAR##_7_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (40), _NORM((_VAR##_8_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (41), _NORM((_VAR##_9_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (42), _NORM((_VAR##_10_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (43), _NORM((_VAR##_11_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (44), _NORM((_VAR##_12_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (45), _NORM((_VAR##_13_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (46), _NORM((_VAR##_14_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (47), _NORM((_VAR##_15_2), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (48), _NORM((_VAR##_0_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (49), _NORM((_VAR##_1_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (50), _NORM((_VAR##_2_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (51), _NORM((_VAR##_3_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (52), _NORM((_VAR##_4_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (53), _NORM((_VAR##_5_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (54), _NORM((_VAR##_6_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (55), _NORM((_VAR##_7_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (56), _NORM((_VAR##_8_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (57), _NORM((_VAR##_9_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (58), _NORM((_VAR##_10_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (59), _NORM((_VAR##_11_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (60), _NORM((_VAR##_12_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (61), _NORM((_VAR##_13_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (62), _NORM((_VAR##_14_3), ##__VA_ARGS__)); \ _STORE((_ARRAY) + (_STEP) * (63), _NORM((_VAR##_15_3), ##__VA_ARGS__)); #define MATRIX_VAR_PERMUTE_1X1(_VAR, _PERMUTE, ...) \ (_VAR##_0_0) = _PERMUTE((_VAR##_0_0), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_1X2(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_1X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_1) = _PERMUTE((_VAR##_0_1), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_1X4(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_1X2(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_2) = _PERMUTE((_VAR##_0_2), ##__VA_ARGS__); \ (_VAR##_0_3) = _PERMUTE((_VAR##_0_3), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_1X8(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_1X4(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_4) = _PERMUTE((_VAR##_0_4), ##__VA_ARGS__); \ (_VAR##_0_5) = _PERMUTE((_VAR##_0_5), ##__VA_ARGS__); \ (_VAR##_0_6) = _PERMUTE((_VAR##_0_6), ##__VA_ARGS__); \ (_VAR##_0_7) = _PERMUTE((_VAR##_0_7), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_1X16(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_1X8(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_8) = _PERMUTE((_VAR##_0_8), ##__VA_ARGS__); \ (_VAR##_0_9) = _PERMUTE((_VAR##_0_9), ##__VA_ARGS__); \ (_VAR##_0_10) = _PERMUTE((_VAR##_0_10), ##__VA_ARGS__); \ (_VAR##_0_11) = _PERMUTE((_VAR##_0_11), ##__VA_ARGS__); \ (_VAR##_0_12) = _PERMUTE((_VAR##_0_12), ##__VA_ARGS__); \ (_VAR##_0_13) = _PERMUTE((_VAR##_0_13), ##__VA_ARGS__); \ (_VAR##_0_14) = _PERMUTE((_VAR##_0_14), ##__VA_ARGS__); \ (_VAR##_0_15) = _PERMUTE((_VAR##_0_15), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X1(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_1X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_1_0) = _PERMUTE((_VAR##_1_0), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X2(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_1) = _PERMUTE((_VAR##_0_1), ##__VA_ARGS__); \ (_VAR##_1_1) = _PERMUTE((_VAR##_1_1), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X4(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X2(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_2) = _PERMUTE((_VAR##_0_2), ##__VA_ARGS__); \ (_VAR##_1_2) = _PERMUTE((_VAR##_1_2), ##__VA_ARGS__); \ (_VAR##_0_3) = _PERMUTE((_VAR##_0_3), ##__VA_ARGS__); \ (_VAR##_1_3) = _PERMUTE((_VAR##_1_3), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X8(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X4(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_4) = _PERMUTE((_VAR##_0_4), ##__VA_ARGS__); \ (_VAR##_1_4) = _PERMUTE((_VAR##_1_4), ##__VA_ARGS__); \ (_VAR##_0_5) = _PERMUTE((_VAR##_0_5), ##__VA_ARGS__); \ (_VAR##_1_5) = _PERMUTE((_VAR##_1_5), ##__VA_ARGS__); \ (_VAR##_0_6) = _PERMUTE((_VAR##_0_6), ##__VA_ARGS__); \ (_VAR##_1_6) = _PERMUTE((_VAR##_1_6), ##__VA_ARGS__); \ (_VAR##_0_7) = _PERMUTE((_VAR##_0_7), ##__VA_ARGS__); \ (_VAR##_1_7) = _PERMUTE((_VAR##_1_7), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X16(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X8(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_8) = _PERMUTE((_VAR##_0_8), ##__VA_ARGS__); \ (_VAR##_1_8) = _PERMUTE((_VAR##_1_8), ##__VA_ARGS__); \ (_VAR##_0_9) = _PERMUTE((_VAR##_0_9), ##__VA_ARGS__); \ (_VAR##_1_9) = _PERMUTE((_VAR##_1_9), ##__VA_ARGS__); \ (_VAR##_0_10) = _PERMUTE((_VAR##_0_10), ##__VA_ARGS__); \ (_VAR##_1_10) = _PERMUTE((_VAR##_1_10), ##__VA_ARGS__); \ (_VAR##_0_11) = _PERMUTE((_VAR##_0_11), ##__VA_ARGS__); \ (_VAR##_1_11) = _PERMUTE((_VAR##_1_11), ##__VA_ARGS__); \ (_VAR##_0_12) = _PERMUTE((_VAR##_0_12), ##__VA_ARGS__); \ (_VAR##_1_12) = _PERMUTE((_VAR##_1_12), ##__VA_ARGS__); \ (_VAR##_0_13) = _PERMUTE((_VAR##_0_13), ##__VA_ARGS__); \ (_VAR##_1_13) = _PERMUTE((_VAR##_1_13), ##__VA_ARGS__); \ (_VAR##_0_14) = _PERMUTE((_VAR##_0_14), ##__VA_ARGS__); \ (_VAR##_1_14) = _PERMUTE((_VAR##_1_14), ##__VA_ARGS__); \ (_VAR##_0_15) = _PERMUTE((_VAR##_0_15), ##__VA_ARGS__); \ (_VAR##_1_15) = _PERMUTE((_VAR##_1_15), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_2X32(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X16(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_16) = _PERMUTE((_VAR##_0_16), ##__VA_ARGS__); \ (_VAR##_1_16) = _PERMUTE((_VAR##_1_16), ##__VA_ARGS__); \ (_VAR##_0_17) = _PERMUTE((_VAR##_0_17), ##__VA_ARGS__); \ (_VAR##_1_17) = _PERMUTE((_VAR##_1_17), ##__VA_ARGS__); \ (_VAR##_0_18) = _PERMUTE((_VAR##_0_18), ##__VA_ARGS__); \ (_VAR##_1_18) = _PERMUTE((_VAR##_1_18), ##__VA_ARGS__); \ (_VAR##_0_19) = _PERMUTE((_VAR##_0_19), ##__VA_ARGS__); \ (_VAR##_1_19) = _PERMUTE((_VAR##_1_19), ##__VA_ARGS__); \ (_VAR##_0_20) = _PERMUTE((_VAR##_0_20), ##__VA_ARGS__); \ (_VAR##_1_20) = _PERMUTE((_VAR##_1_20), ##__VA_ARGS__); \ (_VAR##_0_21) = _PERMUTE((_VAR##_0_21), ##__VA_ARGS__); \ (_VAR##_1_21) = _PERMUTE((_VAR##_1_21), ##__VA_ARGS__); \ (_VAR##_0_22) = _PERMUTE((_VAR##_0_22), ##__VA_ARGS__); \ (_VAR##_1_22) = _PERMUTE((_VAR##_1_22), ##__VA_ARGS__); \ (_VAR##_0_23) = _PERMUTE((_VAR##_0_23), ##__VA_ARGS__); \ (_VAR##_1_23) = _PERMUTE((_VAR##_1_23), ##__VA_ARGS__); \ (_VAR##_0_24) = _PERMUTE((_VAR##_0_24), ##__VA_ARGS__); \ (_VAR##_1_24) = _PERMUTE((_VAR##_1_24), ##__VA_ARGS__); \ (_VAR##_0_25) = _PERMUTE((_VAR##_0_25), ##__VA_ARGS__); \ (_VAR##_1_25) = _PERMUTE((_VAR##_1_25), ##__VA_ARGS__); \ (_VAR##_0_26) = _PERMUTE((_VAR##_0_26), ##__VA_ARGS__); \ (_VAR##_1_26) = _PERMUTE((_VAR##_1_26), ##__VA_ARGS__); \ (_VAR##_0_27) = _PERMUTE((_VAR##_0_27), ##__VA_ARGS__); \ (_VAR##_1_27) = _PERMUTE((_VAR##_1_27), ##__VA_ARGS__); \ (_VAR##_0_28) = _PERMUTE((_VAR##_0_28), ##__VA_ARGS__); \ (_VAR##_1_28) = _PERMUTE((_VAR##_1_28), ##__VA_ARGS__); \ (_VAR##_0_29) = _PERMUTE((_VAR##_0_29), ##__VA_ARGS__); \ (_VAR##_1_29) = _PERMUTE((_VAR##_1_29), ##__VA_ARGS__); \ (_VAR##_0_30) = _PERMUTE((_VAR##_0_30), ##__VA_ARGS__); \ (_VAR##_1_30) = _PERMUTE((_VAR##_1_30), ##__VA_ARGS__); \ (_VAR##_0_31) = _PERMUTE((_VAR##_0_31), ##__VA_ARGS__); \ (_VAR##_1_31) = _PERMUTE((_VAR##_1_31), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X1(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_2X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_2_0) = _PERMUTE((_VAR##_2_0), ##__VA_ARGS__); \ (_VAR##_3_0) = _PERMUTE((_VAR##_3_0), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X2(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_1) = _PERMUTE((_VAR##_0_1), ##__VA_ARGS__); \ (_VAR##_1_1) = _PERMUTE((_VAR##_1_1), ##__VA_ARGS__); \ (_VAR##_2_1) = _PERMUTE((_VAR##_2_1), ##__VA_ARGS__); \ (_VAR##_3_1) = _PERMUTE((_VAR##_3_1), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X4(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X2(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_2) = _PERMUTE((_VAR##_0_2), ##__VA_ARGS__); \ (_VAR##_1_2) = _PERMUTE((_VAR##_1_2), ##__VA_ARGS__); \ (_VAR##_2_2) = _PERMUTE((_VAR##_2_2), ##__VA_ARGS__); \ (_VAR##_3_2) = _PERMUTE((_VAR##_3_2), ##__VA_ARGS__); \ (_VAR##_0_3) = _PERMUTE((_VAR##_0_3), ##__VA_ARGS__); \ (_VAR##_1_3) = _PERMUTE((_VAR##_1_3), ##__VA_ARGS__); \ (_VAR##_2_3) = _PERMUTE((_VAR##_2_3), ##__VA_ARGS__); \ (_VAR##_3_3) = _PERMUTE((_VAR##_3_3), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X8(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X4(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_4) = _PERMUTE((_VAR##_0_4), ##__VA_ARGS__); \ (_VAR##_1_4) = _PERMUTE((_VAR##_1_4), ##__VA_ARGS__); \ (_VAR##_2_4) = _PERMUTE((_VAR##_2_4), ##__VA_ARGS__); \ (_VAR##_3_4) = _PERMUTE((_VAR##_3_4), ##__VA_ARGS__); \ (_VAR##_0_5) = _PERMUTE((_VAR##_0_5), ##__VA_ARGS__); \ (_VAR##_1_5) = _PERMUTE((_VAR##_1_5), ##__VA_ARGS__); \ (_VAR##_2_5) = _PERMUTE((_VAR##_2_5), ##__VA_ARGS__); \ (_VAR##_3_5) = _PERMUTE((_VAR##_3_5), ##__VA_ARGS__); \ (_VAR##_0_6) = _PERMUTE((_VAR##_0_6), ##__VA_ARGS__); \ (_VAR##_1_6) = _PERMUTE((_VAR##_1_6), ##__VA_ARGS__); \ (_VAR##_2_6) = _PERMUTE((_VAR##_2_6), ##__VA_ARGS__); \ (_VAR##_3_6) = _PERMUTE((_VAR##_3_6), ##__VA_ARGS__); \ (_VAR##_0_7) = _PERMUTE((_VAR##_0_7), ##__VA_ARGS__); \ (_VAR##_1_7) = _PERMUTE((_VAR##_1_7), ##__VA_ARGS__); \ (_VAR##_2_7) = _PERMUTE((_VAR##_2_7), ##__VA_ARGS__); \ (_VAR##_3_7) = _PERMUTE((_VAR##_3_7), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X16(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X8(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_8) = _PERMUTE((_VAR##_0_8), ##__VA_ARGS__); \ (_VAR##_1_8) = _PERMUTE((_VAR##_1_8), ##__VA_ARGS__); \ (_VAR##_2_8) = _PERMUTE((_VAR##_2_8), ##__VA_ARGS__); \ (_VAR##_3_8) = _PERMUTE((_VAR##_3_8), ##__VA_ARGS__); \ (_VAR##_0_9) = _PERMUTE((_VAR##_0_9), ##__VA_ARGS__); \ (_VAR##_1_9) = _PERMUTE((_VAR##_1_9), ##__VA_ARGS__); \ (_VAR##_2_9) = _PERMUTE((_VAR##_2_9), ##__VA_ARGS__); \ (_VAR##_3_9) = _PERMUTE((_VAR##_3_9), ##__VA_ARGS__); \ (_VAR##_0_10) = _PERMUTE((_VAR##_0_10), ##__VA_ARGS__); \ (_VAR##_1_10) = _PERMUTE((_VAR##_1_10), ##__VA_ARGS__); \ (_VAR##_2_10) = _PERMUTE((_VAR##_2_10), ##__VA_ARGS__); \ (_VAR##_3_10) = _PERMUTE((_VAR##_3_10), ##__VA_ARGS__); \ (_VAR##_0_11) = _PERMUTE((_VAR##_0_11), ##__VA_ARGS__); \ (_VAR##_1_11) = _PERMUTE((_VAR##_1_11), ##__VA_ARGS__); \ (_VAR##_2_11) = _PERMUTE((_VAR##_2_11), ##__VA_ARGS__); \ (_VAR##_3_11) = _PERMUTE((_VAR##_3_11), ##__VA_ARGS__); \ (_VAR##_0_12) = _PERMUTE((_VAR##_0_12), ##__VA_ARGS__); \ (_VAR##_1_12) = _PERMUTE((_VAR##_1_12), ##__VA_ARGS__); \ (_VAR##_2_12) = _PERMUTE((_VAR##_2_12), ##__VA_ARGS__); \ (_VAR##_3_12) = _PERMUTE((_VAR##_3_12), ##__VA_ARGS__); \ (_VAR##_0_13) = _PERMUTE((_VAR##_0_13), ##__VA_ARGS__); \ (_VAR##_1_13) = _PERMUTE((_VAR##_1_13), ##__VA_ARGS__); \ (_VAR##_2_13) = _PERMUTE((_VAR##_2_13), ##__VA_ARGS__); \ (_VAR##_3_13) = _PERMUTE((_VAR##_3_13), ##__VA_ARGS__); \ (_VAR##_0_14) = _PERMUTE((_VAR##_0_14), ##__VA_ARGS__); \ (_VAR##_1_14) = _PERMUTE((_VAR##_1_14), ##__VA_ARGS__); \ (_VAR##_2_14) = _PERMUTE((_VAR##_2_14), ##__VA_ARGS__); \ (_VAR##_3_14) = _PERMUTE((_VAR##_3_14), ##__VA_ARGS__); \ (_VAR##_0_15) = _PERMUTE((_VAR##_0_15), ##__VA_ARGS__); \ (_VAR##_1_15) = _PERMUTE((_VAR##_1_15), ##__VA_ARGS__); \ (_VAR##_2_15) = _PERMUTE((_VAR##_2_15), ##__VA_ARGS__); \ (_VAR##_3_15) = _PERMUTE((_VAR##_3_15), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_4X32(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X16(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_16) = _PERMUTE((_VAR##_0_16), ##__VA_ARGS__); \ (_VAR##_1_16) = _PERMUTE((_VAR##_1_16), ##__VA_ARGS__); \ (_VAR##_2_16) = _PERMUTE((_VAR##_2_16), ##__VA_ARGS__); \ (_VAR##_3_16) = _PERMUTE((_VAR##_3_16), ##__VA_ARGS__); \ (_VAR##_0_17) = _PERMUTE((_VAR##_0_17), ##__VA_ARGS__); \ (_VAR##_1_17) = _PERMUTE((_VAR##_1_17), ##__VA_ARGS__); \ (_VAR##_2_17) = _PERMUTE((_VAR##_2_17), ##__VA_ARGS__); \ (_VAR##_3_17) = _PERMUTE((_VAR##_3_17), ##__VA_ARGS__); \ (_VAR##_0_18) = _PERMUTE((_VAR##_0_18), ##__VA_ARGS__); \ (_VAR##_1_18) = _PERMUTE((_VAR##_1_18), ##__VA_ARGS__); \ (_VAR##_2_18) = _PERMUTE((_VAR##_2_18), ##__VA_ARGS__); \ (_VAR##_3_18) = _PERMUTE((_VAR##_3_18), ##__VA_ARGS__); \ (_VAR##_0_19) = _PERMUTE((_VAR##_0_19), ##__VA_ARGS__); \ (_VAR##_1_19) = _PERMUTE((_VAR##_1_19), ##__VA_ARGS__); \ (_VAR##_2_19) = _PERMUTE((_VAR##_2_19), ##__VA_ARGS__); \ (_VAR##_3_19) = _PERMUTE((_VAR##_3_19), ##__VA_ARGS__); \ (_VAR##_0_20) = _PERMUTE((_VAR##_0_20), ##__VA_ARGS__); \ (_VAR##_1_20) = _PERMUTE((_VAR##_1_20), ##__VA_ARGS__); \ (_VAR##_2_20) = _PERMUTE((_VAR##_2_20), ##__VA_ARGS__); \ (_VAR##_3_20) = _PERMUTE((_VAR##_3_20), ##__VA_ARGS__); \ (_VAR##_0_21) = _PERMUTE((_VAR##_0_21), ##__VA_ARGS__); \ (_VAR##_1_21) = _PERMUTE((_VAR##_1_21), ##__VA_ARGS__); \ (_VAR##_2_21) = _PERMUTE((_VAR##_2_21), ##__VA_ARGS__); \ (_VAR##_3_21) = _PERMUTE((_VAR##_3_21), ##__VA_ARGS__); \ (_VAR##_0_22) = _PERMUTE((_VAR##_0_22), ##__VA_ARGS__); \ (_VAR##_1_22) = _PERMUTE((_VAR##_1_22), ##__VA_ARGS__); \ (_VAR##_2_22) = _PERMUTE((_VAR##_2_22), ##__VA_ARGS__); \ (_VAR##_3_22) = _PERMUTE((_VAR##_3_22), ##__VA_ARGS__); \ (_VAR##_0_23) = _PERMUTE((_VAR##_0_23), ##__VA_ARGS__); \ (_VAR##_1_23) = _PERMUTE((_VAR##_1_23), ##__VA_ARGS__); \ (_VAR##_2_23) = _PERMUTE((_VAR##_2_23), ##__VA_ARGS__); \ (_VAR##_3_23) = _PERMUTE((_VAR##_3_23), ##__VA_ARGS__); \ (_VAR##_0_24) = _PERMUTE((_VAR##_0_24), ##__VA_ARGS__); \ (_VAR##_1_24) = _PERMUTE((_VAR##_1_24), ##__VA_ARGS__); \ (_VAR##_2_24) = _PERMUTE((_VAR##_2_24), ##__VA_ARGS__); \ (_VAR##_3_24) = _PERMUTE((_VAR##_3_24), ##__VA_ARGS__); \ (_VAR##_0_25) = _PERMUTE((_VAR##_0_25), ##__VA_ARGS__); \ (_VAR##_1_25) = _PERMUTE((_VAR##_1_25), ##__VA_ARGS__); \ (_VAR##_2_25) = _PERMUTE((_VAR##_2_25), ##__VA_ARGS__); \ (_VAR##_3_25) = _PERMUTE((_VAR##_3_25), ##__VA_ARGS__); \ (_VAR##_0_26) = _PERMUTE((_VAR##_0_26), ##__VA_ARGS__); \ (_VAR##_1_26) = _PERMUTE((_VAR##_1_26), ##__VA_ARGS__); \ (_VAR##_2_26) = _PERMUTE((_VAR##_2_26), ##__VA_ARGS__); \ (_VAR##_3_26) = _PERMUTE((_VAR##_3_26), ##__VA_ARGS__); \ (_VAR##_0_27) = _PERMUTE((_VAR##_0_27), ##__VA_ARGS__); \ (_VAR##_1_27) = _PERMUTE((_VAR##_1_27), ##__VA_ARGS__); \ (_VAR##_2_27) = _PERMUTE((_VAR##_2_27), ##__VA_ARGS__); \ (_VAR##_3_27) = _PERMUTE((_VAR##_3_27), ##__VA_ARGS__); \ (_VAR##_0_28) = _PERMUTE((_VAR##_0_28), ##__VA_ARGS__); \ (_VAR##_1_28) = _PERMUTE((_VAR##_1_28), ##__VA_ARGS__); \ (_VAR##_2_28) = _PERMUTE((_VAR##_2_28), ##__VA_ARGS__); \ (_VAR##_3_28) = _PERMUTE((_VAR##_3_28), ##__VA_ARGS__); \ (_VAR##_0_29) = _PERMUTE((_VAR##_0_29), ##__VA_ARGS__); \ (_VAR##_1_29) = _PERMUTE((_VAR##_1_29), ##__VA_ARGS__); \ (_VAR##_2_29) = _PERMUTE((_VAR##_2_29), ##__VA_ARGS__); \ (_VAR##_3_29) = _PERMUTE((_VAR##_3_29), ##__VA_ARGS__); \ (_VAR##_0_30) = _PERMUTE((_VAR##_0_30), ##__VA_ARGS__); \ (_VAR##_1_30) = _PERMUTE((_VAR##_1_30), ##__VA_ARGS__); \ (_VAR##_2_30) = _PERMUTE((_VAR##_2_30), ##__VA_ARGS__); \ (_VAR##_3_30) = _PERMUTE((_VAR##_3_30), ##__VA_ARGS__); \ (_VAR##_0_31) = _PERMUTE((_VAR##_0_31), ##__VA_ARGS__); \ (_VAR##_1_31) = _PERMUTE((_VAR##_1_31), ##__VA_ARGS__); \ (_VAR##_2_31) = _PERMUTE((_VAR##_2_31), ##__VA_ARGS__); \ (_VAR##_3_31) = _PERMUTE((_VAR##_3_31), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X1(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_4X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_4_0) = _PERMUTE((_VAR##_4_0), ##__VA_ARGS__); \ (_VAR##_5_0) = _PERMUTE((_VAR##_5_0), ##__VA_ARGS__); \ (_VAR##_6_0) = _PERMUTE((_VAR##_6_0), ##__VA_ARGS__); \ (_VAR##_7_0) = _PERMUTE((_VAR##_7_0), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X2(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_1) = _PERMUTE((_VAR##_0_1), ##__VA_ARGS__); \ (_VAR##_1_1) = _PERMUTE((_VAR##_1_1), ##__VA_ARGS__); \ (_VAR##_2_1) = _PERMUTE((_VAR##_2_1), ##__VA_ARGS__); \ (_VAR##_3_1) = _PERMUTE((_VAR##_3_1), ##__VA_ARGS__); \ (_VAR##_4_1) = _PERMUTE((_VAR##_4_1), ##__VA_ARGS__); \ (_VAR##_5_1) = _PERMUTE((_VAR##_5_1), ##__VA_ARGS__); \ (_VAR##_6_1) = _PERMUTE((_VAR##_6_1), ##__VA_ARGS__); \ (_VAR##_7_1) = _PERMUTE((_VAR##_7_1), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X4(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X2(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_2) = _PERMUTE((_VAR##_0_2), ##__VA_ARGS__); \ (_VAR##_1_2) = _PERMUTE((_VAR##_1_2), ##__VA_ARGS__); \ (_VAR##_2_2) = _PERMUTE((_VAR##_2_2), ##__VA_ARGS__); \ (_VAR##_3_2) = _PERMUTE((_VAR##_3_2), ##__VA_ARGS__); \ (_VAR##_4_2) = _PERMUTE((_VAR##_4_2), ##__VA_ARGS__); \ (_VAR##_5_2) = _PERMUTE((_VAR##_5_2), ##__VA_ARGS__); \ (_VAR##_6_2) = _PERMUTE((_VAR##_6_2), ##__VA_ARGS__); \ (_VAR##_7_2) = _PERMUTE((_VAR##_7_2), ##__VA_ARGS__); \ (_VAR##_0_3) = _PERMUTE((_VAR##_0_3), ##__VA_ARGS__); \ (_VAR##_1_3) = _PERMUTE((_VAR##_1_3), ##__VA_ARGS__); \ (_VAR##_2_3) = _PERMUTE((_VAR##_2_3), ##__VA_ARGS__); \ (_VAR##_3_3) = _PERMUTE((_VAR##_3_3), ##__VA_ARGS__); \ (_VAR##_4_3) = _PERMUTE((_VAR##_4_3), ##__VA_ARGS__); \ (_VAR##_5_3) = _PERMUTE((_VAR##_5_3), ##__VA_ARGS__); \ (_VAR##_6_3) = _PERMUTE((_VAR##_6_3), ##__VA_ARGS__); \ (_VAR##_7_3) = _PERMUTE((_VAR##_7_3), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X8(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X4(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_4) = _PERMUTE((_VAR##_0_4), ##__VA_ARGS__); \ (_VAR##_1_4) = _PERMUTE((_VAR##_1_4), ##__VA_ARGS__); \ (_VAR##_2_4) = _PERMUTE((_VAR##_2_4), ##__VA_ARGS__); \ (_VAR##_3_4) = _PERMUTE((_VAR##_3_4), ##__VA_ARGS__); \ (_VAR##_4_4) = _PERMUTE((_VAR##_4_4), ##__VA_ARGS__); \ (_VAR##_5_4) = _PERMUTE((_VAR##_5_4), ##__VA_ARGS__); \ (_VAR##_6_4) = _PERMUTE((_VAR##_6_4), ##__VA_ARGS__); \ (_VAR##_7_4) = _PERMUTE((_VAR##_7_4), ##__VA_ARGS__); \ (_VAR##_0_5) = _PERMUTE((_VAR##_0_5), ##__VA_ARGS__); \ (_VAR##_1_5) = _PERMUTE((_VAR##_1_5), ##__VA_ARGS__); \ (_VAR##_2_5) = _PERMUTE((_VAR##_2_5), ##__VA_ARGS__); \ (_VAR##_3_5) = _PERMUTE((_VAR##_3_5), ##__VA_ARGS__); \ (_VAR##_4_5) = _PERMUTE((_VAR##_4_5), ##__VA_ARGS__); \ (_VAR##_5_5) = _PERMUTE((_VAR##_5_5), ##__VA_ARGS__); \ (_VAR##_6_5) = _PERMUTE((_VAR##_6_5), ##__VA_ARGS__); \ (_VAR##_7_5) = _PERMUTE((_VAR##_7_5), ##__VA_ARGS__); \ (_VAR##_0_6) = _PERMUTE((_VAR##_0_6), ##__VA_ARGS__); \ (_VAR##_1_6) = _PERMUTE((_VAR##_1_6), ##__VA_ARGS__); \ (_VAR##_2_6) = _PERMUTE((_VAR##_2_6), ##__VA_ARGS__); \ (_VAR##_3_6) = _PERMUTE((_VAR##_3_6), ##__VA_ARGS__); \ (_VAR##_4_6) = _PERMUTE((_VAR##_4_6), ##__VA_ARGS__); \ (_VAR##_5_6) = _PERMUTE((_VAR##_5_6), ##__VA_ARGS__); \ (_VAR##_6_6) = _PERMUTE((_VAR##_6_6), ##__VA_ARGS__); \ (_VAR##_7_6) = _PERMUTE((_VAR##_7_6), ##__VA_ARGS__); \ (_VAR##_0_7) = _PERMUTE((_VAR##_0_7), ##__VA_ARGS__); \ (_VAR##_1_7) = _PERMUTE((_VAR##_1_7), ##__VA_ARGS__); \ (_VAR##_2_7) = _PERMUTE((_VAR##_2_7), ##__VA_ARGS__); \ (_VAR##_3_7) = _PERMUTE((_VAR##_3_7), ##__VA_ARGS__); \ (_VAR##_4_7) = _PERMUTE((_VAR##_4_7), ##__VA_ARGS__); \ (_VAR##_5_7) = _PERMUTE((_VAR##_5_7), ##__VA_ARGS__); \ (_VAR##_6_7) = _PERMUTE((_VAR##_6_7), ##__VA_ARGS__); \ (_VAR##_7_7) = _PERMUTE((_VAR##_7_7), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X16(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X8(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_8) = _PERMUTE((_VAR##_0_8), ##__VA_ARGS__); \ (_VAR##_1_8) = _PERMUTE((_VAR##_1_8), ##__VA_ARGS__); \ (_VAR##_2_8) = _PERMUTE((_VAR##_2_8), ##__VA_ARGS__); \ (_VAR##_3_8) = _PERMUTE((_VAR##_3_8), ##__VA_ARGS__); \ (_VAR##_4_8) = _PERMUTE((_VAR##_4_8), ##__VA_ARGS__); \ (_VAR##_5_8) = _PERMUTE((_VAR##_5_8), ##__VA_ARGS__); \ (_VAR##_6_8) = _PERMUTE((_VAR##_6_8), ##__VA_ARGS__); \ (_VAR##_7_8) = _PERMUTE((_VAR##_7_8), ##__VA_ARGS__); \ (_VAR##_0_9) = _PERMUTE((_VAR##_0_9), ##__VA_ARGS__); \ (_VAR##_1_9) = _PERMUTE((_VAR##_1_9), ##__VA_ARGS__); \ (_VAR##_2_9) = _PERMUTE((_VAR##_2_9), ##__VA_ARGS__); \ (_VAR##_3_9) = _PERMUTE((_VAR##_3_9), ##__VA_ARGS__); \ (_VAR##_4_9) = _PERMUTE((_VAR##_4_9), ##__VA_ARGS__); \ (_VAR##_5_9) = _PERMUTE((_VAR##_5_9), ##__VA_ARGS__); \ (_VAR##_6_9) = _PERMUTE((_VAR##_6_9), ##__VA_ARGS__); \ (_VAR##_7_9) = _PERMUTE((_VAR##_7_9), ##__VA_ARGS__); \ (_VAR##_0_10) = _PERMUTE((_VAR##_0_10), ##__VA_ARGS__); \ (_VAR##_1_10) = _PERMUTE((_VAR##_1_10), ##__VA_ARGS__); \ (_VAR##_2_10) = _PERMUTE((_VAR##_2_10), ##__VA_ARGS__); \ (_VAR##_3_10) = _PERMUTE((_VAR##_3_10), ##__VA_ARGS__); \ (_VAR##_4_10) = _PERMUTE((_VAR##_4_10), ##__VA_ARGS__); \ (_VAR##_5_10) = _PERMUTE((_VAR##_5_10), ##__VA_ARGS__); \ (_VAR##_6_10) = _PERMUTE((_VAR##_6_10), ##__VA_ARGS__); \ (_VAR##_7_10) = _PERMUTE((_VAR##_7_10), ##__VA_ARGS__); \ (_VAR##_0_11) = _PERMUTE((_VAR##_0_11), ##__VA_ARGS__); \ (_VAR##_1_11) = _PERMUTE((_VAR##_1_11), ##__VA_ARGS__); \ (_VAR##_2_11) = _PERMUTE((_VAR##_2_11), ##__VA_ARGS__); \ (_VAR##_3_11) = _PERMUTE((_VAR##_3_11), ##__VA_ARGS__); \ (_VAR##_4_11) = _PERMUTE((_VAR##_4_11), ##__VA_ARGS__); \ (_VAR##_5_11) = _PERMUTE((_VAR##_5_11), ##__VA_ARGS__); \ (_VAR##_6_11) = _PERMUTE((_VAR##_6_11), ##__VA_ARGS__); \ (_VAR##_7_11) = _PERMUTE((_VAR##_7_11), ##__VA_ARGS__); \ (_VAR##_0_12) = _PERMUTE((_VAR##_0_12), ##__VA_ARGS__); \ (_VAR##_1_12) = _PERMUTE((_VAR##_1_12), ##__VA_ARGS__); \ (_VAR##_2_12) = _PERMUTE((_VAR##_2_12), ##__VA_ARGS__); \ (_VAR##_3_12) = _PERMUTE((_VAR##_3_12), ##__VA_ARGS__); \ (_VAR##_4_12) = _PERMUTE((_VAR##_4_12), ##__VA_ARGS__); \ (_VAR##_5_12) = _PERMUTE((_VAR##_5_12), ##__VA_ARGS__); \ (_VAR##_6_12) = _PERMUTE((_VAR##_6_12), ##__VA_ARGS__); \ (_VAR##_7_12) = _PERMUTE((_VAR##_7_12), ##__VA_ARGS__); \ (_VAR##_0_13) = _PERMUTE((_VAR##_0_13), ##__VA_ARGS__); \ (_VAR##_1_13) = _PERMUTE((_VAR##_1_13), ##__VA_ARGS__); \ (_VAR##_2_13) = _PERMUTE((_VAR##_2_13), ##__VA_ARGS__); \ (_VAR##_3_13) = _PERMUTE((_VAR##_3_13), ##__VA_ARGS__); \ (_VAR##_4_13) = _PERMUTE((_VAR##_4_13), ##__VA_ARGS__); \ (_VAR##_5_13) = _PERMUTE((_VAR##_5_13), ##__VA_ARGS__); \ (_VAR##_6_13) = _PERMUTE((_VAR##_6_13), ##__VA_ARGS__); \ (_VAR##_7_13) = _PERMUTE((_VAR##_7_13), ##__VA_ARGS__); \ (_VAR##_0_14) = _PERMUTE((_VAR##_0_14), ##__VA_ARGS__); \ (_VAR##_1_14) = _PERMUTE((_VAR##_1_14), ##__VA_ARGS__); \ (_VAR##_2_14) = _PERMUTE((_VAR##_2_14), ##__VA_ARGS__); \ (_VAR##_3_14) = _PERMUTE((_VAR##_3_14), ##__VA_ARGS__); \ (_VAR##_4_14) = _PERMUTE((_VAR##_4_14), ##__VA_ARGS__); \ (_VAR##_5_14) = _PERMUTE((_VAR##_5_14), ##__VA_ARGS__); \ (_VAR##_6_14) = _PERMUTE((_VAR##_6_14), ##__VA_ARGS__); \ (_VAR##_7_14) = _PERMUTE((_VAR##_7_14), ##__VA_ARGS__); \ (_VAR##_0_15) = _PERMUTE((_VAR##_0_15), ##__VA_ARGS__); \ (_VAR##_1_15) = _PERMUTE((_VAR##_1_15), ##__VA_ARGS__); \ (_VAR##_2_15) = _PERMUTE((_VAR##_2_15), ##__VA_ARGS__); \ (_VAR##_3_15) = _PERMUTE((_VAR##_3_15), ##__VA_ARGS__); \ (_VAR##_4_15) = _PERMUTE((_VAR##_4_15), ##__VA_ARGS__); \ (_VAR##_5_15) = _PERMUTE((_VAR##_5_15), ##__VA_ARGS__); \ (_VAR##_6_15) = _PERMUTE((_VAR##_6_15), ##__VA_ARGS__); \ (_VAR##_7_15) = _PERMUTE((_VAR##_7_15), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_8X32(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X16(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_16) = _PERMUTE((_VAR##_0_16), ##__VA_ARGS__); \ (_VAR##_1_16) = _PERMUTE((_VAR##_1_16), ##__VA_ARGS__); \ (_VAR##_2_16) = _PERMUTE((_VAR##_2_16), ##__VA_ARGS__); \ (_VAR##_3_16) = _PERMUTE((_VAR##_3_16), ##__VA_ARGS__); \ (_VAR##_4_16) = _PERMUTE((_VAR##_4_16), ##__VA_ARGS__); \ (_VAR##_5_16) = _PERMUTE((_VAR##_5_16), ##__VA_ARGS__); \ (_VAR##_6_16) = _PERMUTE((_VAR##_6_16), ##__VA_ARGS__); \ (_VAR##_7_16) = _PERMUTE((_VAR##_7_16), ##__VA_ARGS__); \ (_VAR##_0_17) = _PERMUTE((_VAR##_0_17), ##__VA_ARGS__); \ (_VAR##_1_17) = _PERMUTE((_VAR##_1_17), ##__VA_ARGS__); \ (_VAR##_2_17) = _PERMUTE((_VAR##_2_17), ##__VA_ARGS__); \ (_VAR##_3_17) = _PERMUTE((_VAR##_3_17), ##__VA_ARGS__); \ (_VAR##_4_17) = _PERMUTE((_VAR##_4_17), ##__VA_ARGS__); \ (_VAR##_5_17) = _PERMUTE((_VAR##_5_17), ##__VA_ARGS__); \ (_VAR##_6_17) = _PERMUTE((_VAR##_6_17), ##__VA_ARGS__); \ (_VAR##_7_17) = _PERMUTE((_VAR##_7_17), ##__VA_ARGS__); \ (_VAR##_0_18) = _PERMUTE((_VAR##_0_18), ##__VA_ARGS__); \ (_VAR##_1_18) = _PERMUTE((_VAR##_1_18), ##__VA_ARGS__); \ (_VAR##_2_18) = _PERMUTE((_VAR##_2_18), ##__VA_ARGS__); \ (_VAR##_3_18) = _PERMUTE((_VAR##_3_18), ##__VA_ARGS__); \ (_VAR##_4_18) = _PERMUTE((_VAR##_4_18), ##__VA_ARGS__); \ (_VAR##_5_18) = _PERMUTE((_VAR##_5_18), ##__VA_ARGS__); \ (_VAR##_6_18) = _PERMUTE((_VAR##_6_18), ##__VA_ARGS__); \ (_VAR##_7_18) = _PERMUTE((_VAR##_7_18), ##__VA_ARGS__); \ (_VAR##_0_19) = _PERMUTE((_VAR##_0_19), ##__VA_ARGS__); \ (_VAR##_1_19) = _PERMUTE((_VAR##_1_19), ##__VA_ARGS__); \ (_VAR##_2_19) = _PERMUTE((_VAR##_2_19), ##__VA_ARGS__); \ (_VAR##_3_19) = _PERMUTE((_VAR##_3_19), ##__VA_ARGS__); \ (_VAR##_4_19) = _PERMUTE((_VAR##_4_19), ##__VA_ARGS__); \ (_VAR##_5_19) = _PERMUTE((_VAR##_5_19), ##__VA_ARGS__); \ (_VAR##_6_19) = _PERMUTE((_VAR##_6_19), ##__VA_ARGS__); \ (_VAR##_7_19) = _PERMUTE((_VAR##_7_19), ##__VA_ARGS__); \ (_VAR##_0_20) = _PERMUTE((_VAR##_0_20), ##__VA_ARGS__); \ (_VAR##_1_20) = _PERMUTE((_VAR##_1_20), ##__VA_ARGS__); \ (_VAR##_2_20) = _PERMUTE((_VAR##_2_20), ##__VA_ARGS__); \ (_VAR##_3_20) = _PERMUTE((_VAR##_3_20), ##__VA_ARGS__); \ (_VAR##_4_20) = _PERMUTE((_VAR##_4_20), ##__VA_ARGS__); \ (_VAR##_5_20) = _PERMUTE((_VAR##_5_20), ##__VA_ARGS__); \ (_VAR##_6_20) = _PERMUTE((_VAR##_6_20), ##__VA_ARGS__); \ (_VAR##_7_20) = _PERMUTE((_VAR##_7_20), ##__VA_ARGS__); \ (_VAR##_0_21) = _PERMUTE((_VAR##_0_21), ##__VA_ARGS__); \ (_VAR##_1_21) = _PERMUTE((_VAR##_1_21), ##__VA_ARGS__); \ (_VAR##_2_21) = _PERMUTE((_VAR##_2_21), ##__VA_ARGS__); \ (_VAR##_3_21) = _PERMUTE((_VAR##_3_21), ##__VA_ARGS__); \ (_VAR##_4_21) = _PERMUTE((_VAR##_4_21), ##__VA_ARGS__); \ (_VAR##_5_21) = _PERMUTE((_VAR##_5_21), ##__VA_ARGS__); \ (_VAR##_6_21) = _PERMUTE((_VAR##_6_21), ##__VA_ARGS__); \ (_VAR##_7_21) = _PERMUTE((_VAR##_7_21), ##__VA_ARGS__); \ (_VAR##_0_22) = _PERMUTE((_VAR##_0_22), ##__VA_ARGS__); \ (_VAR##_1_22) = _PERMUTE((_VAR##_1_22), ##__VA_ARGS__); \ (_VAR##_2_22) = _PERMUTE((_VAR##_2_22), ##__VA_ARGS__); \ (_VAR##_3_22) = _PERMUTE((_VAR##_3_22), ##__VA_ARGS__); \ (_VAR##_4_22) = _PERMUTE((_VAR##_4_22), ##__VA_ARGS__); \ (_VAR##_5_22) = _PERMUTE((_VAR##_5_22), ##__VA_ARGS__); \ (_VAR##_6_22) = _PERMUTE((_VAR##_6_22), ##__VA_ARGS__); \ (_VAR##_7_22) = _PERMUTE((_VAR##_7_22), ##__VA_ARGS__); \ (_VAR##_0_23) = _PERMUTE((_VAR##_0_23), ##__VA_ARGS__); \ (_VAR##_1_23) = _PERMUTE((_VAR##_1_23), ##__VA_ARGS__); \ (_VAR##_2_23) = _PERMUTE((_VAR##_2_23), ##__VA_ARGS__); \ (_VAR##_3_23) = _PERMUTE((_VAR##_3_23), ##__VA_ARGS__); \ (_VAR##_4_23) = _PERMUTE((_VAR##_4_23), ##__VA_ARGS__); \ (_VAR##_5_23) = _PERMUTE((_VAR##_5_23), ##__VA_ARGS__); \ (_VAR##_6_23) = _PERMUTE((_VAR##_6_23), ##__VA_ARGS__); \ (_VAR##_7_23) = _PERMUTE((_VAR##_7_23), ##__VA_ARGS__); \ (_VAR##_0_24) = _PERMUTE((_VAR##_0_24), ##__VA_ARGS__); \ (_VAR##_1_24) = _PERMUTE((_VAR##_1_24), ##__VA_ARGS__); \ (_VAR##_2_24) = _PERMUTE((_VAR##_2_24), ##__VA_ARGS__); \ (_VAR##_3_24) = _PERMUTE((_VAR##_3_24), ##__VA_ARGS__); \ (_VAR##_4_24) = _PERMUTE((_VAR##_4_24), ##__VA_ARGS__); \ (_VAR##_5_24) = _PERMUTE((_VAR##_5_24), ##__VA_ARGS__); \ (_VAR##_6_24) = _PERMUTE((_VAR##_6_24), ##__VA_ARGS__); \ (_VAR##_7_24) = _PERMUTE((_VAR##_7_24), ##__VA_ARGS__); \ (_VAR##_0_25) = _PERMUTE((_VAR##_0_25), ##__VA_ARGS__); \ (_VAR##_1_25) = _PERMUTE((_VAR##_1_25), ##__VA_ARGS__); \ (_VAR##_2_25) = _PERMUTE((_VAR##_2_25), ##__VA_ARGS__); \ (_VAR##_3_25) = _PERMUTE((_VAR##_3_25), ##__VA_ARGS__); \ (_VAR##_4_25) = _PERMUTE((_VAR##_4_25), ##__VA_ARGS__); \ (_VAR##_5_25) = _PERMUTE((_VAR##_5_25), ##__VA_ARGS__); \ (_VAR##_6_25) = _PERMUTE((_VAR##_6_25), ##__VA_ARGS__); \ (_VAR##_7_25) = _PERMUTE((_VAR##_7_25), ##__VA_ARGS__); \ (_VAR##_0_26) = _PERMUTE((_VAR##_0_26), ##__VA_ARGS__); \ (_VAR##_1_26) = _PERMUTE((_VAR##_1_26), ##__VA_ARGS__); \ (_VAR##_2_26) = _PERMUTE((_VAR##_2_26), ##__VA_ARGS__); \ (_VAR##_3_26) = _PERMUTE((_VAR##_3_26), ##__VA_ARGS__); \ (_VAR##_4_26) = _PERMUTE((_VAR##_4_26), ##__VA_ARGS__); \ (_VAR##_5_26) = _PERMUTE((_VAR##_5_26), ##__VA_ARGS__); \ (_VAR##_6_26) = _PERMUTE((_VAR##_6_26), ##__VA_ARGS__); \ (_VAR##_7_26) = _PERMUTE((_VAR##_7_26), ##__VA_ARGS__); \ (_VAR##_0_27) = _PERMUTE((_VAR##_0_27), ##__VA_ARGS__); \ (_VAR##_1_27) = _PERMUTE((_VAR##_1_27), ##__VA_ARGS__); \ (_VAR##_2_27) = _PERMUTE((_VAR##_2_27), ##__VA_ARGS__); \ (_VAR##_3_27) = _PERMUTE((_VAR##_3_27), ##__VA_ARGS__); \ (_VAR##_4_27) = _PERMUTE((_VAR##_4_27), ##__VA_ARGS__); \ (_VAR##_5_27) = _PERMUTE((_VAR##_5_27), ##__VA_ARGS__); \ (_VAR##_6_27) = _PERMUTE((_VAR##_6_27), ##__VA_ARGS__); \ (_VAR##_7_27) = _PERMUTE((_VAR##_7_27), ##__VA_ARGS__); \ (_VAR##_0_28) = _PERMUTE((_VAR##_0_28), ##__VA_ARGS__); \ (_VAR##_1_28) = _PERMUTE((_VAR##_1_28), ##__VA_ARGS__); \ (_VAR##_2_28) = _PERMUTE((_VAR##_2_28), ##__VA_ARGS__); \ (_VAR##_3_28) = _PERMUTE((_VAR##_3_28), ##__VA_ARGS__); \ (_VAR##_4_28) = _PERMUTE((_VAR##_4_28), ##__VA_ARGS__); \ (_VAR##_5_28) = _PERMUTE((_VAR##_5_28), ##__VA_ARGS__); \ (_VAR##_6_28) = _PERMUTE((_VAR##_6_28), ##__VA_ARGS__); \ (_VAR##_7_28) = _PERMUTE((_VAR##_7_28), ##__VA_ARGS__); \ (_VAR##_0_29) = _PERMUTE((_VAR##_0_29), ##__VA_ARGS__); \ (_VAR##_1_29) = _PERMUTE((_VAR##_1_29), ##__VA_ARGS__); \ (_VAR##_2_29) = _PERMUTE((_VAR##_2_29), ##__VA_ARGS__); \ (_VAR##_3_29) = _PERMUTE((_VAR##_3_29), ##__VA_ARGS__); \ (_VAR##_4_29) = _PERMUTE((_VAR##_4_29), ##__VA_ARGS__); \ (_VAR##_5_29) = _PERMUTE((_VAR##_5_29), ##__VA_ARGS__); \ (_VAR##_6_29) = _PERMUTE((_VAR##_6_29), ##__VA_ARGS__); \ (_VAR##_7_29) = _PERMUTE((_VAR##_7_29), ##__VA_ARGS__); \ (_VAR##_0_30) = _PERMUTE((_VAR##_0_30), ##__VA_ARGS__); \ (_VAR##_1_30) = _PERMUTE((_VAR##_1_30), ##__VA_ARGS__); \ (_VAR##_2_30) = _PERMUTE((_VAR##_2_30), ##__VA_ARGS__); \ (_VAR##_3_30) = _PERMUTE((_VAR##_3_30), ##__VA_ARGS__); \ (_VAR##_4_30) = _PERMUTE((_VAR##_4_30), ##__VA_ARGS__); \ (_VAR##_5_30) = _PERMUTE((_VAR##_5_30), ##__VA_ARGS__); \ (_VAR##_6_30) = _PERMUTE((_VAR##_6_30), ##__VA_ARGS__); \ (_VAR##_7_30) = _PERMUTE((_VAR##_7_30), ##__VA_ARGS__); \ (_VAR##_0_31) = _PERMUTE((_VAR##_0_31), ##__VA_ARGS__); \ (_VAR##_1_31) = _PERMUTE((_VAR##_1_31), ##__VA_ARGS__); \ (_VAR##_2_31) = _PERMUTE((_VAR##_2_31), ##__VA_ARGS__); \ (_VAR##_3_31) = _PERMUTE((_VAR##_3_31), ##__VA_ARGS__); \ (_VAR##_4_31) = _PERMUTE((_VAR##_4_31), ##__VA_ARGS__); \ (_VAR##_5_31) = _PERMUTE((_VAR##_5_31), ##__VA_ARGS__); \ (_VAR##_6_31) = _PERMUTE((_VAR##_6_31), ##__VA_ARGS__); \ (_VAR##_7_31) = _PERMUTE((_VAR##_7_31), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_16X1(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_8X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_8_0) = _PERMUTE((_VAR##_8_0), ##__VA_ARGS__); \ (_VAR##_9_0) = _PERMUTE((_VAR##_9_0), ##__VA_ARGS__); \ (_VAR##_10_0) = _PERMUTE((_VAR##_10_0), ##__VA_ARGS__); \ (_VAR##_11_0) = _PERMUTE((_VAR##_11_0), ##__VA_ARGS__); \ (_VAR##_12_0) = _PERMUTE((_VAR##_12_0), ##__VA_ARGS__); \ (_VAR##_13_0) = _PERMUTE((_VAR##_13_0), ##__VA_ARGS__); \ (_VAR##_14_0) = _PERMUTE((_VAR##_14_0), ##__VA_ARGS__); \ (_VAR##_15_0) = _PERMUTE((_VAR##_15_0), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_16X2(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_16X1(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_1) = _PERMUTE((_VAR##_0_1), ##__VA_ARGS__); \ (_VAR##_1_1) = _PERMUTE((_VAR##_1_1), ##__VA_ARGS__); \ (_VAR##_2_1) = _PERMUTE((_VAR##_2_1), ##__VA_ARGS__); \ (_VAR##_3_1) = _PERMUTE((_VAR##_3_1), ##__VA_ARGS__); \ (_VAR##_4_1) = _PERMUTE((_VAR##_4_1), ##__VA_ARGS__); \ (_VAR##_5_1) = _PERMUTE((_VAR##_5_1), ##__VA_ARGS__); \ (_VAR##_6_1) = _PERMUTE((_VAR##_6_1), ##__VA_ARGS__); \ (_VAR##_7_1) = _PERMUTE((_VAR##_7_1), ##__VA_ARGS__); \ (_VAR##_8_1) = _PERMUTE((_VAR##_8_1), ##__VA_ARGS__); \ (_VAR##_9_1) = _PERMUTE((_VAR##_9_1), ##__VA_ARGS__); \ (_VAR##_10_1) = _PERMUTE((_VAR##_10_1), ##__VA_ARGS__); \ (_VAR##_11_1) = _PERMUTE((_VAR##_11_1), ##__VA_ARGS__); \ (_VAR##_12_1) = _PERMUTE((_VAR##_12_1), ##__VA_ARGS__); \ (_VAR##_13_1) = _PERMUTE((_VAR##_13_1), ##__VA_ARGS__); \ (_VAR##_14_1) = _PERMUTE((_VAR##_14_1), ##__VA_ARGS__); \ (_VAR##_15_1) = _PERMUTE((_VAR##_15_1), ##__VA_ARGS__); #define MATRIX_VAR_PERMUTE_16X4(_VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_16X2(_VAR, _PERMUTE, ##__VA_ARGS__) \ (_VAR##_0_2) = _PERMUTE((_VAR##_0_2), ##__VA_ARGS__); \ (_VAR##_1_2) = _PERMUTE((_VAR##_1_2), ##__VA_ARGS__); \ (_VAR##_2_2) = _PERMUTE((_VAR##_2_2), ##__VA_ARGS__); \ (_VAR##_3_2) = _PERMUTE((_VAR##_3_2), ##__VA_ARGS__); \ (_VAR##_4_2) = _PERMUTE((_VAR##_4_2), ##__VA_ARGS__); \ (_VAR##_5_2) = _PERMUTE((_VAR##_5_2), ##__VA_ARGS__); \ (_VAR##_6_2) = _PERMUTE((_VAR##_6_2), ##__VA_ARGS__); \ (_VAR##_7_2) = _PERMUTE((_VAR##_7_2), ##__VA_ARGS__); \ (_VAR##_8_2) = _PERMUTE((_VAR##_8_2), ##__VA_ARGS__); \ (_VAR##_9_2) = _PERMUTE((_VAR##_9_2), ##__VA_ARGS__); \ (_VAR##_10_2) = _PERMUTE((_VAR##_10_2), ##__VA_ARGS__); \ (_VAR##_11_2) = _PERMUTE((_VAR##_11_2), ##__VA_ARGS__); \ (_VAR##_12_2) = _PERMUTE((_VAR##_12_2), ##__VA_ARGS__); \ (_VAR##_13_2) = _PERMUTE((_VAR##_13_2), ##__VA_ARGS__); \ (_VAR##_14_2) = _PERMUTE((_VAR##_14_2), ##__VA_ARGS__); \ (_VAR##_15_2) = _PERMUTE((_VAR##_15_2), ##__VA_ARGS__); \ (_VAR##_0_3) = _PERMUTE((_VAR##_0_3), ##__VA_ARGS__); \ (_VAR##_1_3) = _PERMUTE((_VAR##_1_3), ##__VA_ARGS__); \ (_VAR##_2_3) = _PERMUTE((_VAR##_2_3), ##__VA_ARGS__); \ (_VAR##_3_3) = _PERMUTE((_VAR##_3_3), ##__VA_ARGS__); \ (_VAR##_4_3) = _PERMUTE((_VAR##_4_3), ##__VA_ARGS__); \ (_VAR##_5_3) = _PERMUTE((_VAR##_5_3), ##__VA_ARGS__); \ (_VAR##_6_3) = _PERMUTE((_VAR##_6_3), ##__VA_ARGS__); \ (_VAR##_7_3) = _PERMUTE((_VAR##_7_3), ##__VA_ARGS__); \ (_VAR##_8_3) = _PERMUTE((_VAR##_8_3), ##__VA_ARGS__); \ (_VAR##_9_3) = _PERMUTE((_VAR##_9_3), ##__VA_ARGS__); \ (_VAR##_10_3) = _PERMUTE((_VAR##_10_3), ##__VA_ARGS__); \ (_VAR##_11_3) = _PERMUTE((_VAR##_11_3), ##__VA_ARGS__); \ (_VAR##_12_3) = _PERMUTE((_VAR##_12_3), ##__VA_ARGS__); \ (_VAR##_13_3) = _PERMUTE((_VAR##_13_3), ##__VA_ARGS__); \ (_VAR##_14_3) = _PERMUTE((_VAR##_14_3), ##__VA_ARGS__); \ (_VAR##_15_3) = _PERMUTE((_VAR##_15_3), ##__VA_ARGS__); #define MATRIX_VAR_PROC_2X1(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS##_0), (_RHS), (_RES##_0_##_K)) \ _PROCESS((_LHS##_1), (_RHS), (_RES##_1_##_K)) #define MATRIX_VAR_PROC_4X1(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_2X1(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS##_2), (_RHS), (_RES##_2_##_K)) \ _PROCESS((_LHS##_3), (_RHS), (_RES##_3_##_K)) #define MATRIX_VAR_PROC_8X1(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_4X1(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS##_4), (_RHS), (_RES##_4_##_K)) \ _PROCESS((_LHS##_5), (_RHS), (_RES##_5_##_K)) \ _PROCESS((_LHS##_6), (_RHS), (_RES##_6_##_K)) \ _PROCESS((_LHS##_7), (_RHS), (_RES##_7_##_K)) #define MATRIX_VAR_PROC_16X1(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_8X1(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS##_8), (_RHS), (_RES##_8_##_K)) \ _PROCESS((_LHS##_9), (_RHS), (_RES##_9_##_K)) \ _PROCESS((_LHS##_10), (_RHS), (_RES##_10_##_K)) \ _PROCESS((_LHS##_11), (_RHS), (_RES##_11_##_K)) \ _PROCESS((_LHS##_12), (_RHS), (_RES##_12_##_K)) \ _PROCESS((_LHS##_13), (_RHS), (_RES##_13_##_K)) \ _PROCESS((_LHS##_14), (_RHS), (_RES##_14_##_K)) \ _PROCESS((_LHS##_15), (_RHS), (_RES##_15_##_K)) #define MATRIX_VAR_PROC_1X2(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS), (_RHS##_0), (_RES##_##_K##_0)) \ _PROCESS((_LHS), (_RHS##_1), (_RES##_##_K##_1)) #define MATRIX_VAR_PROC_1X4(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_1X2(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS), (_RHS##_2), (_RES##_##_K##_2)) \ _PROCESS((_LHS), (_RHS##_3), (_RES##_##_K##_3)) #define MATRIX_VAR_PROC_1X8(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_1X4(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS), (_RHS##_4), (_RES##_##_K##_4)) \ _PROCESS((_LHS), (_RHS##_5), (_RES##_##_K##_5)) \ _PROCESS((_LHS), (_RHS##_6), (_RES##_##_K##_6)) \ _PROCESS((_LHS), (_RHS##_7), (_RES##_##_K##_7)) #define MATRIX_VAR_PROC_1X16(_K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_1X8(_K, _LHS, _RHS, _RES, _PROCESS) \ _PROCESS((_LHS), (_RHS##_8), (_RES##_##_K##_8)) \ _PROCESS((_LHS), (_RHS##_9), (_RES##_##_K##_9)) \ _PROCESS((_LHS), (_RHS##_10), (_RES##_##_K##_10)) \ _PROCESS((_LHS), (_RHS##_11), (_RES##_##_K##_11)) \ _PROCESS((_LHS), (_RHS##_12), (_RES##_##_K##_12)) \ _PROCESS((_LHS), (_RHS##_13), (_RES##_##_K##_13)) \ _PROCESS((_LHS), (_RHS##_14), (_RES##_##_K##_14)) \ _PROCESS((_LHS), (_RHS##_15), (_RES##_##_K##_15)) #define MATRIX_VAR_INIT(_M, _N, _VAR_TYPE, _VAR_NAME, _VAR_INIT) \ MATRIX_VAR_INIT_##_M##X##_N(_VAR_TYPE, _VAR_NAME, _VAR_INIT) #define MATRIX_VAR_STORE(_M, _N, _STEP, _VAR, _ARRAY, _STORE, _NORM, ...) \ MATRIX_VAR_STORE_##_M##X##_N(_STEP, _VAR, _ARRAY, _STORE, _NORM, \ ##__VA_ARGS__) #define MATRIX_VAR_PERMUTE(_M, _N, _VAR, _PERMUTE, ...) \ MATRIX_VAR_PERMUTE_##_M##X##_N(_VAR, _PERMUTE, ##__VA_ARGS__) #define MATRIX_VAR_PROC(_M, _N, _K, _LHS, _RHS, _RES, _PROCESS) \ MATRIX_VAR_PROC_##_M##X##_N(_K, _LHS, _RHS, _RES, _PROCESS) ================================================ FILE: src/ailego/math/matrix_utility.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include namespace zvec { namespace ailego { //! Absolute value of a float static inline float FastAbs(float x) { uint32_t *p = reinterpret_cast(&x); *p &= 0x7fffffffu; return *reinterpret_cast(p); } #if defined(__SSE__) static inline float HorizontalMax_FP32_V128(__m128 v) { __m128 x1 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); __m128 x2 = _mm_max_ps(v, x1); __m128 x3 = _mm_shuffle_ps(x2, x2, _MM_SHUFFLE(0, 0, 0, 1)); __m128 x4 = _mm_max_ps(x2, x3); return _mm_cvtss_f32(x4); } static inline float HorizontalAdd_FP32_V128(__m128 v) { #ifdef __SSE3__ __m128 x1 = _mm_hadd_ps(v, v); __m128 x2 = _mm_hadd_ps(x1, x1); return _mm_cvtss_f32(x2); #else __m128 x1 = _mm_movehl_ps(v, v); __m128 x2 = _mm_add_ps(v, x1); __m128 x3 = _mm_shuffle_ps(x2, x2, 1); __m128 x4 = _mm_add_ss(x2, x3); return _mm_cvtss_f32(x4); #endif } #endif // __SSE__ #if defined(__SSE2__) static inline int32_t HorizontalAdd_INT32_V128(__m128i v) { #ifdef __SSE3__ __m128i x1 = _mm_hadd_epi32(v, v); __m128i x2 = _mm_hadd_epi32(x1, x1); return _mm_cvtsi128_si32(x2); #else __m128i x1 = _mm_shuffle_epi32(v, _MM_SHUFFLE(0, 0, 3, 2)); __m128i x2 = _mm_add_epi32(v, x1); __m128i x3 = _mm_shuffle_epi32(x2, _MM_SHUFFLE(0, 0, 0, 1)); __m128i x4 = _mm_add_epi32(x2, x3); return _mm_cvtsi128_si32(x4); #endif } static inline int64_t HorizontalAdd_INT64_V128(__m128i v) { #ifdef __SSE4_1__ return (_mm_extract_epi64(v, 0) + _mm_extract_epi64(v, 1)); #else return _mm_cvtsi128_si64( _mm_add_epi64(_mm_shuffle_epi32(v, _MM_SHUFFLE(0, 0, 3, 2)), v)); #endif } #endif // __SSE2__ #if defined(__SSSE3__) static const __m128i POPCNT_LOOKUP_SSE = _mm_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4); static inline __m128i VerticalPopCount_INT8_V128(__m128i v) { #if defined(__AVX512VL__) && defined(__AVX512BITALG__) return _mm_popcnt_epi8(v); #else const __m128i low_mask = _mm_set1_epi8(0x0f); __m128i lo = _mm_shuffle_epi8(POPCNT_LOOKUP_SSE, _mm_and_si128(v, low_mask)); __m128i hi = _mm_shuffle_epi8(POPCNT_LOOKUP_SSE, _mm_and_si128(_mm_srli_epi32(v, 4), low_mask)); return _mm_add_epi8(lo, hi); #endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m128i VerticalPopCount_INT16_V128(__m128i v) { #if defined(__AVX512VL__) && defined(__AVX512BITALG__) return _mm_popcnt_epi16(v); #else __m128i total = VerticalPopCount_INT8_V128(v); return _mm_add_epi16(_mm_srli_epi16(total, 8), _mm_and_si128(total, _mm_set1_epi16(0xff))); #endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m128i VerticalPopCount_INT32_V128(__m128i v) { #if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__) return _mm_popcnt_epi32(v); #else __m128i total = _mm_madd_epi16(VerticalPopCount_INT8_V128(v), _mm_set1_epi16(1)); return _mm_add_epi32(_mm_srli_epi32(total, 8), _mm_and_si128(total, _mm_set1_epi32(0xff))); #endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline __m128i VerticalPopCount_INT64_V128(__m128i v) { #if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__) return _mm_popcnt_epi64(v); #else return _mm_sad_epu8(VerticalPopCount_INT8_V128(v), _mm_setzero_si128()); #endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } #endif // __SSSE3__ #if defined(__SSE4_1__) static inline int16_t HorizontalMax_UINT8_V128(__m128i v) { v = _mm_max_epu8(v, _mm_shuffle_epi32(v, _MM_SHUFFLE(3, 2, 3, 2))); v = _mm_max_epu8(v, _mm_shuffle_epi32(v, _MM_SHUFFLE(1, 1, 1, 1))); v = _mm_max_epu8(v, _mm_shufflelo_epi16(v, _MM_SHUFFLE(1, 1, 1, 1))); v = _mm_max_epu8(v, _mm_srli_epi16(v, 8)); return static_cast(_mm_cvtsi128_si32(v)); } #endif // __SSE4_1__ #if defined(__AVX__) static inline float HorizontalMax_FP32_V256(__m256 v) { __m256 x1 = _mm256_permute_ps(v, _MM_SHUFFLE(0, 0, 3, 2)); __m256 x2 = _mm256_max_ps(v, x1); __m256 x3 = _mm256_permute_ps(x2, _MM_SHUFFLE(0, 0, 0, 1)); __m256 x4 = _mm256_max_ps(x2, x3); __m128 x5 = _mm256_extractf128_ps(x4, 1); __m128 x6 = _mm_max_ss(_mm256_castps256_ps128(x4), x5); return _mm_cvtss_f32(x6); } static inline float HorizontalAdd_FP32_V256(__m256 v) { __m256 x1 = _mm256_hadd_ps(v, v); __m256 x2 = _mm256_hadd_ps(x1, x1); __m128 x3 = _mm256_extractf128_ps(x2, 1); __m128 x4 = _mm_add_ss(_mm256_castps256_ps128(x2), x3); return _mm_cvtss_f32(x4); } #endif // __AVX__ #if defined(__AVX2__) #define POPCNT_MASK1_INT8_AVX _mm256_set1_epi8(0x0f) #define POPCNT_MASK1_INT16_AVX _mm256_set1_epi16(1) #define POPCNT_MASK2_INT16_AVX _mm256_set1_epi16(0xff) #define POPCNT_MASK1_INT32_AVX _mm256_set1_epi32(0xff) #define POPCNT_ZERO_AVX _mm256_setzero_si256() #define POPCNT_LOOKUP_AVX _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4) static inline __m256i VerticalPopCount_INT8_V256(__m256i v) { #if defined(__AVX512VL__) && defined(__AVX512BITALG__) return _mm256_popcnt_epi8(v); #else __m256i lo = _mm256_shuffle_epi8(POPCNT_LOOKUP_AVX, _mm256_and_si256(v, POPCNT_MASK1_INT8_AVX)); __m256i hi = _mm256_shuffle_epi8( POPCNT_LOOKUP_AVX, _mm256_and_si256(_mm256_srli_epi32(v, 4), POPCNT_MASK1_INT8_AVX)); return _mm256_add_epi8(lo, hi); #endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m256i VerticalPopCount_INT16_V256(__m256i v) { #if defined(__AVX512VL__) && defined(__AVX512BITALG__) return _mm256_popcnt_epi16(v); #else __m256i total = VerticalPopCount_INT8_V256(v); return _mm256_add_epi16(_mm256_srli_epi16(total, 8), _mm256_and_si256(total, POPCNT_MASK2_INT16_AVX)); #endif // __AVX512VL__ && __AVX512BITALG__ } static inline __m256i VerticalPopCount_INT32_V256(__m256i v) { #if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__) return _mm256_popcnt_epi32(v); #else __m256i total = _mm256_madd_epi16(VerticalPopCount_INT8_V256(v), POPCNT_MASK1_INT16_AVX); return _mm256_add_epi32(_mm256_srli_epi32(total, 8), _mm256_and_si256(total, POPCNT_MASK1_INT32_AVX)); #endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline __m256i VerticalPopCount_INT64_V256(__m256i v) { #if defined(__AVX512VL__) && defined(__AVX512VPOPCNTDQ__) return _mm256_popcnt_epi64(v); #else return _mm256_sad_epu8(VerticalPopCount_INT8_V256(v), POPCNT_ZERO_AVX); #endif // __AVX512VL__ && __AVX512VPOPCNTDQ__ } static inline int16_t HorizontalMax_UINT8_V256(__m256i v) { v = _mm256_max_epu8(v, _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 2, 3, 2))); v = _mm256_max_epu8(v, _mm256_shuffle_epi32(v, _MM_SHUFFLE(1, 1, 1, 1))); v = _mm256_max_epu8(v, _mm256_shufflelo_epi16(v, _MM_SHUFFLE(1, 1, 1, 1))); __m128i x = _mm_max_epu8(_mm256_castsi256_si128(v), _mm256_extractf128_si256(v, 1)); x = _mm_max_epu8(x, _mm_srli_epi16(x, 8)); return static_cast(_mm_cvtsi128_si32(x)); } static inline int32_t HorizontalAdd_INT32_V256(__m256i v) { __m256i x1 = _mm256_hadd_epi32(v, v); __m256i x2 = _mm256_hadd_epi32(x1, x1); __m128i x3 = _mm256_extractf128_si256(x2, 1); __m128i x4 = _mm_add_epi32(_mm256_castsi256_si128(x2), x3); return _mm_cvtsi128_si32(x4); } static inline int64_t HorizontalAdd_INT64_V256(__m256i v) { __m256i x1 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(1, 0, 3, 2)); __m256i x2 = _mm256_add_epi64(v, x1); __m128i x3 = _mm256_extractf128_si256(x2, 1); __m128i x4 = _mm_add_epi64(_mm256_extractf128_si256(x2, 0), x3); return _mm_cvtsi128_si64(x4); } #endif // __AVX2__ #if defined(__AVX512F__) static inline float HorizontalMax_FP32_V512(__m512 v) { __m256 low = _mm512_castps512_ps256(v); __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1)); return HorizontalMax_FP32_V256(_mm256_max_ps(low, high)); } static inline float HorizontalAdd_FP32_V512(__m512 v) { __m256 low = _mm512_castps512_ps256(v); __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1)); return HorizontalAdd_FP32_V256(_mm256_add_ps(low, high)); } #endif // __AVX512F__ #if defined(__AVX512FP16__) static inline float HorizontalMax_FP16_V512(__m512h v) { __m512 low = _mm512_cvtxph_ps(_mm512_castph512_ph256(v)); __m512 high = _mm512_cvtxph_ps( _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(v), 1))); return HorizontalMax_FP32_V512(_mm512_max_ps(low, high)); } static inline float HorizontalAdd_FP16_V512(__m512h v) { __m512 low = _mm512_cvtxph_ps(_mm512_castph512_ph256(v)); __m512 high = _mm512_cvtxph_ps( _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(v), 1))); return HorizontalAdd_FP32_V512(_mm512_add_ps(low, high)); } #endif // __AVX512FP16__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "distance_utility.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- /*! Compute the Mips SphericalInjection Squared Euclidean Distance with the two * vectors's InnerProduct and each squared l2-normlized value, and the e2 is * 1.0 / max_squared_l2_norm */ static float inline ComputeSphericalInjection(double ip, double u2, double v2, double e2) { if (e2 == 0.0) { // Implies *localized* spherical injection. return static_cast(2.0 - 2.0 * ip / std::max(u2, v2)); } auto v = (1.0 - e2 * u2) * (1.0 - e2 * v2); auto score = v > 0.0 ? (1.0 - e2 * ip - std::sqrt(v)) : (1.0 - e2 * ip); return static_cast(score * 2.0); } /*! Mips Squared Euclidean Distance Matrix */ template struct MipsSquaredEuclideanDistanceMatrix; /*! Mips Squared Euclidean Distance Matrix (M=1, N=1) */ template struct MipsSquaredEuclideanDistanceMatrix { //! Type of value using ValueType = typename std::remove_cv::type; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && out); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < dim; ++i) { u2 += p[i] * p[i]; v2 += q[i] * q[i]; sum += static_cast(p[i] * q[i]); } *out = ComputeSphericalInjection(sum, u2, v2, e2); } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && out); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < dim; ++i) { u2 += p[i] * p[i]; v2 += q[i] * q[i]; sum += MathHelper::SquaredDifference(p[i], q[i]); } sum *= e2; u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } *out = sum; } }; template <> struct MipsSquaredEuclideanDistanceMatrix { //! Type of value using ValueType = uint8_t; // Compute the distance between matrix and query by SphericalInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out); // Compute the distance between matrix and query by RepeatedQuadraticInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out); }; template <> struct MipsSquaredEuclideanDistanceMatrix { //! Type of value using ValueType = int8_t; // Compute the distance between matrix and query by SphericalInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out); // Compute the distance between matrix and query by RepeatedQuadraticInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out); }; template <> struct MipsSquaredEuclideanDistanceMatrix { //! Type of value using ValueType = Float16; // Compute the distance between matrix and query by SphericalInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out); // Compute the distance between matrix and query by RepeatedQuadraticInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out); }; template <> struct MipsSquaredEuclideanDistanceMatrix { //! Type of value using ValueType = float; // Compute the distance between matrix and query by SphericalInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out); // Compute the distance between matrix and query by RepeatedQuadraticInjection static void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out); }; /*! Mips Squared Euclidean Distance Matrix (M >= 2, N >= 2) */ template struct MipsSquaredEuclideanDistanceMatrix< T, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && out); if (dim == 0) { return; } std::array u2; std::array v2; for (size_t i = 0; i < M; ++i) { const ValueType p_val = p[i]; u2[i] = static_cast(p_val * p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = static_cast(p_val * q[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = static_cast(q[i] * q[i]); } p += M; q += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const ValueType p_val = p[i]; u2[i] += static_cast(p_val * p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += static_cast(p_val * q[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += static_cast(q[i] * q[i]); } p += M; q += N; } // Compute the injection for (size_t i = 0; i < M; ++i) { float *r = out + i; const float u2_val = u2[i]; for (size_t j = 0; j < N; ++j) { *r = ComputeSphericalInjection(*r, u2_val, v2[j], e2); r += M; } } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && out); if (dim == 0) { return; } std::array u2; std::array v2; for (size_t i = 0; i < M; ++i) { const ValueType p_val = p[i]; u2[i] = static_cast(p_val * p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = MathHelper::SquaredDifference(p_val, q[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = static_cast(q[i] * q[i]); } p += M; q += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const ValueType p_val = p[i]; u2[i] += static_cast(p_val * p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += MathHelper::SquaredDifference(p_val, q[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += static_cast(q[i] * q[i]); } p += M; q += N; } // Compute the injections float *r = out; for (size_t i = 0; i < M; ++i) { u2[i] *= e2; for (size_t j = 0; j < N; ++j) { (*r++) *= e2; } } for (size_t i = 0; i < N; ++i) { v2[i] *= e2; } for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { r = out + i; float u2_val = u2[i]; u2[i] = u2_val * u2_val; for (size_t j = 0; j < N; ++j) { *r += (u2_val - v2[j]) * (u2_val - v2[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = v2[i] * v2[i]; } } } }; /*! Mips Squared Euclidean Distance Matrix (N=1) */ template struct MipsSquaredEuclideanDistanceMatrix< T, M, 1, typename std::enable_if= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && out); const ValueType *q_end = q + dim; if (q == q_end) { return; } std::array u2; ValueType q_val = *q++; float v2 = static_cast(q_val * q_val); for (size_t i = 0; i < M; ++i) { u2[i] = static_cast(p[i] * p[i]); out[i] = static_cast(p[i] * q_val); } p += M; while (q != q_end) { q_val = *q++; v2 += static_cast(q_val * q_val); for (size_t i = 0; i < M; ++i) { u2[i] += static_cast(p[i] * p[i]); out[i] += static_cast(p[i] * q_val); } p += M; } // Compute the injection for (size_t i = 0; i < M; ++i) { out[i] = ComputeSphericalInjection(out[i], u2[i], v2, e2); } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && out); const ValueType *q_end = q + dim; if (q == q_end) { return; } std::array u2; ValueType q_val = *q++; float v2 = static_cast(q_val * q_val); for (size_t i = 0; i < M; ++i) { u2[i] = static_cast(p[i] * p[i]); out[i] = MathHelper::SquaredDifference(p[i], q_val); } p += M; while (q != q_end) { q_val = *q++; v2 += static_cast(q_val * q_val); for (size_t i = 0; i < M; ++i) { u2[i] += static_cast(p[i] * p[i]); out[i] += MathHelper::SquaredDifference(p[i], q_val); } p += M; } // Compute the injections for (size_t i = 0; i < M; ++i) { out[i] *= e2; u2[i] *= e2; } v2 *= e2; for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { const float u_val = u2[i]; u2[i] = u_val * u_val; out[i] += (u_val - v2) * (u_val - v2); } v2 = v2 * v2; } } }; /*! Mips Squared Euclidean Distance Matrix (INT8, M >=2, N >= 2) */ template struct MipsSquaredEuclideanDistanceMatrix< int8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = int8_t; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 3) && out); dim >>= 2; if (dim == 0) { return; } std::array u2; std::array v2; const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] = Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = FusedMultiplyAdd(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = Squared(q_it[i]); } p_it += M; q_it += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] += Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += FusedMultiplyAdd(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += Squared(q_it[i]); } p_it += M; q_it += N; } // Compute the injection for (size_t i = 0; i < M; ++i) { float *r = out + i; const float u2_val = u2[i]; for (size_t j = 0; j < N; ++j) { *r = ComputeSphericalInjection(*r, u2_val, v2[j], e2); r += M; } } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 3) && out); dim >>= 2; if (dim == 0) { return; } std::array u2; std::array v2; const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] = Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = SquaredDifference(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = Squared(q_it[i]); } p_it += M; q_it += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] += Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += SquaredDifference(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += Squared(q_it[i]); } p_it += M; q_it += N; } // Compute the injections float *r = out; for (size_t i = 0; i < M; ++i) { u2[i] *= e2; for (size_t j = 0; j < N; ++j) { (*r++) *= e2; } } for (size_t i = 0; i < N; ++i) { v2[i] *= e2; } for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { r = out + i; float u2_val = u2[i]; u2[i] = u2_val * u2_val; for (size_t j = 0; j < N; ++j) { *r += (u2_val - v2[j]) * (u2_val - v2[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = v2[i] * v2[i]; } } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); } //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast(MathHelper::SquaredDifference( (int8_t)(lhs >> 0), (int8_t)(rhs >> 0)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 8), (int8_t)(rhs >> 8)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 16), (int8_t)(rhs >> 16)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 24), (int8_t)(rhs >> 24))); } //! Calculate sum of squared values static inline float Squared(uint32_t v) { return static_cast((int8_t)(v >> 0) * (int8_t)(v >> 0) + (int8_t)(v >> 8) * (int8_t)(v >> 8) + (int8_t)(v >> 16) * (int8_t)(v >> 16) + (int8_t)(v >> 24) * (int8_t)(v >> 24)); } }; /*! Mips Squared Euclidean Distance Matrix (INT8, N=1) */ template struct MipsSquaredEuclideanDistanceMatrix< int8_t, M, 1, typename std::enable_if= 2>::type> { //! Type of value using ValueType = int8_t; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 3) && out); const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 2); if (q_it == q_end) { return; } std::array u2; uint32_t q_val = *q_it++; float v2 = Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] = Squared(p_val); out[i] = FusedMultiplyAdd(p_val, q_val); } while (q_it != q_end) { q_val = *q_it++; v2 += Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] += Squared(p_val); out[i] += FusedMultiplyAdd(p_val, q_val); } } // Compute the injection for (size_t i = 0; i < M; ++i) { out[i] = ComputeSphericalInjection(out[i], u2[i], v2, e2); } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 3) && out); const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 2); if (q_it == q_end) { return; } std::array u2; uint32_t q_val = *q_it++; float v2 = Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] = Squared(p_val); out[i] = SquaredDifference(p_val, q_val); } while (q_it != q_end) { q_val = *q_it++; v2 += Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] += Squared(p_val); out[i] += SquaredDifference(p_val, q_val); } } // Compute the injections for (size_t i = 0; i < M; ++i) { out[i] *= e2; u2[i] *= e2; } v2 *= e2; for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { const float u_val = u2[i]; u2[i] = u_val * u_val; out[i] += (u_val - v2) * (u_val - v2); } v2 = v2 * v2; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast((int8_t)(lhs >> 0) * (int8_t)(rhs >> 0) + (int8_t)(lhs >> 8) * (int8_t)(rhs >> 8) + (int8_t)(lhs >> 16) * (int8_t)(rhs >> 16) + (int8_t)(lhs >> 24) * (int8_t)(rhs >> 24)); } //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast(MathHelper::SquaredDifference( (int8_t)(lhs >> 0), (int8_t)(rhs >> 0)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 8), (int8_t)(rhs >> 8)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 16), (int8_t)(rhs >> 16)) + MathHelper::SquaredDifference( (int8_t)(lhs >> 24), (int8_t)(rhs >> 24))); } //! Calculate sum of squared values static inline float Squared(uint32_t v) { return static_cast((int8_t)(v >> 0) * (int8_t)(v >> 0) + (int8_t)(v >> 8) * (int8_t)(v >> 8) + (int8_t)(v >> 16) * (int8_t)(v >> 16) + (int8_t)(v >> 24) * (int8_t)(v >> 24)); } }; /*! Mips Squared Euclidean Distance Matrix (INT4, M >=2, N >= 2) */ template struct MipsSquaredEuclideanDistanceMatrix< uint8_t, M, N, typename std::enable_if= 2 && N >= 2>::type> { //! Type of value using ValueType = uint8_t; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 7) && out); dim >>= 3; if (dim == 0) { return; } std::array u2; std::array v2; const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] = Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = FusedMultiplyAdd(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = Squared(q_it[i]); } p_it += M; q_it += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] += Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += FusedMultiplyAdd(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += Squared(q_it[i]); } p_it += M; q_it += N; } // Compute the injection for (size_t i = 0; i < M; ++i) { float *r = out + i; const float u2_val = u2[i]; for (size_t j = 0; j < N; ++j) { *r = ComputeSphericalInjection(*r, u2_val, v2[j], e2); r += M; } } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 7) && out); dim >>= 3; if (dim == 0) { return; } std::array u2; std::array v2; const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] = Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r = SquaredDifference(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = Squared(q_it[i]); } p_it += M; q_it += N; for (size_t k = 1; k < dim; ++k) { for (size_t i = 0; i < M; ++i) { const uint32_t p_val = p_it[i]; u2[i] += Squared(p_val); float *r = out + i; for (size_t j = 0; j < N; ++j) { *r += SquaredDifference(p_val, q_it[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] += Squared(q_it[i]); } p_it += M; q_it += N; } // Compute the injections float *r = out; for (size_t i = 0; i < M; ++i) { u2[i] *= e2; for (size_t j = 0; j < N; ++j) { (*r++) *= e2; } } for (size_t i = 0; i < N; ++i) { v2[i] *= e2; } for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { r = out + i; float u2_val = u2[i]; u2[i] = u2_val * u2_val; for (size_t j = 0; j < N; ++j) { *r += (u2_val - v2[j]) * (u2_val - v2[j]); r += M; } } for (size_t i = 0; i < N; ++i) { v2[i] = v2[i] * v2[i]; } } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast( Int4SquaredDiffTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4SquaredDiffTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4SquaredDiffTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4SquaredDiffTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4SquaredDiffTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4SquaredDiffTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4SquaredDiffTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4SquaredDiffTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } //! Calculate sum of squared values static inline float Squared(uint32_t u) { float sum = 0.0f; for (size_t i = 0; i < 32; i += 8) { uint8_t v = (uint8_t)(u >> i); int8_t lo = (int8_t)(v << 4) >> 4; int8_t hi = (int8_t)(v & 0xf0) >> 4; sum += hi * hi + lo * lo; } return sum; } }; /*! Mips Squared Euclidean Distance Matrix (INT4, N=1) */ template struct MipsSquaredEuclideanDistanceMatrix< uint8_t, M, 1, typename std::enable_if= 2>::type> { //! Type of value using ValueType = uint8_t; // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 7) && out); const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 3); if (q_it == q_end) { return; } std::array u2; uint32_t q_val = *q_it++; float v2 = Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] = Squared(p_val); out[i] = FusedMultiplyAdd(p_val, q_val); } while (q_it != q_end) { q_val = *q_it++; v2 += Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] += Squared(p_val); out[i] += FusedMultiplyAdd(p_val, q_val); } } // Compute the injection for (size_t i = 0; i < M; ++i) { out[i] = ComputeSphericalInjection(out[i], u2[i], v2, e2); } } // Compute the distance between matrix and query by RepeatedQuadraticInjection static inline void Compute(const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { ailego_assert(p && q && dim && !(dim & 7) && out); const uint32_t *p_it = reinterpret_cast(p); const uint32_t *q_it = reinterpret_cast(q); const uint32_t *q_end = q_it + (dim >> 3); if (q_it == q_end) { return; } std::array u2; uint32_t q_val = *q_it++; float v2 = Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] = Squared(p_val); out[i] = SquaredDifference(p_val, q_val); } while (q_it != q_end) { q_val = *q_it++; v2 += Squared(q_val); for (size_t i = 0; i < M; ++i) { const uint32_t p_val = *p_it++; u2[i] += Squared(p_val); out[i] += SquaredDifference(p_val, q_val); } } // Compute the injections for (size_t i = 0; i < M; ++i) { out[i] *= e2; u2[i] *= e2; } v2 *= e2; for (size_t k = 0; k < m; ++k) { for (size_t i = 0; i < M; ++i) { const float u_val = u2[i]; u2[i] = u_val * u_val; out[i] += (u_val - v2) * (u_val - v2); } v2 = v2 * v2; } } protected: //! Calculate Fused-Multiply-Add static inline float FusedMultiplyAdd(uint32_t lhs, uint32_t rhs) { return static_cast( Int4MulTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4MulTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4MulTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4MulTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4MulTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4MulTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4MulTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4MulTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } //! Calculate the squared difference static inline float SquaredDifference(uint32_t lhs, uint32_t rhs) { return static_cast( Int4SquaredDiffTable[((lhs << 4) & 0xf0) | ((rhs >> 0) & 0xf)] + Int4SquaredDiffTable[((lhs >> 0) & 0xf0) | ((rhs >> 4) & 0xf)] + Int4SquaredDiffTable[((lhs >> 4) & 0xf0) | ((rhs >> 8) & 0xf)] + Int4SquaredDiffTable[((lhs >> 8) & 0xf0) | ((rhs >> 12) & 0xf)] + Int4SquaredDiffTable[((lhs >> 12) & 0xf0) | ((rhs >> 16) & 0xf)] + Int4SquaredDiffTable[((lhs >> 16) & 0xf0) | ((rhs >> 20) & 0xf)] + Int4SquaredDiffTable[((lhs >> 20) & 0xf0) | ((rhs >> 24) & 0xf)] + Int4SquaredDiffTable[((lhs >> 24) & 0xf0) | ((rhs >> 28) & 0xf)]); } //! Calculate sum of squared values static inline float Squared(uint32_t u) { float sum = 0.0f; for (size_t i = 0; i < 32; i += 8) { uint8_t v = (uint8_t)(u >> i); int8_t lo = (int8_t)(v << 4) >> 4; int8_t hi = (int8_t)(v & 0xf0) >> 4; sum += hi * hi + lo * lo; } return sum; } }; //-------------------------------------------------- // Sparse //-------------------------------------------------- /*! Mips Squared Euclidean Sparse Distance Matrix */ template struct MipsSquaredEuclideanSparseDistanceMatrix { //! Type of value using ValueType = typename std::remove_cv::type; static float ComputeInnerProductSparseInSegment( uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value); // Compute the distance between matrix and query by SphericalInjection static inline void Compute(const void *m_sparse_data_in, const void *q_sparse_data_in, float *out) { ailego_assert(m_sparse_data_in && q_sparse_data_in && out); const uint8_t *m_sparse_data = reinterpret_cast(m_sparse_data_in); const uint8_t *q_sparse_data = reinterpret_cast(q_sparse_data_in); const uint32_t m_sparse_count = *reinterpret_cast(m_sparse_data); const uint32_t q_sparse_count = *reinterpret_cast(q_sparse_data); if (m_sparse_count == 0 && q_sparse_count == 0) { *out = 0; return; } if (m_sparse_count == 0 || q_sparse_count == 0) { *out = 2; return; } const uint32_t m_seg_count = *reinterpret_cast(m_sparse_data + sizeof(uint32_t)); const uint32_t q_seg_count = *reinterpret_cast(q_sparse_data + sizeof(uint32_t)); const uint32_t *m_seg_id = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *q_seg_id = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t)); const uint32_t *m_seg_vec_cnt = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * sizeof(uint32_t)); const uint32_t *q_seg_vec_cnt = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * sizeof(uint32_t)); const uint16_t *m_sparse_index = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t)); const uint16_t *q_sparse_index = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t)); const ValueType *m_sparse_value = reinterpret_cast( m_sparse_data + 2 * sizeof(uint32_t) + m_seg_count * 2 * sizeof(uint32_t) + m_sparse_count * sizeof(uint16_t)); const ValueType *q_sparse_value = reinterpret_cast( q_sparse_data + 2 * sizeof(uint32_t) + q_seg_count * 2 * sizeof(uint32_t) + q_sparse_count * sizeof(uint16_t)); float ip = 0.0f; size_t m_s = 0; size_t q_s = 0; size_t m_count = 0; size_t q_count = 0; while (m_s < m_seg_count && q_s < q_seg_count) { if (m_seg_id[m_s] == q_seg_id[q_s]) { ip += ComputeInnerProductSparseInSegment( m_seg_vec_cnt[m_s], m_sparse_index + m_count, m_sparse_value + m_count, q_seg_vec_cnt[q_s], q_sparse_index + q_count, q_sparse_value + q_count); m_count += m_seg_vec_cnt[m_s]; q_count += q_seg_vec_cnt[q_s]; ++m_s; ++q_s; } else if (m_seg_id[m_s] < q_seg_id[q_s]) { m_count += m_seg_vec_cnt[m_s]; ++m_s; } else { q_count += q_seg_vec_cnt[q_s]; ++q_s; } } float l2_m{0.0f}; SquaredNorm2Matrix::Compute(m_sparse_value, m_sparse_count, &l2_m); float l2_q{0.0f}; SquaredNorm2Matrix::Compute(q_sparse_value, q_sparse_count, &l2_q); *out = ComputeSphericalInjection(ip, l2_m, l2_q, 0.0f); } }; template float MipsSquaredEuclideanSparseDistanceMatrix< T>::ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value) { float sum = 0.0f; size_t m_i = 0; size_t q_i = 0; while (m_i < m_sparse_count && q_i < q_sparse_count) { if (m_sparse_index[m_i] == q_sparse_index[q_i]) { sum += m_sparse_value[m_i] * q_sparse_value[q_i]; ++m_i; ++q_i; } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { ++m_i; } else { ++q_i; } } return sum; } template <> float MipsSquaredEuclideanSparseDistanceMatrix< float>::ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value); } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp16_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX__) && defined(__F16C__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size, float *sql, float *sqr) { __m256 ymm_sum_0 = _mm256_setzero_ps(); __m256 ymm_sum_1 = _mm256_setzero_ps(); __m256 ymm_sum_norm1 = _mm256_setzero_ps(); __m256 ymm_sum_norm2 = _mm256_setzero_ps(); const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 4) << 4); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 8) { __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)lhs)); __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)rhs)); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); lhs += 8; rhs += 8; } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_lhs)); __m256 ymm_lhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_lhs, 1)); __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_rhs)); __m256 ymm_rhs_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_rhs, 1)); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 8) { __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); lhs += 8; rhs += 8; } } float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); switch (last - lhs) { case 7: FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); /* FALLTHRU */ case 6: FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); /* FALLTHRU */ case 5: FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); /* FALLTHRU */ case 4: FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); /* FALLTHRU */ case 3: FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); /* FALLTHRU */ case 2: FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); /* FALLTHRU */ case 1: FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16AVX(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16AVX(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX__ && __F16C__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp16_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX512F__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size, float *sql, float *sqr) { __m512 zmm_sum_0 = _mm512_setzero_ps(); __m512 zmm_sum_1 = _mm512_setzero_ps(); __m512 zmm_sum_norm1 = _mm512_setzero_ps(); __m512 zmm_sum_norm2 = _mm512_setzero_ps(); const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 5) << 5); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512i zmm_lhs = _mm512_load_si512((const __m512i *)lhs); __m512i zmm_rhs = _mm512_load_si512((const __m512i *)rhs); __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) } if (last >= last_aligned + 16) { __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)lhs)); __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)rhs)); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512i zmm_lhs = _mm512_loadu_si512((const __m512i *)lhs); __m512i zmm_rhs = _mm512_loadu_si512((const __m512i *)rhs); __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_lhs)); __m512 zmm_lhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_lhs, 1)); __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_rhs)); __m512 zmm_rhs_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_rhs, 1)); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) } if (last >= last_aligned + 16) { __m512 zmm_lhs_0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)lhs)); __m512 zmm_rhs_0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)rhs)); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) lhs += 16; rhs += 16; } } __m256 ymm_sum_0 = HorizontalAdd_FP32_V512_TO_V256(_mm512_add_ps(zmm_sum_0, zmm_sum_1)); __m256 ymm_sum_norm1 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm1); __m256 ymm_sum_norm2 = HorizontalAdd_FP32_V512_TO_V256(zmm_sum_norm2); if (last >= lhs + 8) { __m256 ymm_lhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)lhs)); __m256 ymm_rhs_0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)rhs)); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); lhs += 8; rhs += 8; } float result = HorizontalAdd_FP32_V256(ymm_sum_0); float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); switch (last - lhs) { case 7: FMA_FP16_GENERAL(lhs[6], rhs[6], result, norm1, norm2); /* FALLTHRU */ case 6: FMA_FP16_GENERAL(lhs[5], rhs[5], result, norm1, norm2); /* FALLTHRU */ case 5: FMA_FP16_GENERAL(lhs[4], rhs[4], result, norm1, norm2); /* FALLTHRU */ case 4: FMA_FP16_GENERAL(lhs[3], rhs[3], result, norm1, norm2); /* FALLTHRU */ case 3: FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); /* FALLTHRU */ case 2: FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); /* FALLTHRU */ case 1: FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16AVX512(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX512( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16AVX512(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX512F__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp16_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16NEON( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size, float e2); #endif #if defined(__AVX512F__) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX512( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp16AVX512(const Float16 *lhs, const Float16 *rhs, size_t size, float e2); #endif #if defined(__AVX__) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp16AVX(const Float16 *lhs, const Float16 *rhs, size_t size, float e2); #endif float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16Scalar( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp16Scalar( const ailego::Float16 *p, const ailego::Float16 *q, size_t dim, float e2); //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { #if defined(__ARM_NEON) *out = MipsEuclideanDistanceSphericalInjectionFp16NEON(p, q, dim, e2); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MipsEuclideanDistanceSphericalInjectionFp16AVX512(p, q, dim, e2); return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MipsEuclideanDistanceSphericalInjectionFp16AVX(p, q, dim, e2); return; } #endif //__AVX__ *out = MipsEuclideanDistanceSphericalInjectionFp16Scalar(p, q, dim, e2); return; #endif //__ARM_NEON } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { #if defined(__ARM_NEON) *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp16NEON(p, q, dim, m, e2); #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX512(p, q, dim, m, e2); return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp16AVX(p, q, dim, m, e2); return; } #endif //__AVX__ *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp16Scalar(p, q, dim, m, e2); return; #endif //__ARM_NEON } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp16_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp16.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) && defined(__aarch64__) #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size, float *sql, float *sqr) { const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 3) << 3); float16x8_t v_sum = vdupq_n_f16(0); float16x8_t v_sum_norm1 = vdupq_n_f16(0); float16x8_t v_sum_norm2 = vdupq_n_f16(0); for (; lhs != last_aligned; lhs += 8, rhs += 8) { float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); } if (last >= last_aligned + 4) { float16x8_t v_lhs = vcombine_f16(vld1_f16((const float16_t *)lhs), vreinterpret_f16_u64(vdup_n_u64(0ul))); float16x8_t v_rhs = vcombine_f16(vld1_f16((const float16_t *)rhs), vreinterpret_f16_u64(vdup_n_u64(0ul))); v_sum = vfmaq_f16(v_sum, v_lhs, v_rhs); v_sum_norm1 = vfmaq_f16(v_sum_norm1, v_lhs, v_lhs); v_sum_norm2 = vfmaq_f16(v_sum_norm2, v_rhs, v_rhs); lhs += 4; rhs += 4; } float result = HorizontalAdd_FP16_NEON(v_sum); float norm1 = HorizontalAdd_FP16_NEON(v_sum_norm1); float norm2 = HorizontalAdd_FP16_NEON(v_sum_norm2); switch (last - lhs) { case 3: FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); /* FALLTHRU */ case 2: FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); /* FALLTHRU */ case 1: FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); } *sql = norm1; *sqr = norm2; return result; } #else //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size, float *sql, float *sqr) { const Float16 *last = lhs + size; const Float16 *last_aligned = lhs + ((size >> 3) << 3); float32x4_t v_sum_0 = vdupq_n_f32(0); float32x4_t v_sum_1 = vdupq_n_f32(0); float32x4_t v_sum_norm1 = vdupq_n_f32(0); float32x4_t v_sum_norm2 = vdupq_n_f32(0); for (; lhs != last_aligned; lhs += 8, rhs += 8) { float16x8_t v_lhs = vld1q_f16((const float16_t *)lhs); float16x8_t v_rhs = vld1q_f16((const float16_t *)rhs); float32x4_t v_lhs_0 = vcvt_f32_f16(vget_low_f16(v_lhs)); float32x4_t v_rhs_0 = vcvt_f32_f16(vget_low_f16(v_rhs)); float32x4_t v_lhs_1 = vcvt_high_f32_f16(v_lhs); float32x4_t v_rhs_1 = vcvt_high_f32_f16(v_rhs); v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); } if (last >= last_aligned + 4) { float32x4_t v_lhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)lhs)); float32x4_t v_rhs_0 = vcvt_f32_f16(vld1_f16((const float16_t *)rhs)); v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); lhs += 4; rhs += 4; } float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); float norm1 = vaddvq_f32(v_sum_norm1); float norm2 = vaddvq_f32(v_sum_norm2); switch (last - lhs) { case 3: FMA_FP16_GENERAL(lhs[2], rhs[2], result, norm1, norm2); /* FALLTHRU */ case 2: FMA_FP16_GENERAL(lhs[1], rhs[1], result, norm1, norm2); /* FALLTHRU */ case 1: FMA_FP16_GENERAL(lhs[0], rhs[0], result, norm1, norm2); } *sql = norm1; *sqr = norm2; return result; } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC float MipsEuclideanDistanceSphericalInjectionFp16NEON(const Float16 *lhs, const Float16 *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16NEON(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16NEON( const Float16 *lhs, const Float16 *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp16NEON(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __ARM_NEON && __aarch64__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp32_avx.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE__) float InnerProductAndSquaredNormFp32SSE(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr); #endif #if defined(__AVX__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp32AVX(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 4) << 4); __m256 ymm_sum_0 = _mm256_setzero_ps(); __m256 ymm_sum_1 = _mm256_setzero_ps(); __m256 ymm_sum_norm1 = _mm256_setzero_ps(); __m256 ymm_sum_norm2 = _mm256_setzero_ps(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_lhs_0 = _mm256_load_ps(lhs + 0); __m256 ymm_lhs_1 = _mm256_load_ps(lhs + 8); __m256 ymm_rhs_0 = _mm256_load_ps(rhs + 0); __m256 ymm_rhs_1 = _mm256_load_ps(rhs + 8); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 8) { __m256 ymm_lhs_0 = _mm256_load_ps(lhs); __m256 ymm_rhs_0 = _mm256_load_ps(rhs); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); lhs += 8; rhs += 8; } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs + 0); __m256 ymm_lhs_1 = _mm256_loadu_ps(lhs + 8); __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs + 0); __m256 ymm_rhs_1 = _mm256_loadu_ps(rhs + 8); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 8) { __m256 ymm_lhs_0 = _mm256_loadu_ps(lhs); __m256 ymm_rhs_0 = _mm256_loadu_ps(rhs); ymm_sum_0 = _mm256_fmadd_ps(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); ymm_sum_norm1 = _mm256_fmadd_ps(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); ymm_sum_norm2 = _mm256_fmadd_ps(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); lhs += 8; rhs += 8; } } float result = HorizontalAdd_FP32_V256(_mm256_add_ps(ymm_sum_0, ymm_sum_1)); float norm1 = HorizontalAdd_FP32_V256(ymm_sum_norm1); float norm2 = HorizontalAdd_FP32_V256(ymm_sum_norm2); switch (last - lhs) { case 7: FMA_FP32_GENERAL(lhs[6], rhs[6], result, norm1, norm2) /* FALLTHRU */ case 6: FMA_FP32_GENERAL(lhs[5], rhs[5], result, norm1, norm2) /* FALLTHRU */ case 5: FMA_FP32_GENERAL(lhs[4], rhs[4], result, norm1, norm2) /* FALLTHRU */ case 4: FMA_FP32_GENERAL(lhs[3], rhs[3], result, norm1, norm2) /* FALLTHRU */ case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionFp32AVX(const float *lhs, const float *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; if (size > 7) { sum = InnerProductAndSquaredNormFp32AVX(lhs, rhs, size, &u2, &v2); } else { sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); } return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX( const float *lhs, const float *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; if (size > 7) { sum = InnerProductAndSquaredNormFp32AVX(lhs, rhs, size, &u2, &v2); } else { sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); } sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp32_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE__) float InnerProductAndSquaredNormFp32SSE(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr); #endif #if defined(__AVX__) float InnerProductAndSquaredNormFp32AVX(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr); #endif #if defined(__AVX512F__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp32AVX512(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 5) << 5); __m512 zmm_sum_0 = _mm512_setzero_ps(); __m512 zmm_sum_1 = _mm512_setzero_ps(); __m512 zmm_sum_norm1 = _mm512_setzero_ps(); __m512 zmm_sum_norm2 = _mm512_setzero_ps(); if (((uintptr_t)lhs & 0x3f) == 0 && ((uintptr_t)rhs & 0x3f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512 zmm_lhs_0 = _mm512_load_ps(lhs + 0); __m512 zmm_lhs_1 = _mm512_load_ps(lhs + 16); __m512 zmm_rhs_0 = _mm512_load_ps(rhs + 0); __m512 zmm_rhs_1 = _mm512_load_ps(rhs + 16); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) } if (last >= last_aligned + 16) { __m512 zmm_lhs_0 = _mm512_load_ps(lhs); __m512 zmm_rhs_0 = _mm512_load_ps(rhs); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs + 0); __m512 zmm_lhs_1 = _mm512_loadu_ps(lhs + 16); __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs + 0); __m512 zmm_rhs_1 = _mm512_loadu_ps(rhs + 16); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_1, zmm_rhs_1, zmm_sum_1) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_lhs_1, zmm_lhs_1, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) FMA_FP32_AVX512(zmm_rhs_1, zmm_rhs_1, zmm_sum_norm2) } if (last >= last_aligned + 16) { __m512 zmm_lhs_0 = _mm512_loadu_ps(lhs); __m512 zmm_rhs_0 = _mm512_loadu_ps(rhs); FMA_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0) FMA_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1) FMA_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2) lhs += 16; rhs += 16; } } zmm_sum_0 = _mm512_add_ps(zmm_sum_0, zmm_sum_1); if (lhs != last) { __mmask16 mask = (__mmask16)((1 << (last - lhs)) - 1); __m512 zmm_undefined = _mm512_undefined_ps(); __m512 zmm_lhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, lhs); __m512 zmm_rhs_0 = _mm512_mask_loadu_ps(zmm_undefined, mask, rhs); FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_rhs_0, zmm_sum_0, mask); FMA_MASK_FP32_AVX512(zmm_lhs_0, zmm_lhs_0, zmm_sum_norm1, mask); FMA_MASK_FP32_AVX512(zmm_rhs_0, zmm_rhs_0, zmm_sum_norm2, mask); } *sql = HorizontalAdd_FP32_V512(zmm_sum_norm1); *sqr = HorizontalAdd_FP32_V512(zmm_sum_norm2); return HorizontalAdd_FP32_V512(zmm_sum_0); } float MipsEuclideanDistanceSphericalInjectionFp32AVX512(const float *lhs, const float *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; if (size > 15) { sum = InnerProductAndSquaredNormFp32AVX512(lhs, rhs, size, &u2, &v2); } else if (size > 7) { sum = InnerProductAndSquaredNormFp32AVX(lhs, rhs, size, &u2, &v2); } else { sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); } return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX512( const float *lhs, const float *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; if (size > 15) { sum = InnerProductAndSquaredNormFp32AVX512(lhs, rhs, size, &u2, &v2); } else if (size > 7) { sum = InnerProductAndSquaredNormFp32AVX(lhs, rhs, size, &u2, &v2); } else { sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); } sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX512F__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp32_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) float InnerProductAndSquaredNormFp32NEON(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr); #endif #if defined(__AVX512F__) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX512( const float *lhs, const float *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp32AVX512(const float *lhs, const float *rhs, size_t size, float e2); #endif #if defined(__AVX__) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX( const float *lhs, const float *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp32AVX(const float *lhs, const float *rhs, size_t size, float e2); #endif #if defined(__SSE__) float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32SSE( const float *lhs, const float *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp32SSE(const float *lhs, const float *rhs, size_t size, float e2); #endif float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32Scalar( const float *p, const float *q, size_t dim, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionFp32Scalar(const float *p, const float *q, size_t dim, float e2); float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { #if __ARM_NEON float u2{0.0f}; float v2{0.0f}; float sum = InnerProductAndSquaredNormFp32NEON(p, q, dim, &u2, &v2); *out = ComputeSphericalInjection(sum, u2, v2, e2); return; #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MipsEuclideanDistanceSphericalInjectionFp32AVX512(p, q, dim, e2); return; } #endif //__AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MipsEuclideanDistanceSphericalInjectionFp32AVX(p, q, dim, e2); return; } #endif // __AVX__ #if defined(__SSE__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE) { *out = MipsEuclideanDistanceSphericalInjectionFp32SSE(p, q, dim, e2); return; } #endif // __SSE__ *out = MipsEuclideanDistanceSphericalInjectionFp32Scalar(p, q, dim, e2); return; #endif //__ARM_NEON } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { #if defined(__ARM_NEON) float u2{0.0f}; float v2{0.0f}; float sum = InnerProductAndSquaredNormFp32NEON(p, q, dim, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } *out = sum; return; #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX512(p, q, dim, m, e2); return; } #endif //__AVX512F__ #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp32AVX(p, q, dim, m, e2); return; } #endif // __AVX__ #if defined(__SSE__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp32SSE(p, q, dim, m, e2); return; } #endif //__SSE__ *out = MipsEuclideanDistanceRepeatedQuadraticInjectionFp32Scalar(p, q, dim, m, e2); return; #endif //__ARM_NEON } // Sparse #if defined(__SSE4_1__) float MipsInnerProductSparseInSegmentSSE(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value); #endif template <> float MipsSquaredEuclideanSparseDistanceMatrix:: ComputeInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const ValueType *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const ValueType *q_sparse_value) { #if defined(__SSE4_1__) return MipsInnerProductSparseInSegmentSSE(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); #else return MipsInnerProductSparseInSegment(m_sparse_count, m_sparse_index, m_sparse_value, q_sparse_count, q_sparse_index, q_sparse_value); #endif } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp32_neon.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__ARM_NEON) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp32NEON(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); float32x4_t v_sum_0 = vdupq_n_f32(0); float32x4_t v_sum_1 = vdupq_n_f32(0); float32x4_t v_sum_norm1 = vdupq_n_f32(0); float32x4_t v_sum_norm2 = vdupq_n_f32(0); for (; lhs != last_aligned; lhs += 8, rhs += 8) { float32x4_t v_lhs_0 = vld1q_f32(lhs + 0); float32x4_t v_lhs_1 = vld1q_f32(lhs + 4); float32x4_t v_rhs_0 = vld1q_f32(rhs + 0); float32x4_t v_rhs_1 = vld1q_f32(rhs + 4); v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); v_sum_1 = vfmaq_f32(v_sum_1, v_lhs_1, v_rhs_1); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_1, v_lhs_1); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_1, v_rhs_1); } if (last >= last_aligned + 4) { float32x4_t v_lhs_0 = vld1q_f32(lhs); float32x4_t v_rhs_0 = vld1q_f32(rhs); v_sum_0 = vfmaq_f32(v_sum_0, v_lhs_0, v_rhs_0); v_sum_norm1 = vfmaq_f32(v_sum_norm1, v_lhs_0, v_lhs_0); v_sum_norm2 = vfmaq_f32(v_sum_norm2, v_rhs_0, v_rhs_0); lhs += 4; rhs += 4; } float result = vaddvq_f32(vaddq_f32(v_sum_0, v_sum_1)); float norm1 = vaddvq_f32(v_sum_norm1); float norm2 = vaddvq_f32(v_sum_norm2); switch (last - lhs) { case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } #endif //__ARM_NEON } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_fp32_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_fp32.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormFp32SSE(const float *lhs, const float *rhs, size_t size, float *sql, float *sqr) { const float *last = lhs + size; const float *last_aligned = lhs + ((size >> 3) << 3); __m128 xmm_sum = _mm_setzero_ps(); __m128 xmm_sum_norm1 = _mm_setzero_ps(); __m128 xmm_sum_norm2 = _mm_setzero_ps(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_lhs_0 = _mm_load_ps(lhs + 0); __m128 xmm_lhs_1 = _mm_load_ps(lhs + 4); __m128 xmm_rhs_0 = _mm_load_ps(rhs + 0); __m128 xmm_rhs_1 = _mm_load_ps(rhs + 4); xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); } if (last >= last_aligned + 4) { __m128 xmm_lhs_0 = _mm_load_ps(lhs); __m128 xmm_rhs_0 = _mm_load_ps(rhs); xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m128 xmm_lhs_0 = _mm_loadu_ps(lhs + 0); __m128 xmm_lhs_1 = _mm_loadu_ps(lhs + 4); __m128 xmm_rhs_0 = _mm_loadu_ps(rhs + 0); __m128 xmm_rhs_1 = _mm_loadu_ps(rhs + 4); xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); xmm_sum = _mm_fmadd_ps(xmm_lhs_1, xmm_rhs_1, xmm_sum); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); } if (last >= last_aligned + 4) { __m128 xmm_lhs_0 = _mm_loadu_ps(lhs); __m128 xmm_rhs_0 = _mm_loadu_ps(rhs); xmm_sum = _mm_fmadd_ps(xmm_lhs_0, xmm_rhs_0, xmm_sum); xmm_sum_norm1 = _mm_fmadd_ps(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); xmm_sum_norm2 = _mm_fmadd_ps(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); lhs += 4; rhs += 4; } } float result = HorizontalAdd_FP32_V128(xmm_sum); float norm1 = HorizontalAdd_FP32_V128(xmm_sum_norm1); float norm2 = HorizontalAdd_FP32_V128(xmm_sum_norm2); switch (last - lhs) { case 3: FMA_FP32_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_FP32_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_FP32_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionFp32SSE(const float *lhs, const float *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32SSE( const float *lhs, const float *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormFp32SSE(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __SSE__ // #if 1 #if defined(__SSE4_1__) const static __m128i SHUFFLE_MASK16[16] = { _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 7, 6, 5, 4, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, -127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 3, 2, 1, 0), _mm_set_epi8(-127, -127, -127, -127, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4), _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), }; constexpr uint32_t MAX_SPARSE_BUFFER_LENGTH = 65536; float MipsInnerProductSparseInSegmentSSE(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { float sum = 0.0f; // size_t alloc_size = 0; size_t i1 = 0, i2 = 0; size_t end1 = m_sparse_count / 8 * 8; size_t end2 = q_sparse_count / 8 * 8; // std::vector mem1; // std::vector mem2; float fixed_buffer_1[MAX_SPARSE_BUFFER_LENGTH]; float fixed_buffer_2[MAX_SPARSE_BUFFER_LENGTH]; float *val_start_1 = fixed_buffer_1; float *val_start_2 = fixed_buffer_2; // uint32_t max_count = std::max(m_sparse_count, q_sparse_count); // if (MAX_SPARSE_BUFFER_LENGTH < max_count) { // mem1.reserve(max_count); // mem2.reserve(max_count); // val_start_1 = mem1.data(); // val_start_2 = mem2.data(); // } float *val_1 = val_start_1; float *val_2 = val_start_2; if (i1 < end1 && i2 < end2) { while (m_sparse_index[i1 + 7] < q_sparse_index[i2]) { i1 += 8; if (i1 >= end1) goto do_scalar; } while (q_sparse_index[i2 + 7] < m_sparse_index[i1]) { i2 += 8; if (i2 >= end2) goto do_scalar; } __m128i mm_index_m = _mm_loadu_si128(reinterpret_cast(&m_sparse_index[i1])); __m128i mm_index_q = _mm_loadu_si128(reinterpret_cast(&q_sparse_index[i2])); while (true) { #ifdef DEBUG_PRINT std::cout << "index 1: " << std::endl; print_data16(&mm_index_m); std::cout << "index 2: " << std::endl; print_data16(&mm_index_q); #endif __m128i mm_cmp_res = _mm_cmpistrm(mm_index_q, mm_index_m, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); #ifdef DEBUG_PRINT std::cout << "cmp res: " << std::endl; print_data16(&mm_cmp_res); #endif int r = _mm_extract_epi32(mm_cmp_res, 0); if (r) { int r1 = r & 15; __m128i v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1])); __m128 vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); _mm_storeu_ps(val_1, vs); val_1 += _mm_popcnt_u32(r1); int r2 = (r >> 4) & 15; v = _mm_loadu_si128( reinterpret_cast(&m_sparse_value[i1 + 4])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); _mm_storeu_ps(val_1, vs); val_1 += _mm_popcnt_u32(r2); mm_cmp_res = _mm_cmpistrm( mm_index_m, mm_index_q, _SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK); r = _mm_extract_epi32(mm_cmp_res, 0); r1 = r & 15; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r1])); _mm_storeu_ps(val_2, vs); val_2 += _mm_popcnt_u32(r1); r2 = (r >> 4) & 15; v = _mm_loadu_si128( reinterpret_cast(&q_sparse_value[i2 + 4])); vs = _mm_castsi128_ps(_mm_shuffle_epi8(v, SHUFFLE_MASK16[r2])); _mm_storeu_ps(val_2, vs); val_2 += _mm_popcnt_u32(r2); } const uint16_t id1_max = m_sparse_index[i1 + 7]; if (id1_max <= q_sparse_index[i2 + 7]) { i1 += 8; if (i1 >= end1) goto do_scalar; mm_index_m = _mm_loadu_si128( reinterpret_cast(&m_sparse_index[i1])); } if (id1_max >= q_sparse_index[i2 + 7]) { i2 += 8; if (i2 >= end2) goto do_scalar; mm_index_q = _mm_loadu_si128( reinterpret_cast(&q_sparse_index[i2])); } } } do_scalar: while (i1 < m_sparse_count && i2 < q_sparse_count) { if (m_sparse_index[i1] == q_sparse_index[i2]) { *val_1++ = m_sparse_value[i1]; *val_2++ = q_sparse_value[i2]; ++i1; ++i2; } else if (m_sparse_index[i1] < q_sparse_index[i2]) { ++i1; } else { ++i2; } } size_t res_num = val_1 - val_start_1; // if (res_num != val_2 - val_start_2) { // std::cerr << "size mismatch!" << std::endl; // } size_t res_num4 = res_num / 4 * 4; if (res_num4) { __m128 sum128 = _mm_set1_ps(0); for (size_t k = 0; k < res_num4; k += 4) { sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(val_start_1 + k), _mm_loadu_ps(val_start_2 + k))); } float __attribute__((aligned(16))) tmp_res[4]; _mm_store_ps(tmp_res, sum128); sum += (tmp_res[0] + tmp_res[1] + tmp_res[2] + tmp_res[3]); } for (size_t k = res_num4; k < res_num; ++k) sum += val_start_1[k] * val_start_2[k]; return sum; } #else float MipsInnerProductSparseInSegment(uint32_t m_sparse_count, const uint16_t *m_sparse_index, const float *m_sparse_value, uint32_t q_sparse_count, const uint16_t *q_sparse_index, const float *q_sparse_value) { float sum = 0.0f; size_t m_i = 0; size_t q_i = 0; while (m_i < m_sparse_count && q_i < q_sparse_count) { if (m_sparse_index[m_i] == q_sparse_index[q_i]) { sum += m_sparse_value[m_i] * q_sparse_value[q_i]; ++m_i; ++q_i; } else if (m_sparse_index[m_i] < q_sparse_index[q_i]) { ++m_i; } else { ++q_i; } } return sum; } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int4_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_mips_utility.i" #include "inner_product_matrix.h" #include "mips_euclidean_distance_matrix.h" #include "norm_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size, float *sql, float *sqr) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 5) << 5); __m256i ymm_sum_0 = _mm256_setzero_si256(); __m256i ymm_sum_1 = _mm256_setzero_si256(); __m256i ymm_sum_norm1 = _mm256_setzero_si256(); __m256i ymm_sum_norm2 = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)(rhs)); FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, ymm_sum_norm2) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); __m128i xmm_sum_norm1 = _mm_setzero_si128(); __m128i xmm_sum_norm2 = _mm_setzero_si128(); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) ymm_sum_0 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); ymm_sum_norm1 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); ymm_sum_norm2 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)(lhs)); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)(rhs)); FMA_INT4_ITER_AVX(ymm_lhs, ymm_rhs, ymm_sum_0, ymm_sum1, ymm_sum_norm1, ymm_sum_norm2) } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); __m128i xmm_sum = _mm_setzero_si128(); __m128i xmm_sum_norm1 = _mm_setzero_si128(); __m128i xmm_sum_norm2 = _mm_setzero_si128(); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) ymm_sum_0 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum), ymm_sum_0); ymm_sum_norm1 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm1), ymm_sum_norm1); ymm_sum_norm2 = _mm256_add_epi32( _mm256_set_m128i(_mm_setzero_si128(), xmm_sum_norm2), ymm_sum_norm2); lhs += 16; rhs += 16; } } float result = static_cast( HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); switch (last - lhs) { case 15: FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) /* FALLTHRU */ case 14: FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) /* FALLTHRU */ case 13: FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) /* FALLTHRU */ case 12: FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) /* FALLTHRU */ case 11: FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) /* FALLTHRU */ case 10: FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) /* FALLTHRU */ case 9: FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) /* FALLTHRU */ case 8: FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) /* FALLTHRU */ case 7: FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) /* FALLTHRU */ case 6: FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) /* FALLTHRU */ case 5: FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) /* FALLTHRU */ case 4: FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) /* FALLTHRU */ case 3: FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt4AVX2(lhs, rhs, size >> 1, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4AVX2( const uint8_t *lhs, const uint8_t *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt4AVX2(lhs, rhs, size >> 1, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int4_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "inner_product_matrix.h" #include "mips_euclidean_distance_matrix.h" #include "norm_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4AVX2( const uint8_t *lhs, const uint8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt4AVX2(const uint8_t *lhs, const uint8_t *rhs, size_t size, float e2); #endif #if defined(__SSE4_1__) float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4SSE( const uint8_t *lhs, const uint8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size, float e2); #endif float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4Scalar( const uint8_t *lhs, const uint8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt4Scalar(const uint8_t *lhs, const uint8_t *rhs, size_t size, float e2); //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MipsEuclideanDistanceSphericalInjectionInt4AVX2(p, q, dim, e2); return; } #endif #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MipsEuclideanDistanceSphericalInjectionInt4SSE(p, q, dim, e2); return; } #endif *out = MipsEuclideanDistanceSphericalInjectionInt4Scalar(p, q, dim, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt4AVX2(p, q, dim, m, e2); return; } #endif #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt4SSE(p, q, dim, m, e2); return; } #endif *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt4Scalar(p, q, dim, m, e2); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int4_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_mips_utility.i" #include "inner_product_matrix.h" #include "mips_euclidean_distance_matrix.h" #include "norm_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE4_1__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size, float *sql, float *sqr) { const uint8_t *last = lhs + size; const uint8_t *last_aligned = lhs + ((size >> 4) << 4); __m128i xmm_sum = _mm_setzero_si128(); __m128i xmm_sum_norm1 = _mm_setzero_si128(); __m128i xmm_sum_norm2 = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_load_si128((const __m128i *)(rhs)); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) } } else { for (; lhs != last_aligned; lhs += 16, rhs += 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)(lhs)); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)(rhs)); FMA_INT4_ITER_SSE(xmm_lhs, xmm_rhs, xmm_sum, xmm_sum_norm1, xmm_sum_norm2) } } float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); switch (last - lhs) { case 15: FMA_INT4_GENERAL(lhs[14], rhs[14], result, norm1, norm2) /* FALLTHRU */ case 14: FMA_INT4_GENERAL(lhs[13], rhs[13], result, norm1, norm2) /* FALLTHRU */ case 13: FMA_INT4_GENERAL(lhs[12], rhs[12], result, norm1, norm2) /* FALLTHRU */ case 12: FMA_INT4_GENERAL(lhs[11], rhs[11], result, norm1, norm2) /* FALLTHRU */ case 11: FMA_INT4_GENERAL(lhs[10], rhs[10], result, norm1, norm2) /* FALLTHRU */ case 10: FMA_INT4_GENERAL(lhs[9], rhs[9], result, norm1, norm2) /* FALLTHRU */ case 9: FMA_INT4_GENERAL(lhs[8], rhs[8], result, norm1, norm2) /* FALLTHRU */ case 8: FMA_INT4_GENERAL(lhs[7], rhs[7], result, norm1, norm2) /* FALLTHRU */ case 7: FMA_INT4_GENERAL(lhs[6], rhs[6], result, norm1, norm2) /* FALLTHRU */ case 6: FMA_INT4_GENERAL(lhs[5], rhs[5], result, norm1, norm2) /* FALLTHRU */ case 5: FMA_INT4_GENERAL(lhs[4], rhs[4], result, norm1, norm2) /* FALLTHRU */ case 4: FMA_INT4_GENERAL(lhs[3], rhs[3], result, norm1, norm2) /* FALLTHRU */ case 3: FMA_INT4_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_INT4_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_INT4_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionInt4SSE(const uint8_t *lhs, const uint8_t *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt4SSE(lhs, rhs, size >> 1, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4SSE( const uint8_t *lhs, const uint8_t *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt4SSE(lhs, rhs, size >> 1, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int8_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size, float *sql, float *sqr) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 6) << 6); __m256i ymm_sum_0 = _mm256_setzero_si256(); __m256i ymm_sum_1 = _mm256_setzero_si256(); __m256i ymm_sum_norm1 = _mm256_setzero_si256(); __m256i ymm_sum_norm2 = _mm256_setzero_si256(); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_load_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_load_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_load_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_load_si256((const __m256i *)(rhs + 32)); FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_load_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_load_si256((const __m256i *)rhs); FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); lhs += 32; rhs += 32; } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 64, rhs += 64) { __m256i ymm_lhs_0 = _mm256_loadu_si256((const __m256i *)(lhs + 0)); __m256i ymm_lhs_1 = _mm256_loadu_si256((const __m256i *)(lhs + 32)); __m256i ymm_rhs_0 = _mm256_loadu_si256((const __m256i *)(rhs + 0)); __m256i ymm_rhs_1 = _mm256_loadu_si256((const __m256i *)(rhs + 32)); FMA_INT8_AVX(ymm_lhs_0, ymm_rhs_0, ymm_sum_0); FMA_INT8_AVX(ymm_lhs_1, ymm_rhs_1, ymm_sum_1); FMA_INT8_AVX(ymm_lhs_0, ymm_lhs_0, ymm_sum_norm1); FMA_INT8_AVX(ymm_lhs_1, ymm_lhs_1, ymm_sum_norm1); FMA_INT8_AVX(ymm_rhs_0, ymm_rhs_0, ymm_sum_norm2); FMA_INT8_AVX(ymm_rhs_1, ymm_rhs_1, ymm_sum_norm2); } if (last >= last_aligned + 32) { __m256i ymm_lhs = _mm256_loadu_si256((const __m256i *)lhs); __m256i ymm_rhs = _mm256_loadu_si256((const __m256i *)rhs); FMA_INT8_AVX(ymm_lhs, ymm_rhs, ymm_sum_0); FMA_INT8_AVX(ymm_lhs, ymm_lhs, ymm_sum_norm1); FMA_INT8_AVX(ymm_rhs, ymm_rhs, ymm_sum_norm2); lhs += 32; rhs += 32; } if (last >= lhs + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_rhs, ymm_sum_0); FMA_INT8_AVX_SSE_HYBRID(xmm_lhs, xmm_lhs, ymm_sum_norm1); FMA_INT8_AVX_SSE_HYBRID(xmm_rhs, xmm_rhs, ymm_sum_norm2); lhs += 16; rhs += 16; } } float result = static_cast( HorizontalAdd_INT32_V256(_mm256_add_epi32(ymm_sum_0, ymm_sum_1))); float norm1 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm1)); float norm2 = static_cast(HorizontalAdd_INT32_V256(ymm_sum_norm2)); switch (last - lhs) { case 15: FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) /* FALLTHRU */ case 14: FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) /* FALLTHRU */ case 13: FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) /* FALLTHRU */ case 12: FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) /* FALLTHRU */ case 11: FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) /* FALLTHRU */ case 10: FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) /* FALLTHRU */ case 9: FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) /* FALLTHRU */ case 8: FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) /* FALLTHRU */ case 7: FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) /* FALLTHRU */ case 6: FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) /* FALLTHRU */ case 5: FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) /* FALLTHRU */ case 4: FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) /* FALLTHRU */ case 3: FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt8AVX2(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8AVX2( const int8_t *lhs, const int8_t *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt8AVX2(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __AVX2__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int8_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__AVX2__) float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8AVX2( const int8_t *lhs, const int8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt8AVX2(const int8_t *lhs, const int8_t *rhs, size_t size, float e2); #endif #if defined(__SSE4_1__) float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8SSE( const int8_t *lhs, const int8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size, float e2); #endif float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8Scalar( const int8_t *lhs, const int8_t *rhs, size_t size, size_t m, float e2); float MipsEuclideanDistanceSphericalInjectionInt8Scalar(const int8_t *lhs, const int8_t *rhs, size_t size, float e2); //! Compute the distance between matrix and query by SphericalInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, float e2, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MipsEuclideanDistanceSphericalInjectionInt8AVX2(p, q, dim, e2); return; } #endif #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MipsEuclideanDistanceSphericalInjectionInt8SSE(p, q, dim, e2); return; } #endif //__SSE4_1__ *out = MipsEuclideanDistanceSphericalInjectionInt8Scalar(p, q, dim, e2); } //! Compute the distance between matrix and query by RepeatedQuadraticInjection void MipsSquaredEuclideanDistanceMatrix::Compute( const ValueType *p, const ValueType *q, size_t dim, size_t m, float e2, float *out) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt8AVX2(p, q, dim, m, e2); return; } #endif #if defined(__SSE4_1__) if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt8SSE(p, q, dim, m, e2); return; } #endif //__SSE4_1__ *out = MipsEuclideanDistanceRepeatedQuadraticInjectionInt8Scalar(p, q, dim, m, e2); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_int8_sse.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "distance_matrix_accum_int8.i" #include "distance_matrix_mips_utility.i" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { #if defined(__SSE4_1__) //! Compute the Inner Product between p and q, and each Squared L2-Norm value float InnerProductAndSquaredNormInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size, float *sql, float *sqr) { const int8_t *last = lhs + size; const int8_t *last_aligned = lhs + ((size >> 5) << 5); __m128i xmm_sum = _mm_setzero_si128(); __m128i xmm_sum_norm1 = _mm_setzero_si128(); __m128i xmm_sum_norm2 = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_load_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_load_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_load_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_load_si128((const __m128i *)(rhs + 16)); FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_load_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_load_si128((const __m128i *)rhs); FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); lhs += 16; rhs += 16; } } else { for (; lhs != last_aligned; lhs += 32, rhs += 32) { __m128i xmm_lhs_0 = _mm_loadu_si128((const __m128i *)(lhs + 0)); __m128i xmm_lhs_1 = _mm_loadu_si128((const __m128i *)(lhs + 16)); __m128i xmm_rhs_0 = _mm_loadu_si128((const __m128i *)(rhs + 0)); __m128i xmm_rhs_1 = _mm_loadu_si128((const __m128i *)(rhs + 16)); FMA_INT8_SSE(xmm_lhs_0, xmm_rhs_0, xmm_sum); FMA_INT8_SSE(xmm_lhs_0, xmm_lhs_0, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs_0, xmm_rhs_0, xmm_sum_norm2); FMA_INT8_SSE(xmm_lhs_1, xmm_rhs_1, xmm_sum); FMA_INT8_SSE(xmm_lhs_1, xmm_lhs_1, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs_1, xmm_rhs_1, xmm_sum_norm2); } if (last >= last_aligned + 16) { __m128i xmm_lhs = _mm_loadu_si128((const __m128i *)lhs); __m128i xmm_rhs = _mm_loadu_si128((const __m128i *)rhs); FMA_INT8_SSE(xmm_lhs, xmm_rhs, xmm_sum); FMA_INT8_SSE(xmm_lhs, xmm_lhs, xmm_sum_norm1); FMA_INT8_SSE(xmm_rhs, xmm_rhs, xmm_sum_norm2); lhs += 16; rhs += 16; } } float result = static_cast(HorizontalAdd_INT32_V128(xmm_sum)); float norm1 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm1)); float norm2 = static_cast(HorizontalAdd_INT32_V128(xmm_sum_norm2)); switch (last - lhs) { case 15: FMA_INT8_GENERAL(lhs[14], rhs[14], result, norm1, norm2) /* FALLTHRU */ case 14: FMA_INT8_GENERAL(lhs[13], rhs[13], result, norm1, norm2) /* FALLTHRU */ case 13: FMA_INT8_GENERAL(lhs[12], rhs[12], result, norm1, norm2) /* FALLTHRU */ case 12: FMA_INT8_GENERAL(lhs[11], rhs[11], result, norm1, norm2) /* FALLTHRU */ case 11: FMA_INT8_GENERAL(lhs[10], rhs[10], result, norm1, norm2) /* FALLTHRU */ case 10: FMA_INT8_GENERAL(lhs[9], rhs[9], result, norm1, norm2) /* FALLTHRU */ case 9: FMA_INT8_GENERAL(lhs[8], rhs[8], result, norm1, norm2) /* FALLTHRU */ case 8: FMA_INT8_GENERAL(lhs[7], rhs[7], result, norm1, norm2) /* FALLTHRU */ case 7: FMA_INT8_GENERAL(lhs[6], rhs[6], result, norm1, norm2) /* FALLTHRU */ case 6: FMA_INT8_GENERAL(lhs[5], rhs[5], result, norm1, norm2) /* FALLTHRU */ case 5: FMA_INT8_GENERAL(lhs[4], rhs[4], result, norm1, norm2) /* FALLTHRU */ case 4: FMA_INT8_GENERAL(lhs[3], rhs[3], result, norm1, norm2) /* FALLTHRU */ case 3: FMA_INT8_GENERAL(lhs[2], rhs[2], result, norm1, norm2) /* FALLTHRU */ case 2: FMA_INT8_GENERAL(lhs[1], rhs[1], result, norm1, norm2) /* FALLTHRU */ case 1: FMA_INT8_GENERAL(lhs[0], rhs[0], result, norm1, norm2) } *sql = norm1; *sqr = norm2; return result; } float MipsEuclideanDistanceSphericalInjectionInt8SSE(const int8_t *lhs, const int8_t *rhs, size_t size, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt8SSE(lhs, rhs, size, &u2, &v2); return ComputeSphericalInjection(sum, u2, v2, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8SSE( const int8_t *lhs, const int8_t *rhs, size_t size, size_t m, float e2) { float u2{0.0f}; float v2{0.0f}; float sum{0.0f}; sum = InnerProductAndSquaredNormInt8SSE(lhs, rhs, size, &u2, &v2); sum = e2 * (u2 + v2 - 2 * sum); u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } #endif // __SSE4_1__ } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/mips_euclidean_distance_matrix_scalar.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "distance_utility.h" #include "mips_euclidean_distance_matrix.h" namespace zvec { namespace ailego { //-------------------------------------------------- // Dense //-------------------------------------------------- // Compute the distance between matrix and query by SphericalInjection template inline float MipsEuclideanDistanceSphericalInjectionScalar(const T *p, const T *q, size_t dim, float e2) { ailego_assert(p && q && dim); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < dim; ++i) { u2 += p[i] * p[i]; v2 += q[i] * q[i]; sum += static_cast(p[i] * q[i]); } return ComputeSphericalInjection(sum, u2, v2, e2); } // Compute the distance between matrix and query by RepeatedQuadraticInjection template inline float MipsEuclideanDistanceRepeatedQuadraticInjectionScalar( const T *p, const T *q, size_t dim, size_t m, float e2) { ailego_assert(p && q && dim); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < dim; ++i) { u2 += p[i] * p[i]; v2 += q[i] * q[i]; sum += MathHelper::SquaredDifference(p[i], q[i]); } sum *= e2; u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } /*! Mips Squared Euclidean Distance Matrix (INT4, M=1, N=1) */ //! Calculate sum of squared values static inline float Squared(uint8_t v) { return static_cast(((int8_t)(v << 4) >> 4) * ((int8_t)(v << 4) >> 4) + ((int8_t)(v & 0xf0) >> 4) * ((int8_t)(v & 0xf0) >> 4)); } // Compute the distance between matrix and query by SphericalInjection float MipsEuclideanDistanceSphericalInjectionInt4Scalar(const uint8_t *p, const uint8_t *q, size_t dim, float e2) { ailego_assert(p && q && dim && !(dim & 1)); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { const uint8_t p_val = p[i]; const uint8_t q_val = q[i]; u2 += Squared(p_val); v2 += Squared(q_val); sum += Int4MulTable[((p_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4MulTable[((p_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } return ComputeSphericalInjection(sum, u2, v2, e2); } // Compute the distance between matrix and query by RepeatedQuadraticInjection float MipsEuclideanDistanceRepeatedQuadraticInjectionInt4Scalar( const uint8_t *p, const uint8_t *q, size_t dim, size_t m, float e2) { ailego_assert(p && q && dim && !(dim & 1)); float sum = 0.0; float u2 = 0.0; float v2 = 0.0; for (size_t i = 0; i < (dim >> 1); ++i) { const uint8_t p_val = p[i]; const uint8_t q_val = q[i]; u2 += Squared(p_val); v2 += Squared(q_val); sum += Int4SquaredDiffTable[((p_val << 4) & 0xf0) | ((q_val >> 0) & 0xf)] + Int4SquaredDiffTable[((p_val >> 0) & 0xf0) | ((q_val >> 4) & 0xf)]; } sum *= e2; u2 *= e2; v2 *= e2; for (size_t i = 0; i < m; ++i) { sum += (u2 - v2) * (u2 - v2); u2 = u2 * u2; v2 = v2 * v2; } return sum; } float MipsEuclideanDistanceSphericalInjectionInt8Scalar(const int8_t *p, const int8_t *q, size_t dim, float e2) { return MipsEuclideanDistanceSphericalInjectionScalar(p, q, dim, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionInt8Scalar( const int8_t *p, const int8_t *q, size_t dim, size_t m, float e2) { return MipsEuclideanDistanceRepeatedQuadraticInjectionScalar( p, q, dim, m, e2); } float MipsEuclideanDistanceSphericalInjectionFp16Scalar( const ailego::Float16 *p, const ailego::Float16 *q, size_t dim, float e2) { return MipsEuclideanDistanceSphericalInjectionScalar( p, q, dim, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp16Scalar( const ailego::Float16 *p, const ailego::Float16 *q, size_t dim, size_t m, float e2) { return MipsEuclideanDistanceRepeatedQuadraticInjectionScalar( p, q, dim, m, e2); } float MipsEuclideanDistanceSphericalInjectionFp32Scalar(const float *p, const float *q, size_t dim, float e2) { return MipsEuclideanDistanceSphericalInjectionScalar(p, q, dim, e2); } float MipsEuclideanDistanceRepeatedQuadraticInjectionFp32Scalar( const float *p, const float *q, size_t dim, size_t m, float e2) { return MipsEuclideanDistanceRepeatedQuadraticInjectionScalar(p, q, dim, m, e2); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm1_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include namespace zvec { namespace ailego { /*! L1-Norm Matrix */ template struct Norm1Matrix; /*! L1-Norm Matrix */ template struct Norm1Matrix::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim * M; if (m != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = MathHelper::Absolute(m[i]); } m += M; } while (m != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += MathHelper::Absolute(m[i]); } m += M; } } }; /*! L1-Norm Matrix (INT8) */ template struct Norm1Matrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *m_end = m_it + (dim >> 2) * M; if (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = Absolute(m_it[i]); } m_it += M; } while (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += Absolute(m_it[i]); } m_it += M; } } protected: //! Calculate sum of absolute values static inline float Absolute(uint32_t v) { return static_cast( MathHelper::Absolute((int8_t)(v >> 0)) + MathHelper::Absolute((int8_t)(v >> 8)) + MathHelper::Absolute((int8_t)(v >> 16)) + MathHelper::Absolute((int8_t)(v >> 24))); } }; /*! L1-Norm Matrix (M=1) */ template struct Norm1Matrix< T, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim; if (m != m_end) { *out = MathHelper::Absolute(*m++); } while (m != m_end) { *out += MathHelper::Absolute(*m++); } } }; #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) /*! L1-Norm Matrix (FP32, M=1) */ template <> struct Norm1Matrix { //! Type of value using ValueType = float; //! Compute the L1-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; #endif // __SSE__ || (__ARM_NEON && __aarch64__) #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) /*! L1-Norm Matrix (FP16, M=1) */ template <> struct Norm1Matrix { //! Type of value using ValueType = Float16; //! Compute the L1-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm1_matrix_fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "ailego/internal/cpu_features.h" #include "norm1_matrix.h" #include "norm_matrix_fp16.i" namespace zvec { namespace ailego { #define NORM_FP32_STEP_GENERAL SA_FP32_GENERAL #define NORM_FP32_STEP_SSE SA_FP32_SSE #define NORM_FP32_STEP_AVX SA_FP32_AVX #define NORM_FP32_STEP_AVX512 SA_FP32_AVX512 #define NORM_FP32_STEP_NEON SA_FP32_NEON #define NORM_FP16_STEP_GENERAL SA_FP16_GENERAL #define NORM_FP16_STEP_NEON SA_FP16_NEON #if defined(__SSE__) static const __m128 ABS_MASK_FP32_SSE = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffu)); #endif // __SSE__ #if defined(__AVX__) static const __m256 ABS_MASK_FP32_AVX = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffu)); #endif // __AVX__ #if defined(__AVX512F__) static const __m512 ABS_MASK_FP32_AVX512 = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)); #endif // __AVX512F__ //! Calculate sum of absolute (GENERAL) #define SA_FP32_GENERAL(m, sum) sum += FastAbs(m); //! Calculate sum of absolute (SSE) #define SA_FP32_SSE(xmm_m, xmm_sum) \ xmm_sum = _mm_add_ps(_mm_and_ps(xmm_m, ABS_MASK_FP32_SSE), xmm_sum); //! Calculate sum of absolute (AVX) #define SA_FP32_AVX(ymm_m, ymm_sum) \ ymm_sum = _mm256_add_ps(_mm256_and_ps(ymm_m, ABS_MASK_FP32_AVX), ymm_sum); //! Calculate sum of absolute (AVX512) #define SA_FP32_AVX512(zmm_m, zmm_sum) \ zmm_sum = _mm512_add_ps(_mm512_and_ps(zmm_m, ABS_MASK_FP32_AVX512), zmm_sum); //! Calculate sum of absolute (NEON) #define SA_FP32_NEON(v_m, v_sum) v_sum = vaddq_f32(vabsq_f32(v_m), v_sum); //! Calculate sum of absolute (GENERAL) #define SA_FP16_GENERAL(m, sum) sum += Float16::Absolute(m); //! Calculate sum of absolute (NEON) #define SA_FP16_NEON(v_m, v_sum) v_sum = vaddq_f16(vabsq_f16(v_m), v_sum); #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the L1-norm of vectors (FP16, M=1) void Norm1Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, ) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP16_1_AVX512(m, dim, out, ) return; } #endif NORM_FP16_1_AVX(m, dim, out, ) #endif } #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm1_matrix_fp32.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "ailego/internal/cpu_features.h" #include "norm1_matrix.h" #include "norm_matrix_fp32.i" namespace zvec { namespace ailego { #define NORM_FP32_STEP_GENERAL SA_FP32_GENERAL #define NORM_FP32_STEP_SSE SA_FP32_SSE #define NORM_FP32_STEP_AVX SA_FP32_AVX #define NORM_FP32_STEP_AVX512 SA_FP32_AVX512 #define NORM_FP32_STEP_NEON SA_FP32_NEON #if defined(__SSE__) #define ABS_MASK_FP32_SSE _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffu)) #endif // __SSE__ #if defined(__AVX__) #define ABS_MASK_FP32_AVX _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffu)) #endif // __AVX__ #if defined(__AVX512F__) #define ABS_MASK_FP32_AVX512 _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffffu)) #endif // __AVX512F__ //! Calculate sum of absolute (GENERAL) #define SA_FP32_GENERAL(m, sum) sum += FastAbs(m); //! Calculate sum of absolute (SSE) #define SA_FP32_SSE(xmm_m, xmm_sum) \ xmm_sum = _mm_add_ps(_mm_and_ps(xmm_m, ABS_MASK_FP32_SSE), xmm_sum); //! Calculate sum of absolute (AVX) #define SA_FP32_AVX(ymm_m, ymm_sum) \ ymm_sum = _mm256_add_ps(_mm256_and_ps(ymm_m, ABS_MASK_FP32_AVX), ymm_sum); //! Calculate sum of absolute (AVX512) #define SA_FP32_AVX512(zmm_m, zmm_sum) \ zmm_sum = _mm512_add_ps(_mm512_and_ps(zmm_m, ABS_MASK_FP32_AVX512), zmm_sum); //! Calculate sum of absolute (NEON) #define SA_FP32_NEON(v_m, v_sum) v_sum = vaddq_f32(vabsq_f32(v_m), v_sum); #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the L1-norm of vectors (FP32, M=1) void Norm1Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, ) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP32_1_AVX512(m, dim, out, ) return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { NORM_FP32_1_AVX(m, dim, out, ) return; } #endif NORM_FP32_1_SSE(m, dim, out, ) #endif } #endif // __SSE__ || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm2_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include namespace zvec { namespace ailego { /*! L2-Norm Matrix */ template struct Norm2Matrix; /*! L2-Norm Matrix */ template struct Norm2Matrix::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim * M; if (m != m_end) { for (size_t i = 0; i < M; ++i) { ValueType v = m[i]; *(out + i) = static_cast(v * v); } m += M; } while (m != m_end) { for (size_t i = 0; i < M; ++i) { ValueType v = m[i]; *(out + i) += static_cast(v * v); } m += M; } for (size_t i = 0; i < M; ++i) { float v = *out; *out++ = std::sqrt(v); } } }; /*! L2-Norm Matrix (INT8) */ template struct Norm2Matrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *m_end = m_it + (dim >> 2) * M; if (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = Squared(m_it[i]); } m_it += M; } while (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += Squared(m_it[i]); } m_it += M; } for (size_t i = 0; i < M; ++i) { float v = *out; *out++ = std::sqrt(v); } } protected: //! Calculate sum of squared values static inline float Squared(uint32_t v) { return static_cast((int8_t)(v >> 0) * (int8_t)(v >> 0) + (int8_t)(v >> 8) * (int8_t)(v >> 8) + (int8_t)(v >> 16) * (int8_t)(v >> 16) + (int8_t)(v >> 24) * (int8_t)(v >> 24)); } }; /*! L2-Norm Matrix (M=1) */ template struct Norm2Matrix< T, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim; if (m != m_end) { ValueType v = *m++; *out = static_cast(v * v); } while (m != m_end) { ValueType v = *m++; *out += static_cast(v * v); } *out = std::sqrt(*out); } }; /*! L2-Norm Matrix (M=1, INT4) */ template <> struct Norm2Matrix { //! Type of value using ValueType = uint8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && !(dim & 1) && dim && out); const uint8_t *m_end = m + (dim >> 1); float square = 0.0f; while (m != m_end) { square += Squared(*m++); } *out = std::sqrt(square); } protected: //! Calculate sum of squared values static inline float Squared(uint8_t v) { return static_cast( ((int8_t)(v << 4) >> 4) * ((int8_t)(v << 4) >> 4) + ((int8_t)(v & 0xf0) >> 4) * ((int8_t)(v & 0xf0) >> 4)); } }; /*! L2-Norm Matrix (INT4) */ template struct Norm2Matrix= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *m_end = m_it + (dim >> 3) * M; if (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = Squared(m_it[i]); } m_it += M; } while (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += Squared(m_it[i]); } m_it += M; } for (size_t i = 0; i < M; ++i) { float v = *out; *out++ = std::sqrt(v); } } protected: //! Calculate sum of squared values static inline float Squared(uint32_t u) { float sum = 0.0f; for (size_t i = 0; i < 32; i += 8) { uint8_t v = (uint8_t)(u >> i); int8_t lo = (int8_t)(v << 4) >> 4; int8_t hi = (int8_t)(v & 0xf0) >> 4; sum += hi * hi + lo * lo; } return sum; } }; /*! Squared L2-Norm Matrix */ template struct SquaredNorm2Matrix; /*! Squared L2-Norm Matrix */ template struct SquaredNorm2Matrix< T, M, typename std::enable_if::value && sizeof(T) >= 2 && M >= 2>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim * M; if (m != m_end) { for (size_t i = 0; i < M; ++i) { ValueType v = m[i]; *(out + i) = static_cast(v * v); } m += M; } while (m != m_end) { for (size_t i = 0; i < M; ++i) { ValueType v = m[i]; *(out + i) += static_cast(v * v); } m += M; } } }; /*! Squared L2-Norm Matrix (INT8) */ template struct SquaredNorm2Matrix= 2>::type> { //! Type of value using ValueType = int8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && !(dim & 3) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *m_end = m_it + (dim >> 2) * M; if (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = Squared(m_it[i]); } m_it += M; } while (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += Squared(m_it[i]); } m_it += M; } } protected: //! Calculate sum of squared values static inline float Squared(uint32_t v) { return static_cast((int8_t)(v >> 0) * (int8_t)(v >> 0) + (int8_t)(v >> 8) * (int8_t)(v >> 8) + (int8_t)(v >> 16) * (int8_t)(v >> 16) + (int8_t)(v >> 24) * (int8_t)(v >> 24)); } }; /*! Squared L2-Norm Matrix (M=1) */ template struct SquaredNorm2Matrix< T, 1, typename std::enable_if::value>::type> { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && out); const ValueType *m_end = m + dim; if (m != m_end) { ValueType v = *m++; *out = static_cast(v * v); } while (m != m_end) { ValueType v = *m++; *out += static_cast(v * v); } } }; /*! L2-Norm Matrix (M=1, INT4) */ template <> struct SquaredNorm2Matrix { //! Type of value using ValueType = uint8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && !(dim & 1) && out); const uint8_t *m_end = m + (dim >> 1); *out = 0.0f; while (m != m_end) { *out += Squared(*m++); } } protected: //! Calculate sum of squared values static inline float Squared(uint8_t v) { return static_cast( ((int8_t)(v << 4) >> 4) * ((int8_t)(v << 4) >> 4) + ((int8_t)(v & 0xf0) >> 4) * ((int8_t)(v & 0xf0) >> 4)); } }; /*! Squared L2-Norm Matrix (INT4) */ template struct SquaredNorm2Matrix= 2>::type> { //! Type of value using ValueType = uint8_t; //! Compute the norm of vectors static inline void Compute(const ValueType *m, size_t dim, float *out) { ailego_assert(m && dim && !(dim & 7) && out); const uint32_t *m_it = reinterpret_cast(m); const uint32_t *m_end = m_it + (dim >> 3) * M; if (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) = Squared(m_it[i]); } m_it += M; } while (m_it != m_end) { for (size_t i = 0; i < M; ++i) { *(out + i) += Squared(m_it[i]); } m_it += M; } } protected: //! Calculate sum of squared values static inline float Squared(uint32_t u) { float sum = 0.0f; for (size_t i = 0; i < 32; i += 8) { uint8_t v = (uint8_t)(u >> i); int8_t lo = (int8_t)(v << 4) >> 4; int8_t hi = (int8_t)(v & 0xf0) >> 4; sum += hi * hi + lo * lo; } return sum; } }; #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) /*! L2-Norm Matrix (FP32, M=1) */ template <> struct Norm2Matrix { //! Type of value using ValueType = float; //! Compute the L2-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; /*! Squared L2-Norm Matrix (FP32, M=1) */ template <> struct SquaredNorm2Matrix { //! Type of value using ValueType = float; //! Compute the squared L2-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; #endif // __SSE__ || (__ARM_NEON && __aarch64__) #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) /*! L2-Norm Matrix (FP16, M=1) */ template <> struct Norm2Matrix { //! Type of value using ValueType = Float16; //! Compute the L2-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; /*! Squared L2-Norm Matrix (FP16, M=1) */ template <> struct SquaredNorm2Matrix { //! Type of value using ValueType = Float16; //! Compute the squared L2-norm of vectors static void Compute(const ValueType *m, size_t dim, float *out); }; #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm2_matrix_fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "ailego/internal/cpu_features.h" #include "norm2_matrix.h" #include "norm_matrix_fp16.i" namespace zvec { namespace ailego { #define NORM_FP32_STEP_GENERAL SS_FP32_GENERAL #define NORM_FP32_STEP_SSE SS_FP32_SSE #define NORM_FP32_STEP_AVX SS_FP32_AVX #define NORM_FP32_STEP_AVX512 SS_FP32_AVX512 #define NORM_FP32_STEP_NEON SS_FP32_NEON #define NORM_FP16_STEP_GENERAL SS_FP16_GENERAL #define NORM_FP16_STEP_NEON SS_FP16_NEON //! Calculate sum of squared (GENERAL) #define SS_FP32_GENERAL(m, sum) sum += (m) * (m); //! Calculate sum of squared (SSE) #define SS_FP32_SSE(xmm_m, xmm_sum) \ xmm_sum = _mm_fmadd_ps(xmm_m, xmm_m, xmm_sum); //! Calculate sum of squared (AVX) #define SS_FP32_AVX(ymm_m, ymm_sum) \ ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_m, ymm_sum); //! Calculate sum of squared (AVX512) #define SS_FP32_AVX512(zmm_m, zmm_sum) \ zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_m, zmm_sum); //! Calculate sum of squared (NEON) #define SS_FP32_NEON(v_m, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_m); //! Calculate sum of squared (GENERAL) #define SS_FP16_GENERAL(m, sum) sum += (m) * (m); //! Calculate sum of squared (NEON) #define SS_FP16_NEON(v_m, v_sum) v_sum = vfmaq_f16(v_sum, v_m, v_m); #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the L2-norm of vectors (FP16, M=1) void Norm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, std::sqrt) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP16_1_AVX512(m, dim, out, std::sqrt) return; } #endif NORM_FP16_1_AVX(m, dim, out, std::sqrt) #endif } //! Compute the L2-norm of vectors (FP16, M=1) void SquaredNorm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP16_1_NEON(m, dim, out, ) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP16_1_AVX512(m, dim, out, ) return; } #endif NORM_FP16_1_AVX(m, dim, out, ) #endif } #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm2_matrix_fp32.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "norm2_matrix.h" #include "norm_matrix_fp32.i" namespace zvec { namespace ailego { #define NORM_FP32_STEP_GENERAL SS_FP32_GENERAL #define NORM_FP32_STEP_SSE SS_FP32_SSE #define NORM_FP32_STEP_AVX SS_FP32_AVX #define NORM_FP32_STEP_AVX512 SS_FP32_AVX512 #define NORM_FP32_STEP_NEON SS_FP32_NEON //! Calculate sum of squared (GENERAL) #define SS_FP32_GENERAL(m, sum) sum += (m) * (m); //! Calculate sum of squared (SSE) #define SS_FP32_SSE(xmm_m, xmm_sum) \ xmm_sum = _mm_fmadd_ps(xmm_m, xmm_m, xmm_sum); //! Calculate sum of squared (AVX) #define SS_FP32_AVX(ymm_m, ymm_sum) \ ymm_sum = _mm256_fmadd_ps(ymm_m, ymm_m, ymm_sum); //! Calculate sum of squared (AVX512) #define SS_FP32_AVX512(zmm_m, zmm_sum) \ zmm_sum = _mm512_fmadd_ps(zmm_m, zmm_m, zmm_sum); //! Calculate sum of squared (NEON) #define SS_FP32_NEON(v_m, v_sum) v_sum = vfmaq_f32(v_sum, v_m, v_m); #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the L2-norm of vectors (FP32, M=1) void Norm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, std::sqrt) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP32_1_AVX512(m, dim, out, std::sqrt) return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { NORM_FP32_1_AVX(m, dim, out, std::sqrt) return; } #endif NORM_FP32_1_SSE(m, dim, out, std::sqrt) #endif } //! Compute the squared L2-norm of vectors (FP32, M=1) void SquaredNorm2Matrix::Compute(const ValueType *m, size_t dim, float *out) { #if defined(__ARM_NEON) NORM_FP32_1_NEON(m, dim, out, ) #else #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { NORM_FP32_1_AVX512(m, dim, out, ) return; } #endif #if defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { NORM_FP32_1_AVX(m, dim, out, ) return; } #endif NORM_FP32_1_SSE(m, dim, out, ) #endif } #endif // __SSE__ || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/norm_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "norm1_matrix.h" #include "norm2_matrix.h" ================================================ FILE: src/ailego/math/norm_matrix_fp16.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "matrix_define.i" #include "matrix_utility.i" #if !defined(__FMA__) #define _mm_fmadd_ps(a, b, c) _mm_add_ps(_mm_mul_ps((a), (b)), (c)) #define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(_mm256_mul_ps((a), (b)), (c)) #endif // !__FMA__ //! Mask process of computing norm (FP16) #define NORM_FP16_MASK_AVX(m, cnt, _RES) \ switch (cnt) { \ case 7: { \ __m256 ymm_m = _mm256_cvtph_ps( \ _mm_set_epi16(0, *((const short *)(m) + 6), \ *((const short *)(m) + 5), *((const short *)(m) + 4), \ *((const short *)(m) + 3), *((const short *)(m) + 2), \ *((const short *)(m) + 1), *((const short *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 6: { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_set_epi32(0, *((const int *)(m) + 2), \ *((const int *)(m) + 1), \ *((const int *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 5: { \ __m256 ymm_m = _mm256_cvtph_ps( \ _mm_set_epi16(0, 0, 0, *((const short *)(m) + 4), \ *((const short *)(m) + 3), *((const short *)(m) + 2), \ *((const short *)(m) + 1), *((const short *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 4: { \ __m256 ymm_m = _mm256_cvtph_ps( \ _mm_set_epi64((__m64)(0ull), *((const __m64 *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 3: { \ __m256 ymm_m = _mm256_cvtph_ps( \ _mm_set_epi16(0, 0, 0, 0, 0, *((const short *)(m) + 2), \ *((const short *)(m) + 1), *((const short *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 2: { \ __m256 ymm_m = \ _mm256_cvtph_ps(_mm_set_epi32(0, 0, 0, *((const int *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ case 1: { \ __m256 ymm_m = _mm256_cvtph_ps( \ _mm_set_epi16(0, 0, 0, 0, 0, 0, 0, *((const short *)(m)))); \ NORM_FP32_STEP_AVX(ymm_m, _RES##_0_0) \ break; \ } \ } //! Compute the norm of vectors (FP16, M=1) #define NORM_FP16_1_AVX(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const Float16 *last = m + dim; \ const Float16 *last_aligned = m + ((dim >> 4) << 4); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; m != last_aligned; m += 16) { \ __m256i ymm_mi = _mm256_load_si256((const __m256i *)m); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ NORM_FP32_STEP_AVX(ymm_m_0, ymm_sum_0_0) \ NORM_FP32_STEP_AVX(ymm_m_1, ymm_sum_0_0) \ } \ if (last >= last_aligned + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_load_si128((const __m128i *)m)); \ NORM_FP32_STEP_AVX(ymm_m, ymm_sum_0_0) \ m += 8; \ } \ } else { \ for (; m != last_aligned; m += 16) { \ __m256i ymm_mi = _mm256_loadu_si256((const __m256i *)m); \ __m256 ymm_m_0 = _mm256_cvtph_ps(_mm256_castsi256_si128(ymm_mi)); \ __m256 ymm_m_1 = _mm256_cvtph_ps(_mm256_extractf128_si256(ymm_mi, 1)); \ NORM_FP32_STEP_AVX(ymm_m_0, ymm_sum_0_0) \ NORM_FP32_STEP_AVX(ymm_m_1, ymm_sum_0_0) \ } \ if (last >= last_aligned + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)m)); \ NORM_FP32_STEP_AVX(ymm_m, ymm_sum_0_0) \ m += 8; \ } \ } \ NORM_FP16_MASK_AVX(m, (last - m), ymm_sum) \ *out = _NORM(HorizontalAdd_FP32_V256(ymm_sum_0_0)); //! Compute the norm of vectors (FP16, M=1) #define NORM_FP16_1_AVX512(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ const Float16 *last = m + dim; \ const Float16 *last_aligned = m + ((dim >> 5) << 5); \ if (((uintptr_t)m & 0x3f) == 0) { \ for (; m != last_aligned; m += 32) { \ __m512i zmm_mi = _mm512_load_si512((const __m512i *)m); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ NORM_FP32_STEP_AVX512(zmm_m_0, zmm_sum_0_0) \ NORM_FP32_STEP_AVX512(zmm_m_1, zmm_sum_0_1) \ } \ if (last >= last_aligned + 16) { \ __m512 zmm_m = _mm512_cvtph_ps(_mm256_load_si256((const __m256i *)m)); \ NORM_FP32_STEP_AVX512(zmm_m, zmm_sum_0_0) \ m += 16; \ } \ } else { \ for (; m != last_aligned; m += 32) { \ __m512i zmm_mi = _mm512_loadu_si512((const __m512i *)m); \ __m512 zmm_m_0 = _mm512_cvtph_ps(_mm512_castsi512_si256(zmm_mi)); \ __m512 zmm_m_1 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(zmm_mi, 1)); \ NORM_FP32_STEP_AVX512(zmm_m_0, zmm_sum_0_0) \ NORM_FP32_STEP_AVX512(zmm_m_1, zmm_sum_0_1) \ } \ if (last >= last_aligned + 16) { \ __m512 zmm_m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)m)); \ NORM_FP32_STEP_AVX512(zmm_m, zmm_sum_0_0) \ m += 16; \ } \ } \ float result = \ HorizontalAdd_FP32_V512(_mm512_add_ps(zmm_sum_0_0, zmm_sum_0_1)); \ if (m != last) { \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ if (last >= m + 8) { \ __m256 ymm_m = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)m)); \ NORM_FP32_STEP_AVX(ymm_m, ymm_sum_0_0) \ m += 8; \ } \ NORM_FP16_MASK_AVX(m, (last - m), ymm_sum) \ result += HorizontalAdd_FP32_V256(ymm_sum_0_0); \ } \ *out = _NORM(result); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) //! Compute the norm of vectors (FP16, M=1) #define NORM_FP16_1_NEON(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, float16x8_t, v_sum, vdupq_n_f16(0)) \ const Float16 *last = m + dim; \ const Float16 *last_aligned = m + ((dim >> 3) << 3); \ for (; m != last_aligned; m += 8) { \ float16x8_t v_m = vld1q_f16((const float16_t *)m); \ NORM_FP16_STEP_NEON(v_m, v_sum_0_0) \ } \ if (last >= m + 4) { \ float16x8_t v_m = vreinterpretq_f16_u64( \ vld1q_lane_u64((const uint64_t *)m, vdupq_n_u64(0), 0)); \ NORM_FP16_STEP_NEON(v_m, v_sum_0_0) \ m += 4; \ } \ float result = vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(v_sum_0_0)), \ vcvt_high_f32_f16(v_sum_0_0))); \ switch (last - m) { \ case 3: \ NORM_FP16_STEP_GENERAL(m[2], result) \ /* FALLTHRU */ \ case 2: \ NORM_FP16_STEP_GENERAL(m[1], result) \ /* FALLTHRU */ \ case 1: \ NORM_FP16_STEP_GENERAL(m[0], result) \ } \ *out = _NORM(result); #else //! Compute the norm of vectors (FP16, M=1) #define NORM_FP16_1_NEON(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ const Float16 *last = m + dim; \ const Float16 *last_aligned = m + ((dim >> 3) << 3); \ for (; m != last_aligned; m += 8) { \ float16x8_t v_m = vld1q_f16((const float16_t *)m); \ float32x4_t v_n_0 = vcvt_f32_f16(vget_low_f16(v_m)); \ float32x4_t v_n_1 = vcvt_high_f32_f16(v_m); \ NORM_FP32_STEP_NEON(v_n_0, v_sum_0_0) \ NORM_FP32_STEP_NEON(v_n_1, v_sum_0_1) \ } \ if (last >= m + 4) { \ float32x4_t v_m = vcvt_f32_f16(vld1_f16((const float16_t *)m)); \ NORM_FP32_STEP_NEON(v_m, v_sum_0_0) \ m += 4; \ } \ float result = vaddvq_f32(vaddq_f32(v_sum_0_0, v_sum_0_1)); \ switch (last - m) { \ case 3: \ NORM_FP16_STEP_GENERAL(m[2], result) \ /* FALLTHRU */ \ case 2: \ NORM_FP16_STEP_GENERAL(m[1], result) \ /* FALLTHRU */ \ case 1: \ NORM_FP16_STEP_GENERAL(m[0], result) \ } \ *out = _NORM(result); #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ================================================ FILE: src/ailego/math/norm_matrix_fp32.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "matrix_define.i" #include "matrix_utility.i" #if !defined(__FMA__) #define _mm_fmadd_ps(a, b, c) _mm_add_ps(_mm_mul_ps((a), (b)), (c)) #define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(_mm256_mul_ps((a), (b)), (c)) #endif // !__FMA__ //! Mask process of computing norm (FP32) #define NORM_FP32_MASK_SSE(m, cnt, _RES) \ switch (cnt) { \ case 3: { \ __m128 xmm_m = _mm_set_ps(0.0f, m[2], m[1], m[0]); \ NORM_FP32_STEP_SSE(xmm_m, _RES##_0_0) \ break; \ } \ case 2: { \ __m128 xmm_m = _mm_set_ps(0.0f, 0.0f, m[1], m[0]); \ NORM_FP32_STEP_SSE(xmm_m, _RES##_0_0) \ break; \ } \ case 1: { \ __m128 xmm_m = _mm_set_ps(0.0f, 0.0f, 0.0f, m[0]); \ NORM_FP32_STEP_SSE(xmm_m, _RES##_0_0) \ break; \ } \ } //! Compute the norm of vectors (FP32, M=1) #define NORM_FP32_1_SSE(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m128, xmm_sum, _mm_setzero_ps()) \ const float *last = m + dim; \ const float *last_aligned = m + ((dim >> 3) << 3); \ if (((uintptr_t)m & 0xf) == 0) { \ for (; m != last_aligned; m += 8) { \ __m128 xmm_m_0 = _mm_load_ps(m + 0); \ __m128 xmm_m_1 = _mm_load_ps(m + 4); \ NORM_FP32_STEP_SSE(xmm_m_0, xmm_sum_0_0) \ NORM_FP32_STEP_SSE(xmm_m_1, xmm_sum_0_0) \ } \ if (last >= last_aligned + 4) { \ __m128 xmm_m = _mm_load_ps(m); \ NORM_FP32_STEP_SSE(xmm_m, xmm_sum_0_0) \ m += 4; \ } \ } else { \ for (; m != last_aligned; m += 8) { \ __m128 xmm_m_0 = _mm_loadu_ps(m + 0); \ __m128 xmm_m_1 = _mm_loadu_ps(m + 4); \ NORM_FP32_STEP_SSE(xmm_m_0, xmm_sum_0_0) \ NORM_FP32_STEP_SSE(xmm_m_1, xmm_sum_0_0) \ } \ if (last >= last_aligned + 4) { \ __m128 xmm_m = _mm_loadu_ps(m); \ NORM_FP32_STEP_SSE(xmm_m, xmm_sum_0_0) \ m += 4; \ } \ } \ NORM_FP32_MASK_SSE(m, (last - m), xmm_sum) \ *out = _NORM(HorizontalAdd_FP32_V128(xmm_sum_0_0)); //! Compute the norm of vectors (FP32, M=1) #define NORM_FP32_1_AVX(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 1, __m256, ymm_sum, _mm256_setzero_ps()) \ const float *last = m + dim; \ const float *last_aligned = m + ((dim >> 4) << 4); \ if (((uintptr_t)m & 0x1f) == 0) { \ for (; m != last_aligned; m += 16) { \ __m256 ymm_m_0 = _mm256_load_ps(m + 0); \ __m256 ymm_m_1 = _mm256_load_ps(m + 8); \ NORM_FP32_STEP_AVX(ymm_m_0, ymm_sum_0_0) \ NORM_FP32_STEP_AVX(ymm_m_1, ymm_sum_0_0) \ } \ if (last >= last_aligned + 8) { \ __m256 ymm_m = _mm256_load_ps(m); \ NORM_FP32_STEP_AVX(ymm_m, ymm_sum_0_0) \ m += 8; \ } \ } else { \ for (; m != last_aligned; m += 16) { \ __m256 ymm_m_0 = _mm256_loadu_ps(m + 0); \ __m256 ymm_m_1 = _mm256_loadu_ps(m + 8); \ NORM_FP32_STEP_AVX(ymm_m_0, ymm_sum_0_0) \ NORM_FP32_STEP_AVX(ymm_m_1, ymm_sum_0_0) \ } \ if (last >= last_aligned + 8) { \ __m256 ymm_m = _mm256_loadu_ps(m); \ NORM_FP32_STEP_AVX(ymm_m, ymm_sum_0_0) \ m += 8; \ } \ } \ float result = HorizontalAdd_FP32_V256(ymm_sum_0_0); \ if (m != last) { \ __m128 xmm_sum_0_0 = _mm_setzero_ps(); \ if (last >= m + 4) { \ __m128 xmm_m = _mm_loadu_ps(m); \ NORM_FP32_STEP_SSE(xmm_m, xmm_sum_0_0) \ m += 4; \ } \ NORM_FP32_MASK_SSE(m, (last - m), xmm_sum) \ result += HorizontalAdd_FP32_V128(xmm_sum_0_0); \ } \ *out = _NORM(result); //! Compute the norm of vectors (FP32, M=1) #define NORM_FP32_1_AVX512(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, __m512, zmm_sum, _mm512_setzero_ps()) \ const float *last = m + dim; \ const float *last_aligned = m + ((dim >> 5) << 5); \ if (((uintptr_t)m & 0x3f) == 0) { \ for (; m != last_aligned; m += 32) { \ __m512 zmm_m_0 = _mm512_load_ps(m + 0); \ NORM_FP32_STEP_AVX512(zmm_m_0, zmm_sum_0_0) \ __m512 zmm_m_1 = _mm512_load_ps(m + 16); \ NORM_FP32_STEP_AVX512(zmm_m_1, zmm_sum_0_1) \ } \ if (last >= last_aligned + 16) { \ __m512 zmm_m = _mm512_load_ps(m); \ NORM_FP32_STEP_AVX512(zmm_m, zmm_sum_0_0) \ m += 16; \ } \ } else { \ for (; m != last_aligned; m += 32) { \ __m512 zmm_m_0 = _mm512_loadu_ps(m + 0); \ NORM_FP32_STEP_AVX512(zmm_m_0, zmm_sum_0_0) \ __m512 zmm_m_1 = _mm512_loadu_ps(m + 16); \ NORM_FP32_STEP_AVX512(zmm_m_1, zmm_sum_0_1) \ } \ if (last >= last_aligned + 16) { \ __m512 zmm_m = _mm512_loadu_ps(m); \ NORM_FP32_STEP_AVX512(zmm_m, zmm_sum_0_0) \ m += 16; \ } \ } \ if (m != last) { \ __mmask16 mask = (__mmask16)((1 << (last - m)) - 1); \ __m512 zmm_m = _mm512_mask_loadu_ps(_mm512_setzero_ps(), mask, m); \ NORM_FP32_STEP_AVX512(zmm_m, zmm_sum_0_0) \ } \ float result = \ HorizontalAdd_FP32_V512(_mm512_add_ps(zmm_sum_0_0, zmm_sum_0_1)); \ *out = _NORM(result); //! Compute the norm of vectors (FP32, M=1) #define NORM_FP32_1_NEON(m, dim, out, _NORM) \ MATRIX_VAR_INIT(1, 2, float32x4_t, v_sum, vdupq_n_f32(0)) \ const float *last = m + dim; \ const float *last_aligned = m + ((dim >> 3) << 3); \ for (; m != last_aligned; m += 8) { \ float32x4_t v_m_0 = vld1q_f32(m + 0); \ float32x4_t v_m_1 = vld1q_f32(m + 4); \ NORM_FP32_STEP_NEON(v_m_0, v_sum_0_0) \ NORM_FP32_STEP_NEON(v_m_1, v_sum_0_1) \ } \ if (last >= last_aligned + 4) { \ float32x4_t v_m = vld1q_f32(m); \ NORM_FP32_STEP_NEON(v_m, v_sum_0_0) \ m += 4; \ } \ float result = vaddvq_f32(vaddq_f32(v_sum_0_0, v_sum_0_1)); \ switch (last - m) { \ case 3: \ NORM_FP32_STEP_GENERAL(m[2], result) \ /* FALLTHRU */ \ case 2: \ NORM_FP32_STEP_GENERAL(m[1], result) \ /* FALLTHRU */ \ case 1: \ NORM_FP32_STEP_GENERAL(m[0], result) \ } \ *out = _NORM(result); ================================================ FILE: src/ailego/math/normalizer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "normalizer.h" namespace zvec { namespace ailego { #if (defined(__ARM_NEON) && defined(__aarch64__)) static inline void NormalizeNEON(float *arr, size_t dim, float norm) { float *last = arr + dim; float *last_aligned = arr + ((dim >> 3) << 3); float32x4_t v_norm = vdupq_n_f32(norm); for (; arr != last_aligned; arr += 8) { vst1q_f32(arr + 0, vdivq_f32(vld1q_f32(arr + 0), v_norm)); vst1q_f32(arr + 4, vdivq_f32(vld1q_f32(arr + 4), v_norm)); } if (last >= last_aligned + 4) { vst1q_f32(arr, vdivq_f32(vld1q_f32(arr), v_norm)); arr += 4; } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) static inline void NormalizeNEON(float16_t *arr, size_t dim, float norm) { float16_t *last = arr + dim; float16_t *last_aligned = arr + ((dim >> 4) << 4); float16x8_t v_norm = vdupq_n_f16((float16_t)norm); for (; arr != last_aligned; arr += 16) { vst1q_f16(arr + 0, vdivq_f16(vld1q_f16(arr + 0), v_norm)); vst1q_f16(arr + 8, vdivq_f16(vld1q_f16(arr + 8), v_norm)); } if (last >= arr + 8) { vst1q_f16(arr, vdivq_f16(vld1q_f16(arr), v_norm)); arr += 8; } if (last >= arr + 4) { vst1_f16(arr, vdiv_f16(vld1_f16(arr), vget_low_f16(v_norm))); arr += 4; } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #else static inline void NormalizeNEON(float16_t *arr, size_t dim, float norm) { float16_t *last = arr + dim; float16_t *last_aligned = arr + ((dim >> 4) << 4); float32x4_t v_norm = vdupq_n_f32(norm); for (; arr != last_aligned; arr += 16) { float16x8_t vf16_0 = vld1q_f16(arr + 0); float16x8_t vf16_1 = vld1q_f16(arr + 8); vf16_0 = vcombine_f16( vcvt_f16_f32(vdivq_f32(vcvt_f32_f16(vget_low_f16(vf16_0)), v_norm)), vcvt_f16_f32(vdivq_f32(vcvt_high_f32_f16(vf16_0), v_norm))); vf16_1 = vcombine_f16( vcvt_f16_f32(vdivq_f32(vcvt_f32_f16(vget_low_f16(vf16_1)), v_norm)), vcvt_f16_f32(vdivq_f32(vcvt_high_f32_f16(vf16_1), v_norm))); vst1q_f16(arr + 0, vf16_0); vst1q_f16(arr + 8, vf16_1); } if (last >= arr + 8) { float16x8_t vf16 = vld1q_f16(arr); vf16 = vcombine_f16( vcvt_f16_f32(vdivq_f32(vcvt_f32_f16(vget_low_f16(vf16)), v_norm)), vcvt_f16_f32(vdivq_f32(vcvt_high_f32_f16(vf16), v_norm))); vst1q_f16(arr, vf16); arr += 8; } if (last >= arr + 4) { vst1_f16(arr, vcvt_f16_f32(vdivq_f32(vcvt_f32_f16(vld1_f16(arr)), v_norm))); arr += 4; } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif // __ARM_NEON && __aarch64__ #if defined(__AVX__) #if defined(__AVX512F__) static inline void NormalizeAVX512(float *arr, size_t dim, float norm) { float *last = arr + dim; float *last_aligned = arr + ((dim >> 4) << 4); __m512 zmm_norm = _mm512_set1_ps(norm); if (((uintptr_t)arr & 0x3f) == 0) { for (; arr != last_aligned; arr += 16) { _mm512_store_ps(arr, _mm512_div_ps(_mm512_load_ps(arr), zmm_norm)); } if (last >= arr + 8) { __m256 ymm_norm = _mm256_set1_ps(norm); _mm256_store_ps(arr, _mm256_div_ps(_mm256_load_ps(arr), ymm_norm)); arr += 8; } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_store_ps(arr, _mm_div_ps(_mm_load_ps(arr), xmm_norm)); arr += 4; } } else { for (; arr != last_aligned; arr += 16) { _mm512_storeu_ps(arr, _mm512_div_ps(_mm512_loadu_ps(arr), zmm_norm)); } if (last >= arr + 8) { __m256 ymm_norm = _mm256_set1_ps(norm); _mm256_storeu_ps(arr, _mm256_div_ps(_mm256_loadu_ps(arr), ymm_norm)); arr += 8; } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_storeu_ps(arr, _mm_div_ps(_mm_loadu_ps(arr), xmm_norm)); arr += 4; } } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #endif // __AVX512F__ static inline void NormalizeAVX(float *arr, size_t dim, float norm) { float *last = arr + dim; float *last_aligned = arr + ((dim >> 4) << 4); __m256 ymm_norm = _mm256_set1_ps(norm); if (((uintptr_t)arr & 0x1f) == 0) { for (; arr != last_aligned; arr += 16) { _mm256_store_ps(arr + 0, _mm256_div_ps(_mm256_load_ps(arr + 0), ymm_norm)); _mm256_store_ps(arr + 8, _mm256_div_ps(_mm256_load_ps(arr + 8), ymm_norm)); } if (last >= arr + 8) { _mm256_store_ps(arr, _mm256_div_ps(_mm256_load_ps(arr), ymm_norm)); arr += 8; } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_store_ps(arr, _mm_div_ps(_mm_load_ps(arr), xmm_norm)); arr += 4; } } else { for (; arr != last_aligned; arr += 16) { _mm256_storeu_ps(arr + 0, _mm256_div_ps(_mm256_loadu_ps(arr + 0), ymm_norm)); _mm256_storeu_ps(arr + 8, _mm256_div_ps(_mm256_loadu_ps(arr + 8), ymm_norm)); } if (last >= arr + 8) { _mm256_storeu_ps(arr, _mm256_div_ps(_mm256_loadu_ps(arr), ymm_norm)); arr += 8; } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_storeu_ps(arr, _mm_div_ps(_mm_loadu_ps(arr), xmm_norm)); arr += 4; } } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #endif // __AVX__ #if defined(__AVX__) && defined(__F16C__) #if defined(__AVX512F__) static inline void NormalizeAVX512(uint16_t *arr, size_t dim, float norm) { uint16_t *last = arr + dim; uint16_t *last_aligned = arr + ((dim >> 4) << 4); __m512 zmm_norm = _mm512_set1_ps(norm); if (((uintptr_t)arr & 0x1f) == 0) { for (; arr != last_aligned; arr += 16) { _mm256_store_si256( (__m256i *)arr, _mm512_cvtps_ph(_mm512_div_ps(_mm512_cvtph_ps(_mm256_load_si256( (const __m256i *)arr)), zmm_norm), _MM_FROUND_NO_EXC)); } if (last >= arr + 8) { __m256 ymm_norm = _mm256_set1_ps(norm); _mm_store_si128( (__m128i *)arr, _mm256_cvtps_ph(_mm256_div_ps(_mm256_cvtph_ps(_mm_load_si128( (const __m128i *)arr)), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; } } else { for (; arr != last_aligned; arr += 16) { _mm256_storeu_si256( (__m256i *)arr, _mm512_cvtps_ph(_mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256( (const __m256i *)arr)), zmm_norm), _MM_FROUND_NO_EXC)); } if (last >= arr + 8) { __m256 ymm_norm = _mm256_set1_ps(norm); _mm_storeu_si128( (__m128i *)arr, _mm256_cvtps_ph(_mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128( (const __m128i *)arr)), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; } } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_storel_epi64( (__m128i *)arr, _mm_cvtps_ph( _mm_div_ps(_mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)arr)), xmm_norm), _MM_FROUND_NO_EXC)); arr += 8; } switch (last - arr) { case 3: arr[2] = _cvtss_sh(_cvtsh_ss(arr[2]) / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 2: arr[1] = _cvtss_sh(_cvtsh_ss(arr[1]) / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 1: arr[0] = _cvtss_sh(_cvtsh_ss(arr[0]) / norm, _MM_FROUND_NO_EXC); } } #endif // __AVX512F__ static inline void NormalizeAVX(uint16_t *arr, size_t dim, float norm) { uint16_t *last = arr + dim; uint16_t *last_aligned = arr + ((dim >> 4) << 4); __m256 ymm_norm = _mm256_set1_ps(norm); if (((uintptr_t)arr & 0xf) == 0) { for (; arr != last_aligned; arr += 16) { __m128i xmm_0 = _mm_load_si128((const __m128i *)(arr + 0)); __m128i xmm_1 = _mm_load_si128((const __m128i *)(arr + 8)); __m256 ymm_0 = _mm256_div_ps(_mm256_cvtph_ps(xmm_0), ymm_norm); __m256 ymm_1 = _mm256_div_ps(_mm256_cvtph_ps(xmm_1), ymm_norm); _mm_store_si128((__m128i *)(arr + 0), _mm256_cvtps_ph(ymm_0, _MM_FROUND_NO_EXC)); _mm_store_si128((__m128i *)(arr + 8), _mm256_cvtps_ph(ymm_1, _MM_FROUND_NO_EXC)); } if (last >= arr + 8) { _mm_store_si128( (__m128i *)arr, _mm256_cvtps_ph(_mm256_div_ps(_mm256_cvtph_ps(_mm_load_si128( (const __m128i *)arr)), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; } } else { for (; arr != last_aligned; arr += 16) { __m128i xmm_0 = _mm_loadu_si128((const __m128i *)(arr + 0)); __m128i xmm_1 = _mm_loadu_si128((const __m128i *)(arr + 8)); __m256 ymm_0 = _mm256_div_ps(_mm256_cvtph_ps(xmm_0), ymm_norm); __m256 ymm_1 = _mm256_div_ps(_mm256_cvtph_ps(xmm_1), ymm_norm); _mm_storeu_si128((__m128i *)(arr + 0), _mm256_cvtps_ph(ymm_0, _MM_FROUND_NO_EXC)); _mm_storeu_si128((__m128i *)(arr + 8), _mm256_cvtps_ph(ymm_1, _MM_FROUND_NO_EXC)); } if (last >= arr + 8) { _mm_storeu_si128( (__m128i *)arr, _mm256_cvtps_ph(_mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128( (const __m128i *)arr)), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; } } if (last >= arr + 4) { __m128 xmm_norm = _mm_set1_ps(norm); _mm_storel_epi64( (__m128i *)arr, _mm_cvtps_ph( _mm_div_ps(_mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)arr)), xmm_norm), _MM_FROUND_NO_EXC)); arr += 8; } switch (last - arr) { case 3: arr[2] = _cvtss_sh(_cvtsh_ss(arr[2]) / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 2: arr[1] = _cvtss_sh(_cvtsh_ss(arr[1]) / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 1: arr[0] = _cvtss_sh(_cvtsh_ss(arr[0]) / norm, _MM_FROUND_NO_EXC); } } #endif // __AVX__ && __F16C__ #if defined(__SSE__) static inline void NormalizeSSE(float *arr, size_t dim, float norm) { float *last = arr + dim; float *last_aligned = arr + ((dim >> 3) << 3); __m128 xmm_norm = _mm_set1_ps(norm); if (((uintptr_t)arr & 0xf) == 0) { for (; arr != last_aligned; arr += 8) { _mm_store_ps(arr + 0, _mm_div_ps(_mm_load_ps(arr + 0), xmm_norm)); _mm_store_ps(arr + 4, _mm_div_ps(_mm_load_ps(arr + 4), xmm_norm)); } if (last >= last_aligned + 4) { _mm_store_ps(arr, _mm_div_ps(_mm_load_ps(arr), xmm_norm)); arr += 4; } } else { for (; arr != last_aligned; arr += 8) { _mm_storeu_ps(arr + 0, _mm_div_ps(_mm_loadu_ps(arr + 0), xmm_norm)); _mm_storeu_ps(arr + 4, _mm_div_ps(_mm_loadu_ps(arr + 4), xmm_norm)); } if (last >= last_aligned + 4) { _mm_storeu_ps(arr, _mm_div_ps(_mm_loadu_ps(arr), xmm_norm)); arr += 4; } } switch (last - arr) { case 3: arr[2] /= norm; /* FALLTHRU */ case 2: arr[1] /= norm; /* FALLTHRU */ case 1: arr[0] /= norm; } } #endif // __SSE__ #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the norm of vector void Normalizer::Compute(ValueType *arr, size_t dim, float norm) { #if defined(__ARM_NEON) NormalizeNEON(arr, dim, norm); #else #if defined(__AVX512F__) if (dim > 15) { NormalizeAVX512(arr, dim, norm); return; } #endif // __AVX512F__ #if defined(__AVX__) if (dim > 7) { NormalizeAVX(arr, dim, norm); return; } #endif // __AVX__ NormalizeSSE(arr, dim, norm); #endif // __ARM_NEON } #endif // __SSE__ || (__ARM_NEON && __aarch64__) #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) //! Compute the norm of vector void Normalizer::Compute(ValueType *arr, size_t dim, float norm) { #if defined(__ARM_NEON) NormalizeNEON(reinterpret_cast(arr), dim, norm); #else #if defined(__AVX512F__) if (dim > 31) { NormalizeAVX512(reinterpret_cast(arr), dim, norm); return; } #endif // __AVX512F__ NormalizeAVX(reinterpret_cast(arr), dim, norm); #endif // __ARM_NEON } #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math/normalizer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "norm_matrix.h" namespace zvec { namespace ailego { /*! Normalizer */ template ::value>::type> struct Normalizer { //! Type of value using ValueType = typename std::remove_cv::type; //! Compute the norm of vector static inline void Compute(ValueType *arr, size_t dim, float norm) { for (size_t i = 0; i < dim; ++i) { arr[i] /= norm; } } //! Normalize a vector (L1) static inline void L1(ValueType *arr, size_t dim, float *norm) { Norm1Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } //! Normalize a vector (L2) static inline void L2(ValueType *arr, size_t dim, float *norm) { Norm2Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } }; #if defined(__SSE__) || (defined(__ARM_NEON) && defined(__aarch64__)) /*! Normalizer (FP32) */ template <> struct Normalizer { //! Type of value using ValueType = float; //! Compute the norm of vector static void Compute(ValueType *arr, size_t dim, float norm); //! Normalize a vector (L1) static inline void L1(ValueType *arr, size_t dim, float *norm) { Norm1Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } //! Normalize a vector (L2) static inline void L2(ValueType *arr, size_t dim, float *norm) { Norm2Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } }; #endif // __SSE__ || (__ARM_NEON && __aarch64__) #if (defined(__F16C__) && defined(__AVX__)) || \ (defined(__ARM_NEON) && defined(__aarch64__)) /*! Normalizer (FP16) */ template <> struct Normalizer { //! Type of value using ValueType = Float16; //! Compute the norm of vector static void Compute(ValueType *arr, size_t dim, float norm); //! Normalize a vector (L1) static inline void L1(ValueType *arr, size_t dim, float *norm) { Norm1Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } //! Normalize a vector (L2) static inline void L2(ValueType *arr, size_t dim, float *norm) { Norm2Matrix::Compute(arr, dim, norm); if (*norm > 0.0f) { Compute(arr, dim, *norm); } } }; #endif // (__F16C__ && __AVX__) || (__ARM_NEON && __aarch64__) } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/math_batch/cosine_distance_batch.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "inner_product_distance_batch.h" namespace zvec::ailego::DistanceBatch { template struct CosineDistanceBatch; template struct CosineDistanceBatch { using ValueType = typename std::remove_cv::type; static inline void ComputeBatch(const ValueType **vecs, const ValueType *query, size_t num_vecs, size_t dim, float *results) { constexpr size_t extra_dim = sizeof(float) / sizeof(ValueType); size_t _dim = dim - extra_dim; InnerProductDistanceBatch::ComputeBatch( vecs, query, num_vecs, _dim, results); for (size_t i = 0; i < num_vecs; ++i) { results[i] = 1 - results[i]; } } using IPImplType = InnerProductDistanceBatch; static void QueryPreprocess(void *query, size_t dim) { return IPImplType::QueryPreprocess(query, dim - sizeof(float) / sizeof(ValueType)); } }; } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/distance_batch.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "ailego/math/distance_matrix.h" #include "cosine_distance_batch.h" #include "inner_product_distance_batch.h" namespace zvec::ailego { template < template class DistanceType, typename ValueType, size_t BatchSize, size_t PrefetchStep, typename = void> struct BaseDistance { static inline void _ComputeBatch(const ValueType **m, const ValueType *q, size_t num, size_t dim, float *out) { for (size_t i = 0; i < num; ++i) { DistanceType::Compute(m[i], q, dim, out + i); } } // If Distance has ComputeBatch, use it; otherwise fall back to _ComputeBatch. static inline void ComputeBatch(const ValueType **m, const ValueType *q, size_t num, size_t dim, float *out) { if constexpr (std::is_same_v, CosineDistanceMatrix>) { return DistanceBatch::CosineDistanceBatch< ValueType, BatchSize, PrefetchStep>::ComputeBatch(m, q, num, dim, out); } _ComputeBatch(m, q, num, dim, out); } }; } // namespace zvec::ailego ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include namespace zvec::ailego::DistanceBatch { template struct InnerProductDistanceBatch; template static void compute_one_to_many_inner_product_fallback( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { for (size_t j = 0; j < BatchSize; ++j) { sums[j] = 0.0; InnerProductMatrix::Compute(ptrs[j], query, dim, sums + j); ailego_prefetch(&prefetch_ptrs[j]); } } // Function template partial specialization is not allowed, // therefore the wrapper struct is required. template struct InnerProductDistanceBatchImpl { using ValueType = typename std::remove_cv::type; static void compute_one_to_many( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { return nullptr; } }; template struct InnerProductDistanceBatch { using ValueType = typename std::remove_cv::type; static inline void ComputeBatch(const ValueType **vecs, const ValueType *query, size_t num_vecs, size_t dim, float *results) { size_t i = 0; for (; i + BatchSize <= num_vecs; i += BatchSize) { std::array prefetch_ptrs; for (size_t j = 0; j < BatchSize; ++j) { if (i + j + BatchSize * PrefetchStep < num_vecs) { prefetch_ptrs[j] = vecs[i + j + BatchSize * PrefetchStep]; } else { prefetch_ptrs[j] = nullptr; } } InnerProductDistanceBatchImpl::compute_one_to_many( query, &vecs[i], prefetch_ptrs, dim, &results[i]); } for (; i < num_vecs; ++i) { // TODO: unroll by 1, 2, 4, 8, etc. std::array prefetch_ptrs{nullptr}; InnerProductDistanceBatchImpl::compute_one_to_many( query, &vecs[i], prefetch_ptrs, dim, &results[i]); } } static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc() { return InnerProductDistanceBatchImpl::GetQueryPreprocessFunc(); } }; template <> struct InnerProductDistanceBatchImpl { using ValueType = ailego::Float16; static void compute_one_to_many( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); }; template <> struct InnerProductDistanceBatchImpl { using ValueType = float; static void compute_one_to_many(const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); }; template <> struct InnerProductDistanceBatchImpl { using ValueType = int8_t; static void compute_one_to_many(const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); static DistanceBatchQueryPreprocessFunc GetQueryPreprocessFunc(); }; template <> struct InnerProductDistanceBatchImpl { using ValueType = ailego::Float16; static void compute_one_to_many( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); }; template <> struct InnerProductDistanceBatchImpl { using ValueType = float; static void compute_one_to_many(const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); }; template <> struct InnerProductDistanceBatchImpl { using ValueType = int8_t; static void compute_one_to_many(const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums); }; } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_dispatch.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "inner_product_distance_batch.h" namespace zvec::ailego::DistanceBatch { #if defined(__AVX512VNNI__) void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess( void *query, size_t dim); void compute_one_to_many_inner_product_avx512_vnni_int8_1( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx512_vnni_int8_12( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); #endif #if defined(__AVX512FP16__) void compute_one_to_many_inner_product_avx512fp16_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx512fp16_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); #endif //__AVX512FP16__ #if defined(__AVX512F__) void compute_one_to_many_inner_product_avx512f_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx512f_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); #endif //__AVX512F__ #if defined(__AVX2__) void compute_one_to_many_inner_product_avx2_fp32_1( const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx2_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx2_int8_1( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx2_fp32_12( const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx2_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); void compute_one_to_many_inner_product_avx2_int8_12( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results); #endif void InnerProductDistanceBatchImpl::compute_one_to_many( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_fp32_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } void InnerProductDistanceBatchImpl::compute_one_to_many( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { return compute_one_to_many_inner_product_avx512fp16_fp16_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return compute_one_to_many_inner_product_avx512f_fp16_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_fp16_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } void InnerProductDistanceBatchImpl::compute_one_to_many( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { // #if defined(__AVX512BW__) // TODO: this version is problematic // return compute_one_to_many_avx512_int8( // query, ptrs, prefetch_ptrs, dim, sums); #if defined(__AVX512VNNI__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { return compute_one_to_many_inner_product_avx512_vnni_int8_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_int8_1( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } DistanceBatchQueryPreprocessFunc InnerProductDistanceBatchImpl::GetQueryPreprocessFunc() { #if defined(__AVX512VNNI__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { return compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess; } #endif return nullptr; } void InnerProductDistanceBatchImpl::compute_one_to_many( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_fp32_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } void InnerProductDistanceBatchImpl::compute_one_to_many( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { #if defined(__AVX512FP16__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_FP16) { return compute_one_to_many_inner_product_avx512fp16_fp16_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return compute_one_to_many_inner_product_avx512f_fp16_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_fp16_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } void InnerProductDistanceBatchImpl::compute_one_to_many( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { // #if defined(__AVX512BW__) // TODO: this version is problematic // return compute_one_to_many_avx512_int8( // query, ptrs, prefetch_ptrs, dim, sums); #if defined(__AVX512VNNI__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX512_VNNI) { return compute_one_to_many_inner_product_avx512_vnni_int8_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif #if defined(__AVX2__) if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { return compute_one_to_many_inner_product_avx2_int8_12( query, ptrs, prefetch_ptrs, dim, sums); } #endif return compute_one_to_many_inner_product_fallback(query, ptrs, prefetch_ptrs, dim, sums); } } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX2__) template static std::enable_if_t, void> compute_one_to_many_inner_product_avx2_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m256 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_ps(); } size_t dim = 0; for (; dim + 16 <= dimensionality; dim += 16) { __m256i q = _mm256_loadu_si256(reinterpret_cast(query + dim)); __m256 q1 = _mm256_cvtph_ps(_mm256_castsi256_si128(q)); __m256 q2 = _mm256_cvtph_ps(_mm256_extractf128_si256(q, 1)); __m256 data_regs_1[dp_batch]; __m256 data_regs_2[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m256i m = _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim)); data_regs_1[i] = _mm256_cvtph_ps(_mm256_castsi256_si128(m)); data_regs_2[i] = _mm256_cvtph_ps(_mm256_extractf128_si256(m, 1)); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_fmadd_ps(q1, data_regs_1[i], accs[i]); accs[i] = _mm256_fmadd_ps(q2, data_regs_2[i], accs[i]); } } if (dim + 8 <= dimensionality) { __m256 q = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(query + dim))); __m256 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); accs[i] = _mm256_fmadd_ps(q, data_regs[i], accs[i]); } dim += 8; } for (size_t i = 0; i < dp_batch; ++i) { results[i] = HorizontalAdd_FP32_V256(accs[i]); } for (; dim < dimensionality; ++dim) { for (size_t i = 0; i < dp_batch; ++i) { results[i] += (*(query + dim)) * (*(ptrs[i] + dim)); } } } void compute_one_to_many_inner_product_avx2_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_fp16( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx2_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX512F__) template static std::enable_if_t, void> compute_one_to_many_inner_product_avx512f_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m512 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_ps(); } size_t dim = 0; for (; dim + 32 <= dimensionality; dim += 32) { __m512i q = _mm512_loadu_si512(reinterpret_cast(query + dim)); __m512 q1 = _mm512_cvtph_ps(_mm512_castsi512_si256(q)); __m512 q2 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(q, 1)); __m512 data_regs_1[dp_batch]; __m512 data_regs_2[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m512i m = _mm512_loadu_si512(reinterpret_cast(ptrs[i] + dim)); data_regs_1[i] = _mm512_cvtph_ps(_mm512_castsi512_si256(m)); data_regs_2[i] = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(m, 1)); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_fmadd_ps(q1, data_regs_1[i], accs[i]); accs[i] = _mm512_fmadd_ps(q2, data_regs_2[i], accs[i]); } } if (dim + 16 <= dimensionality) { __m512 q = _mm512_cvtph_ps( _mm256_loadu_si256(reinterpret_cast(query + dim))); __m512 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_cvtph_ps( _mm256_loadu_si256(reinterpret_cast(ptrs[i] + dim))); accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); } dim += 16; } __m256 acc_new[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { acc_new[i] = _mm256_add_ps( _mm512_castps512_ps256(accs[i]), _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(accs[i]), 1))); } if (dim + 8 <= dimensionality) { __m256 q = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(query + dim))); for (size_t i = 0; i < dp_batch; ++i) { __m256 m = _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ptrs[i] + dim))); acc_new[i] = _mm256_fmadd_ps(m, q, acc_new[i]); } dim += 8; } for (size_t i = 0; i < dp_batch; ++i) { results[i] = HorizontalAdd_FP32_V256(acc_new[i]); } for (; dim < dimensionality; ++dim) { for (size_t i = 0; i < dp_batch; ++i) { results[i] += (*(query + dim)) * (*(ptrs[i] + dim)); } } } void compute_one_to_many_inner_product_avx512f_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512f_fp16( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx512f_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512f_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_fp16_avx512fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX512FP16__) template static std::enable_if_t, void> compute_one_to_many_inner_product_avx512fp16_fp16( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m512h accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_ph(); } size_t dim = 0; for (; dim + 32 <= dimensionality; dim += 32) { __m512h q = _mm512_loadu_ph(query + dim); __m512h data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_loadu_ph(ptrs[i] + dim); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_fmadd_ph(data_regs[i], q, accs[i]); } } if (dim < dimensionality) { __mmask32 mask = (__mmask32)((1 << (dimensionality - dim)) - 1); for (size_t i = 0; i < dp_batch; ++i) { __m512i zmm_undefined = _mm512_undefined_epi32(); accs[i] = _mm512_mask3_fmadd_ph(_mm512_castsi512_ph(_mm512_mask_loadu_epi16( zmm_undefined, mask, query + dim)), _mm512_castsi512_ph(_mm512_mask_loadu_epi16( zmm_undefined, mask, ptrs[i] + dim)), accs[i], mask); } } for (size_t i = 0; i < dp_batch; ++i) { results[i] = HorizontalAdd_FP16_V512(accs[i]); } } void compute_one_to_many_inner_product_avx512fp16_fp16_1( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512fp16_fp16( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx512fp16_fp16_12( const ailego::Float16 *query, const ailego::Float16 **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512fp16_fp16( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_fp32_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX2__) inline float sum4(__m128 v) { v = _mm_add_ps(v, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(v), 8))); return v[0] + v[1]; } inline __m128 sum_top_bottom_avx(__m256 v) { const __m128 high = _mm256_extractf128_ps(v, 1); const __m128 low = _mm256_castps256_ps128(v); return _mm_add_ps(high, low); } template static std::enable_if_t, void> compute_one_to_many_inner_product_avx2_fp32( const ValueType *query, const ValueType **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m256 accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_ps(); } size_t dim = 0; for (; dim + 8 <= dimensionality; dim += 8) { __m256 q = _mm256_loadu_ps(query + dim); __m256 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_loadu_ps(ptrs[i] + dim); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_fnmadd_ps(q, data_regs[i], accs[i]); } } __m128 sum128_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { sum128_regs[i] = sum_top_bottom_avx(accs[i]); } if (dim + 4 <= dimensionality) { __m128 q = _mm_loadu_ps(query + dim); __m128 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm_loadu_ps(ptrs[i] + dim); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { sum128_regs[i] = _mm_fnmadd_ps(q, data_regs[i], sum128_regs[i]); } dim += 4; } if (dim + 2 <= dimensionality) { __m128 q = _mm_setzero_ps(); __m128 data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm_setzero_ps(); } q = _mm_loadh_pi(q, (const __m64 *)(query + dim)); for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm_loadh_pi(data_regs[i], (const __m64 *)(ptrs[i] + dim)); } for (size_t i = 0; i < dp_batch; ++i) { sum128_regs[i] = _mm_fnmadd_ps(q, data_regs[i], sum128_regs[i]); } dim += 2; } float res[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { res[i] = sum4(sum128_regs[i]); } if (dim < dimensionality) { float q = query[dim]; for (size_t i = 0; i < dp_batch; ++i) { res[i] -= q * ptrs[i][dim]; } } for (size_t i = 0; i < dp_batch; ++i) { results[i] = -res[i]; } } void compute_one_to_many_inner_product_avx2_fp32_1( const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_fp32( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx2_fp32_12( const float *query, const float **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_fp32( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx2.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX2__) template static std::enable_if_t, void> compute_one_to_many_inner_product_avx2_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m256i accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_setzero_si256(); } size_t dim = 0; for (; dim + 32 <= dimensionality; dim += 32) { __m256i q = _mm256_loadu_si256((const __m256i *)(query + dim)); __m256i data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm256_loadu_si256((const __m256i *)(ptrs[i] + dim)); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } __m256i q_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q)); __m256i q_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q, 1)); __m256i data_lo[dp_batch]; __m256i data_hi[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_lo[i] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_regs[i])); data_hi[i] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_regs[i], 1)); } __m256i prod_lo[dp_batch]; __m256i prod_hi[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { prod_lo[i] = _mm256_madd_epi16(q_lo, data_lo[i]); prod_hi[i] = _mm256_madd_epi16(q_hi, data_hi[i]); } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm256_add_epi32(accs[i], _mm256_add_epi32(prod_lo[i], prod_hi[i])); } } int temp_results[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { __m128i lo = _mm256_castsi256_si128(accs[i]); __m128i hi = _mm256_extracti128_si256(accs[i], 1); __m128i sum128 = _mm_add_epi32(lo, hi); sum128 = _mm_hadd_epi32(sum128, sum128); sum128 = _mm_hadd_epi32(sum128, sum128); temp_results[i] = _mm_cvtsi128_si32(sum128); } for (; dim < dimensionality; ++dim) { int8_t q = query[dim]; for (size_t i = 0; i < dp_batch; ++i) { temp_results[i] += q * static_cast(ptrs[i][dim]); } } for (size_t i = 0; i < dp_batch; ++i) { results[i] = static_cast(temp_results[i]); } } void compute_one_to_many_inner_product_avx2_int8_1( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_int8( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx2_int8_12( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx2_int8( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/math_batch/inner_product_distance_batch_impl_int8_avx512fp16.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include namespace zvec::ailego::DistanceBatch { #if defined(__AVX512VNNI__) void compute_one_to_many_inner_product_avx512_vnni_int8_query_preprocess( void *query, size_t dim) { const int8_t *input = reinterpret_cast(query); uint8_t *output = reinterpret_cast(query); // // AVX512 constant: 128 in each byte (cast to int8_t, which becomes -128 // // in signed representation, but addition works correctly due to two's // // complement arithmetic) const __m512i offset = _mm512_set1_epi8(static_cast(128)); // size_t i = 0; // // Process 64 bytes at a time using AVX512 for (; i + 64 <= dim; i += 64) { __m512i data = _mm512_loadu_si512(reinterpret_cast(input + i)); __m512i result = _mm512_add_epi8(data, offset); _mm512_storeu_si512(reinterpret_cast<__m512i *>(output + i), result); } // Handle remaining elements with scalar loop for (; i < dim; ++i) { output[i] = static_cast(static_cast(input[i]) + 128); } } // query is unsigned template static void compute_one_to_many_inner_product_avx512_vnni_int8( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dimensionality, float *results) { __m512i accs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_setzero_si512(); } size_t dim = 0; for (; dim + 64 <= dimensionality; dim += 64) { __m512i q = _mm512_loadu_si512(reinterpret_cast(query + dim)); __m512i data_regs[dp_batch]; for (size_t i = 0; i < dp_batch; ++i) { data_regs[i] = _mm512_loadu_si512(reinterpret_cast(ptrs[i] + dim)); } if (prefetch_ptrs[0]) { for (size_t i = 0; i < dp_batch; ++i) { ailego_prefetch(prefetch_ptrs[i] + dim); } } for (size_t i = 0; i < dp_batch; ++i) { accs[i] = _mm512_dpbusd_epi32(accs[i], q, data_regs[i]); } } int temp_results[dp_batch]{}; for (size_t i = 0; i < dp_batch; ++i) { temp_results[i] = _mm512_reduce_add_epi32(accs[i]); } for (; dim < dimensionality; ++dim) { uint q = reinterpret_cast(query)[dim]; for (size_t i = 0; i < dp_batch; ++i) { temp_results[i] += q * static_cast(ptrs[i][dim]); } } for (size_t i = 0; i < dp_batch; ++i) { results[i] = static_cast(temp_results[i]); } } // // #elif defined(__AVX512BW__) // // TODO: this version is problematic // template // static std::enable_if_t, void> // compute_one_to_many_avx512_int8( // const int8_t *query, const int8_t **ptrs, // std::array &prefetch_ptrs, size_t // dimensionality, float *results) { // std::array<__m512i, dp_batch> accs; // size_t dim = 0; // for (; dim + 64 <= dimensionality; dim += 64) { // __m512i q = // _mm512_loadu_si512(reinterpret_cast(query + dim)); // std::array<__m512i, dp_batch> data_regs; // for (size_t i = 0; i < dp_batch; ++i) { // data_regs[i] = // _mm512_loadu_si512(reinterpret_cast(ptrs[i] + // dim)); // } // if (prefetch_ptrs[0]) { // for (size_t i = 0; i < dp_batch; ++i) { // ailego_prefetch(prefetch_ptrs[i] + dim); // } // } // __m512i q_lo = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(q, 0)); // __m512i q_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(q, 1)); // std::array<__m512i, dp_batch> data_lo; // std::array<__m512i, dp_batch> data_hi; // for (size_t i = 0; i < dp_batch; ++i) { // data_lo[i] = // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_regs[i], 0)); // data_hi[i] = // _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_regs[i], 1)); // } // std::array<__m512i, dp_batch> prod_lo; // std::array<__m512i, dp_batch> prod_hi; // for (size_t i = 0; i < dp_batch; ++i) { // prod_lo[i] = _mm512_madd_epi16(q_lo, data_lo[i]); // prod_hi[i] = _mm512_madd_epi16(q_hi, data_hi[i]); // } // for (size_t i = 0; i < dp_batch; ++i) { // accs[i] = _mm512_add_epi32( // accs[i], _mm512_add_epi32( // _mm512_madd_epi16(prod_lo[i], _mm512_set1_epi16(1)), // _mm512_madd_epi16(prod_hi[i], _mm512_set1_epi16(1)))); // } // } // std::array temp_results; // for (size_t i = 0; i < dp_batch; ++i) { // temp_results[i] = _mm512_reduce_add_epi32(accs[i]); // } // for (; dim < dimensionality; ++dim) { // int8_t q = query[dim]; // for (size_t i = 0; i < dp_batch; ++i) { // temp_results[i] += q * static_cast(ptrs[i][dim]); // } // } // for (size_t i = 0; i < dp_batch; ++i) { // results[i] = static_cast(temp_results[i]); // } // } void compute_one_to_many_inner_product_avx512_vnni_int8_1( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512_vnni_int8<1>( query, ptrs, prefetch_ptrs, dim, sums); } void compute_one_to_many_inner_product_avx512_vnni_int8_12( const int8_t *query, const int8_t **ptrs, std::array &prefetch_ptrs, size_t dim, float *sums) { return compute_one_to_many_inner_product_avx512_vnni_int8<12>( query, ptrs, prefetch_ptrs, dim, sums); } #endif } // namespace zvec::ailego::DistanceBatch ================================================ FILE: src/ailego/parallel/lock.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #if __cplusplus >= 201703L #include #endif #include #include namespace zvec { namespace ailego { // Test if atomic_bool is always lock free. // Arm may be always lock free using some compiler flags, // see https://stackoverflow.com/a/64253858/486350. #if ATOMIC_BOOL_LOCK_FREE == 2 /*! Spin Mutex (The atomic type is always lock-free) */ class SpinMutex { public: //! Constructor SpinMutex(void) {} //! Locking void lock(void) { bool expected = false; while (!flag_.compare_exchange_weak( expected, true, std::memory_order_acquire, std::memory_order_relaxed)) { expected = false; // Provide a hint to the processor that the code sequence is a spin-wait // loop. This can help improve the performance and power consumption of // spin-wait loops. ailego_yield(); } } //! Try locking bool try_lock(void) { bool expected = false; return flag_.compare_exchange_strong( expected, true, std::memory_order_acquire, std::memory_order_relaxed); } //! Unlocking void unlock(void) { flag_.store(false, std::memory_order_release); } private: //! Disable them SpinMutex(const SpinMutex &) = delete; SpinMutex(SpinMutex &&) = delete; SpinMutex &operator=(const SpinMutex &) = delete; SpinMutex &operator=(SpinMutex &&) = delete; //! Members std::atomic_bool flag_{false}; }; #else /*! Spin Mutex (General) */ class SpinMutex { public: //! Constructor SpinMutex(void) {} //! Locking void lock(void) { while (flag_.test_and_set(std::memory_order_acquire)); } //! Try locking bool try_lock(void) { return (!flag_.test_and_set(std::memory_order_acquire)); } //! Unlocking void unlock(void) { flag_.clear(std::memory_order_release); } private: //! Disable them SpinMutex(const SpinMutex &) = delete; SpinMutex(SpinMutex &&) = delete; SpinMutex &operator=(const SpinMutex &) = delete; SpinMutex &operator=(SpinMutex &&) = delete; //! Members std::atomic_flag flag_{}; }; #endif // ATOMIC_BOOL_LOCK_FREE == 2 #if __cplusplus >= 201703L using SharedMutex = std::shared_mutex; #else /*! Shared Mutex */ class SharedMutex { public: //! Constructor SharedMutex(void) {} //! Locking void lock(void) { std::unique_lock q(mutex_); ++write_count_; write_cond_.wait(q, [this]() { return (pending_count_ == 0); }); --write_count_; --pending_count_; } //! Try locking bool try_lock(void) { std::unique_lock q(mutex_, std::defer_lock); if (q.try_lock()) { if (pending_count_ == 0) { --pending_count_; return true; } } return false; } //! Unlocking void unlock(void) { std::lock_guard q(mutex_); ++pending_count_; if (write_count_ != 0) { write_cond_.notify_one(); } else { read_cond_.notify_all(); } } //! Locking (shared) void lock_shared(void) { std::unique_lock q(mutex_); ++read_count_; read_cond_.wait( q, [this]() { return (write_count_ == 0 && pending_count_ >= 0); }); --read_count_; ++pending_count_; } //! Try locking (shared) bool try_lock_shared(void) { std::lock_guard q(mutex_); if (write_count_ == 0 && pending_count_ >= 0) { ++pending_count_; return true; } return false; } //! Unlocking (shared) void unlock_shared(void) { std::lock_guard q(mutex_); --pending_count_; if (write_count_ != 0 && pending_count_ == 0) { write_cond_.notify_one(); } else { read_cond_.notify_all(); } } private: //! Disable them SharedMutex(const SharedMutex &) = delete; SharedMutex(SharedMutex &&) = delete; SharedMutex &operator=(const SharedMutex &) = delete; SharedMutex &operator=(SharedMutex &&) = delete; //! Members int32_t pending_count_{0}; int32_t read_count_{0}; int32_t write_count_{0}; std::mutex mutex_{}; std::condition_variable read_cond_{}; std::condition_variable write_cond_{}; }; #endif // __cplusplus >= 201703L /*! Write Lock */ class WriteLock { public: //! Constructor WriteLock(SharedMutex &mutex) : mutex_(mutex) {} //! Locking void lock(void) { mutex_.lock(); } //! Try locking bool try_lock(void) { return mutex_.try_lock(); } //! Unlocking void unlock(void) { mutex_.unlock(); } private: //! Disable them WriteLock(void) = delete; WriteLock(const WriteLock &) = delete; WriteLock(WriteLock &&) = delete; WriteLock &operator=(const WriteLock &) = delete; WriteLock &operator=(WriteLock &&) = delete; //! Members SharedMutex &mutex_; }; /*! Read Lock */ class ReadLock { public: //! Constructor ReadLock(SharedMutex &mutex) : mutex_(mutex) {} //! Locking void lock(void) { mutex_.lock_shared(); } //! Try locking bool try_lock(void) { return mutex_.try_lock_shared(); } //! Unlocking void unlock(void) { mutex_.unlock_shared(); } private: //! Disable them ReadLock(void) = delete; ReadLock(const ReadLock &) = delete; ReadLock(ReadLock &&) = delete; ReadLock &operator=(const ReadLock &) = delete; ReadLock &operator=(ReadLock &&) = delete; //! Members SharedMutex &mutex_; }; /* Atomic Close Lock */ #define AILEGO_SAFE_ACCESS(CLOSE_ERR) \ counter_.fetch_add(1); \ AILEGO_DEFER([this] { counter_.fetch_sub(1); }); \ if (!opened_.load()) { \ return CLOSE_ERR; \ } #define AILEGO_SAFE_CLOSE \ opened_.store(false); \ while (counter_.load() > 0) { \ std::this_thread::sleep_for(std::chrono::milliseconds(1)); \ } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/parallel/multi_thread_list.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include namespace zvec { namespace ailego { /*! Multi-Thread list */ template class MultiThreadList { public: MultiThreadList(size_t size_limit = 1000) : size_limit_(size_limit) {} bool produce(const T &item) { std::unique_lock lk(lock_); not_full_.wait( lk, [&]() { return (list_.size() < size_limit_) || done_.load(); }); if (done_.load()) { return false; } list_.emplace_back(item); not_empty_.notify_one(); return true; } bool produce(T &&item) { std::unique_lock lk(lock_); not_full_.wait( lk, [&]() { return (list_.size() < size_limit_) || done_.load(); }); if (done_.load()) { return false; } list_.emplace_back(std::move(item)); not_empty_.notify_one(); return true; } bool consume(T *item) { std::unique_lock lk(lock_); not_empty_.wait(lk, [&]() { return !list_.empty() || done_.load() || consume_stopped_.load(); }); if ((list_.empty() && done_.load()) || consume_stopped_.load()) { return false; } *item = std::move(list_.front()); list_.pop_front(); not_full_.notify_one(); return true; } void done() { std::unique_lock lk(lock_); done_.store(true); not_empty_.notify_all(); not_full_.notify_all(); } void reset() { done_.store(false); list_.clear(); } void stop_consume() { std::unique_lock lk(lock_); consume_stopped_.store(true); not_empty_.notify_all(); } void resume_consume() { consume_stopped_.store(false); } private: std::deque list_; size_t size_limit_{0}; std::mutex lock_; std::condition_variable not_empty_, not_full_; std::atomic done_{false}; std::atomic consume_stopped_{false}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/parallel/semaphore.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include namespace zvec { namespace ailego { /*! Semaphore */ class Semaphore { public: //! Constructor Semaphore(void) : Semaphore{1} {} //! Constructor Semaphore(uint32_t count) : count_(count) {} //! Acquire a permit from this semaphore, suspending until one is available void lock(void) { while (!this->try_lock()) { std::unique_lock latch(mutex_); cond_.wait(latch, [this]() { return (count_ > 0); }); } } //! Try to acquire a permit from this semaphore without suspension bool try_lock(void) { uint32_t count = count_.load(std::memory_order_acquire); return (count > 0 ? count_.compare_exchange_strong( count, count - 1, std::memory_order_release, std::memory_order_relaxed) : false); } //! Release a permit, returning it into this semaphore void unlock(void) { ++count_; std::lock_guard latch(mutex_); cond_.notify_one(); } private: //! Disable them Semaphore(const Semaphore &) = delete; Semaphore(Semaphore &&) = delete; Semaphore &operator=(const Semaphore &) = delete; Semaphore &operator=(Semaphore &&) = delete; //! Members std::atomic count_{0}; std::mutex mutex_{}; std::condition_variable cond_{}; }; /*! Binary Semaphores */ template ::type> class BinarySemaphores { public: using BitwiseType = typename std::conditional< N <= 32u, typename std::conditional< N <= 16u, typename std::conditional::type, uint32_t>::type, uint64_t>::type; //! Constructor BinarySemaphores(void) : BinarySemaphores{1} {} //! Constructor BinarySemaphores(uint32_t count) { if (count == 0 || count > N) { count = N; } count_ = count; mask_ = static_cast(BitwiseType(1) << (count - 1)); mask_ |= static_cast(mask_ - 1); flags_.store(mask_); } //! Acquire a permit from this semaphore, suspending until one is available int acquire(void) { int index = -1; while ((index = this->try_acquire()) < 0) { std::unique_lock latch(mutex_); cond_.wait(latch, [this]() { return (flags_ > 0); }); } return index; } //! Try to acquire a permit from this semaphore without suspension int try_acquire(void) { BitwiseType flags = flags_.load(std::memory_order_relaxed); while (flags > 0) { int index = CountTrailingZeros(flags); if (flags_.compare_exchange_weak( flags, flags & (~(BitwiseType(1) << index)), std::memory_order_release, std::memory_order_relaxed)) { return index; } flags = flags_.load(std::memory_order_relaxed); } return -1; } //! Acquire a specified permit from this semaphore, suspending until index is //! available int acquire(int index) { if (index < 0 || (uint32_t)index >= count_) { return -1; } BitwiseType flags = flags_.load(std::memory_order_relaxed); BitwiseType mask = BitwiseType(1) << index; while (true) { if ((flags & mask) && flags_.compare_exchange_weak(flags, flags & (~mask), std::memory_order_release, std::memory_order_relaxed)) { return index; } flags = flags_.load(std::memory_order_relaxed); } } //! Release a permit, returning it into this semaphore void release(int index) { flags_.fetch_or((BitwiseType(1) << index) & mask_); std::lock_guard latch(mutex_); cond_.notify_one(); } protected: //! Count the trailing zeros (32 bits) template static inline auto CountTrailingZeros(T val) -> typename std::enable_if::type { return ailego_ctz32(val); } //! Count the trailing zeros (64 bits) template static inline auto CountTrailingZeros(T val) -> typename std::enable_if::type { return ailego_ctz64(val); } private: //! Disable them BinarySemaphores(const BinarySemaphores &) = delete; BinarySemaphores(BinarySemaphores &&) = delete; BinarySemaphores &operator=(const BinarySemaphores &) = delete; BinarySemaphores &operator=(BinarySemaphores &&) = delete; //! Members uint32_t count_{0}; BitwiseType mask_{0}; std::atomic flags_{0}; std::mutex mutex_{}; std::condition_variable cond_{}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/parallel/thread_pool.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #if (defined(__linux) || defined(__linux__)) && !defined(__ANDROID__) #include static inline void BindThreads(std::vector &pool) { uint32_t hc = std::thread::hardware_concurrency(); if (hc > 1) { cpu_set_t mask; for (size_t i = 0u; i < pool.size(); ++i) { CPU_ZERO(&mask); CPU_SET(i % hc, &mask); pthread_setaffinity_np(pool[i].native_handle(), sizeof(mask), &mask); } } } static inline void UnbindThreads(std::vector &pool) { cpu_set_t mask; CPU_ZERO(&mask); for (size_t i = 0u; i < CPU_SETSIZE; ++i) { CPU_SET(i, &mask); } for (size_t i = 0u; i < pool.size(); ++i) { pthread_setaffinity_np(pool[i].native_handle(), sizeof(mask), &mask); } } #else static inline void BindThreads(std::vector &) {} static inline void UnbindThreads(std::vector &) {} #endif namespace zvec { namespace ailego { ThreadPool::ThreadPool(uint32_t size, bool binding) { for (uint32_t i = 0u; i < size; ++i) { pool_.emplace_back(&ThreadPool::worker, this); } if (binding) { this->bind(); } } void ThreadPool::bind(void) { BindThreads(pool_); } void ThreadPool::unbind(void) { UnbindThreads(pool_); } void ThreadPool::worker(void) { // Counter of workers ++worker_count_; ThreadPool::Task task; while (this->picking(&task)) { // Run the task task.handle->run(); task.handle = nullptr; // Notify task finished if (task.control) { task.control->notify(); } // Notify task group if (task.group) { task.group->notify(); task.group = nullptr; } // Decrease count of active works std::lock_guard lock(wait_mutex_); if (--active_count_ == 0 && pending_count_ == 0) { finished_cond_.notify_all(); } } // Decrease count of workers std::lock_guard lock(wait_mutex_); if (--worker_count_ == 0) { stopped_cond_.notify_all(); } } bool ThreadPool::picking(ThreadPool::Task *task) { std::unique_lock latch(queue_mutex_); work_cond_.wait(latch, [this]() { return (pending_count_ > 0 || stopping_); }); if (stopping_) { return false; } // Pop a task auto &head = queue_.front(); task->control = head.control; task->group = std::move(head.group); task->handle = std::move(head.handle); queue_.pop(); // Update group control if (task->group) { task->group->mark_task_actived(); } // Counter of active tasks std::unique_lock lock(wait_mutex_); ++active_count_; --pending_count_; return true; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/pattern/defer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "scope_guard.h" #define AILEGO_DEFER_NAME_(x, y) x##y #define AILEGO_DEFER_NAME(x) AILEGO_DEFER_NAME_(__ailegoDefer_, x) //! Defer operator #define AILEGO_DEFER(...) \ auto AILEGO_DEFER_NAME(__LINE__) = ailego::ScopeGuard::Make(__VA_ARGS__) ================================================ FILE: src/ailego/pattern/scope_guard.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { /*! Scope Guard Implementation */ template class ScopeGuardImpl { public: using Object = CallbackObject; using Functor = CallbackFunctor; //! Constructor ScopeGuardImpl(ScopeGuardImpl &&rhs) : obj_(rhs.obj_), impl_(std::move(rhs.impl_)), tuple_(std::move(rhs.tuple_)) { rhs.obj_ = nullptr; } //! Constructor template ScopeGuardImpl(typename Object::Type *obj, const typename Functor::Type &impl, TArgs &&...args) : obj_(obj), impl_(impl), tuple_(std::forward(args)...) {} //! Constructor template ScopeGuardImpl(typename Object::Type *obj, typename Functor::Type &&impl, TArgs &&...args) : obj_(obj), impl_(std::move(impl)), tuple_(std::forward(args)...) {} // Destructor ~ScopeGuardImpl(void) { if (obj_) { Functor::Run(obj_, impl_, tuple_); } } protected: //! Disable them ScopeGuardImpl(void) = delete; ScopeGuardImpl(const ScopeGuardImpl &) = delete; ScopeGuardImpl &operator=(const ScopeGuardImpl &) = delete; private: //! Members typename Object::Type *obj_; typename Functor::Type impl_; typename Functor::TupleType tuple_; }; /*! Scope Guard Implementation (void, TFunc) */ template class ScopeGuardImpl { public: //! Callback Functor Type using Functor = CallbackFunctor; //! Constructor ScopeGuardImpl(ScopeGuardImpl &&rhs) : impl_(std::move(rhs.impl_)), tuple_(std::move(rhs.tuple_)), valid_(rhs.valid_) { rhs.valid_ = false; } //! Constructor template ScopeGuardImpl(const typename Functor::Type &impl, TArgs &&...args) : impl_(impl), tuple_(std::forward(args)...), valid_(true) {} //! Constructor template ScopeGuardImpl(typename Functor::Type &&impl, TArgs &&...args) : impl_(std::move(impl)), tuple_(std::forward(args)...), valid_(true) {} // Destructor ~ScopeGuardImpl(void) { if (valid_) { Functor::Run(impl_, tuple_); } } protected: //! Disable them ScopeGuardImpl(void) = delete; ScopeGuardImpl(const ScopeGuardImpl &) = delete; ScopeGuardImpl &operator=(const ScopeGuardImpl &) = delete; private: //! Members typename Functor::Type impl_; typename Functor::TupleType tuple_; bool valid_; }; /*! Scope Guard */ struct ScopeGuard { //! Make a scope guard object (member function pointer) template static inline auto Make(T *obj, R (T::*impl)(TParams...), TArgs &&...args) -> ScopeGuardImpl::Type> { return ScopeGuardImpl::Type>( obj, impl, std::forward(args)...); } //! Make a scope guard object (constable member function pointer) template static inline auto Make(const T *obj, R (T::*impl)(TParams...) const, TArgs &&...args) -> ScopeGuardImpl::Type> { return ScopeGuardImpl::Type>( obj, impl, std::forward(args)...); } //! Make a scope guard object (volatile member function pointer) template static inline auto Make(volatile T *obj, R (T::*impl)(TParams...) volatile, TArgs &&...args) -> ScopeGuardImpl::Type> { return ScopeGuardImpl::Type>( obj, impl, std::forward(args)...); } //! Make a scope guard object (constable volatile member function pointer) template static inline auto Make(const volatile T *obj, R (T::*impl)(TParams...) const volatile, TArgs &&...args) -> ScopeGuardImpl::Type> { return ScopeGuardImpl::Type>( obj, impl, std::forward(args)...); } //! Make a scope guard object (function) template < typename TFunc, typename... TArgs, typename = typename std::enable_if::Value>::type> static inline auto Make(TFunc &&impl, TArgs &&...args) -> ScopeGuardImpl::Type> { return ScopeGuardImpl::Type>( std::forward(impl), std::forward(args)...); } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/bit_string_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { class BitStringWriter { public: BitStringWriter(uint8_t *buffer, size_t buffer_size) : buffer_(buffer), buffer_size_(buffer_size), offset_(0) { ::memset(buffer_, 0, buffer_size_); } bool write(uint64_t data, int nbit) { if (buffer_size_ * 8 < nbit + offset_) { return false; } int bits_remain = 8 - (offset_ & 7); if (nbit <= bits_remain) { buffer_[offset_ >> 3] |= data << (offset_ & 7); offset_ += nbit; } else { size_t j = offset_ >> 3; buffer_[j++] |= data << (offset_ & 7); offset_ += nbit; data >>= bits_remain; while (data != 0) { buffer_[j++] |= data; data >>= 8; } } return true; } size_t offset() { return offset_; } private: uint8_t *buffer_; size_t buffer_size_; size_t offset_; }; class BitStringReader { public: BitStringReader(const uint8_t *buffer, size_t buffer_size) : buffer_(buffer), buffer_size_(buffer_size), offset_(0) {} bool read(uint64_t &data, int nbit) { if (buffer_size_ * 8 < nbit + offset_) { return false; } int bits_remain = 8 - (offset_ & 7); uint64_t result = buffer_[offset_ >> 3] >> (offset_ & 7); if (nbit <= bits_remain) { result &= (1 << nbit) - 1; offset_ += nbit; data = result; } else { int temp = bits_remain; size_t i = (offset_ >> 3) + 1; offset_ += nbit; nbit -= bits_remain; while (nbit > 8) { result |= ((uint64_t)buffer_[i++]) << temp; temp += 8; nbit -= 8; } uint64_t last_byte = buffer_[i]; last_byte &= (1 << nbit) - 1; result |= last_byte << temp; data = result; } return true; } size_t offset() { return offset_; } private: const uint8_t *buffer_; size_t buffer_size_; size_t offset_; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/bitset_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "bitset_helper.h" #include #ifndef __SSE4_2__ #define bitset_popcount32 ailego_popcount32 #define bitset_popcount64 ailego_popcount64 #else #define bitset_popcount32 _mm_popcnt_u32 #define bitset_popcount64 _mm_popcnt_u64 #endif // !__SSE4_2__ #if defined(__ARM_NEON) static inline void bitset_and(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { vst1q_u32(lhs, vandq_u32(vld1q_u32(lhs), vld1q_u32(rhs))); } switch (last - last_aligned) { case 3: lhs[2] &= rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= rhs[0]; } } static inline void bitset_andnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { vst1q_u32(lhs, vbicq_u32(vld1q_u32(lhs), vld1q_u32(rhs))); } switch (last - last_aligned) { case 3: lhs[2] &= ~rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= ~rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= ~rhs[0]; } } static inline void bitset_or(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { vst1q_u32(lhs, vorrq_u32(vld1q_u32(lhs), vld1q_u32(rhs))); } switch (last - last_aligned) { case 3: lhs[2] |= rhs[2]; /* FALLTHRU */ case 2: lhs[1] |= rhs[1]; /* FALLTHRU */ case 1: lhs[0] |= rhs[0]; } } static inline void bitset_xor(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { vst1q_u32(lhs, veorq_u32(vld1q_u32(lhs), vld1q_u32(rhs))); } switch (last - last_aligned) { case 3: lhs[2] ^= rhs[2]; /* FALLTHRU */ case 2: lhs[1] ^= rhs[1]; /* FALLTHRU */ case 1: lhs[0] ^= rhs[0]; } } static inline void bitset_not(uint32_t *lhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); static const uint32x4_t v_zero = vdupq_n_u32(0); for (; lhs != last_aligned; lhs += 4) { vst1q_u32(lhs, vornq_u32(v_zero, vld1q_u32(lhs))); } switch (last - last_aligned) { case 3: lhs[2] = ~lhs[2]; /* FALLTHRU */ case 2: lhs[1] = ~lhs[1]; /* FALLTHRU */ case 1: lhs[0] = ~lhs[0]; } } static inline bool bitset_test_all(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { uint64x2_t vu64 = vld1q_u64((const uint64_t *)lhs); if ((vgetq_lane_u64(vu64, 0) & vgetq_lane_u64(vu64, 1)) != (uint64_t)-1) { return false; } } switch (last - last_aligned) { case 3: if (lhs[2] != 0xffffffffu) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0xffffffffu) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0xffffffffu) { return false; } } return true; } static inline bool bitset_test_any(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { uint64x2_t vu64 = vld1q_u64((const uint64_t *)lhs); if (vgetq_lane_u64(vu64, 0) | vgetq_lane_u64(vu64, 1)) { return true; } } switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return true; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return true; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return true; } } return false; } static inline bool bitset_test_none(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { uint64x2_t vu64 = vld1q_u64((const uint64_t *)lhs); if (vgetq_lane_u64(vu64, 0) | vgetq_lane_u64(vu64, 1)) { return false; } } switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return false; } } return true; } #elif defined(__AVX2__) static inline void bitset_and(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); __m256i ymm1 = _mm256_load_si256((__m256i *)rhs); _mm256_store_si256((__m256i *)lhs, _mm256_and_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_and_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); __m256i ymm1 = _mm256_loadu_si256((__m256i *)rhs); _mm256_storeu_si256((__m256i *)lhs, _mm256_and_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_and_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } switch (last - lhs) { case 3: lhs[2] &= rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= rhs[0]; } } static inline void bitset_andnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); __m256i ymm1 = _mm256_load_si256((__m256i *)rhs); _mm256_store_si256((__m256i *)lhs, _mm256_andnot_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_andnot_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); __m256i ymm1 = _mm256_loadu_si256((__m256i *)rhs); _mm256_storeu_si256((__m256i *)lhs, _mm256_andnot_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_andnot_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } switch (last - lhs) { case 3: lhs[2] &= ~rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= ~rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= ~rhs[0]; } } static inline void bitset_or(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); __m256i ymm1 = _mm256_load_si256((__m256i *)rhs); _mm256_store_si256((__m256i *)lhs, _mm256_or_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_or_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); __m256i ymm1 = _mm256_loadu_si256((__m256i *)rhs); _mm256_storeu_si256((__m256i *)lhs, _mm256_or_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_or_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } switch (last - lhs) { case 3: lhs[2] |= rhs[2]; /* FALLTHRU */ case 2: lhs[1] |= rhs[1]; /* FALLTHRU */ case 1: lhs[0] |= rhs[0]; } } static inline void bitset_xor(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0 && ((uintptr_t)rhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); __m256i ymm1 = _mm256_load_si256((__m256i *)rhs); _mm256_store_si256((__m256i *)lhs, _mm256_xor_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_xor_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } else { for (; lhs != last_aligned; lhs += 8, rhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); __m256i ymm1 = _mm256_loadu_si256((__m256i *)rhs); _mm256_storeu_si256((__m256i *)lhs, _mm256_xor_si256(ymm1, ymm0)); } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_xor_si128(xmm1, xmm0)); lhs += 4; rhs += 4; } } switch (last - lhs) { case 3: lhs[2] ^= rhs[2]; /* FALLTHRU */ case 2: lhs[1] ^= rhs[1]; /* FALLTHRU */ case 1: lhs[0] ^= rhs[0]; } } static inline void bitset_not(uint32_t *lhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); static const __m256i mask_256 = _mm256_set1_epi32(0xffffffffu); static const __m128i mask_128 = _mm_set1_epi32(0xffffffffu); if (((uintptr_t)lhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8) { _mm256_store_si256( (__m256i *)lhs, _mm256_andnot_si256(_mm256_load_si256((__m256i *)lhs), mask_256)); } if (last >= last_aligned + 4) { _mm_store_si128( (__m128i *)lhs, _mm_andnot_si128(_mm_load_si128((__m128i *)lhs), mask_128)); lhs += 4; } } else { for (; lhs != last_aligned; lhs += 8) { _mm256_storeu_si256( (__m256i *)lhs, _mm256_andnot_si256(_mm256_loadu_si256((__m256i *)lhs), mask_256)); } if (last >= last_aligned + 4) { _mm_storeu_si128( (__m128i *)lhs, _mm_andnot_si128(_mm_lddqu_si128((__m128i *)lhs), mask_128)); lhs += 4; } } switch (last - lhs) { case 3: lhs[2] = ~lhs[2]; /* FALLTHRU */ case 2: lhs[1] = ~lhs[1]; /* FALLTHRU */ case 1: lhs[0] = ~lhs[0]; } } static inline bool bitset_test_all(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); static const __m256i mask_256 = _mm256_set1_epi32(0xffffffffu); static const __m128i mask_128 = _mm_set1_epi32(0xffffffffu); if (((uintptr_t)lhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8) { __m256i neq = _mm256_xor_si256(_mm256_load_si256((__m256i *)lhs), mask_256); if (!_mm256_testz_si256(neq, neq)) { return false; } } if (last >= last_aligned + 4) { __m128i neq = _mm_xor_si128(_mm_load_si128((__m128i *)lhs), mask_128); if (!_mm_testz_si128(neq, neq)) { return false; } lhs += 4; } } else { for (; lhs != last_aligned; lhs += 8) { __m256i neq = _mm256_xor_si256(_mm256_loadu_si256((__m256i *)lhs), mask_256); if (!_mm256_testz_si256(neq, neq)) { return false; } } if (last >= last_aligned + 4) { __m128i neq = _mm_xor_si128(_mm_lddqu_si128((__m128i *)lhs), mask_128); if (!_mm_testz_si128(neq, neq)) { return false; } lhs += 4; } } switch (last - lhs) { case 3: if (lhs[2] != 0xffffffffu) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0xffffffffu) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0xffffffffu) { return false; } } return true; } static inline bool bitset_test_any(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); if (!_mm256_testz_si256(ymm0, ymm0)) { return true; } } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return true; } lhs += 4; } } else { for (; lhs != last_aligned; lhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); if (!_mm256_testz_si256(ymm0, ymm0)) { return true; } } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return true; } lhs += 4; } } switch (last - lhs) { case 3: if (lhs[2] != 0u) { return true; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return true; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return true; } } return false; } static inline bool bitset_test_none(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); if (((uintptr_t)lhs & 0x1f) == 0) { for (; lhs != last_aligned; lhs += 8) { __m256i ymm0 = _mm256_load_si256((__m256i *)lhs); if (!_mm256_testz_si256(ymm0, ymm0)) { return false; } } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return false; } lhs += 4; } } else { for (; lhs != last_aligned; lhs += 8) { __m256i ymm0 = _mm256_loadu_si256((__m256i *)lhs); if (!_mm256_testz_si256(ymm0, ymm0)) { return false; } } if (last >= last_aligned + 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return false; } lhs += 4; } } switch (last - lhs) { case 3: if (lhs[2] != 0u) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return false; } } return true; } #elif defined(__SSE2__) #ifndef __SSE3__ #define _mm_lddqu_si128 _mm_loadu_si128 #endif // !__SSE3__ static inline void bitset_and(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_and_si128(xmm1, xmm0)); } } else { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_and_si128(xmm1, xmm0)); } } switch (last - last_aligned) { case 3: lhs[2] &= rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= rhs[0]; } } static inline void bitset_andnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_andnot_si128(xmm1, xmm0)); } } else { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_andnot_si128(xmm1, xmm0)); } } switch (last - last_aligned) { case 3: lhs[2] &= ~rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= ~rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= ~rhs[0]; } } static inline void bitset_or(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_or_si128(xmm1, xmm0)); } } else { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_or_si128(xmm1, xmm0)); } } switch (last - last_aligned) { case 3: lhs[2] |= rhs[2]; /* FALLTHRU */ case 2: lhs[1] |= rhs[1]; /* FALLTHRU */ case 1: lhs[0] |= rhs[0]; } } static inline void bitset_xor(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); if (((uintptr_t)lhs & 0xf) == 0 && ((uintptr_t)rhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); __m128i xmm1 = _mm_load_si128((__m128i *)rhs); _mm_store_si128((__m128i *)lhs, _mm_xor_si128(xmm1, xmm0)); } } else { for (; lhs != last_aligned; lhs += 4, rhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); __m128i xmm1 = _mm_lddqu_si128((__m128i *)rhs); _mm_storeu_si128((__m128i *)lhs, _mm_xor_si128(xmm1, xmm0)); } } switch (last - last_aligned) { case 3: lhs[2] ^= rhs[2]; /* FALLTHRU */ case 2: lhs[1] ^= rhs[1]; /* FALLTHRU */ case 1: lhs[0] ^= rhs[0]; } } static inline void bitset_not(uint32_t *lhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); static const __m128i mask = _mm_set1_epi32(0xffffffffu); if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { _mm_store_si128((__m128i *)lhs, _mm_andnot_si128(_mm_load_si128((__m128i *)lhs), mask)); } } else { for (; lhs != last_aligned; lhs += 4) { _mm_storeu_si128((__m128i *)lhs, _mm_andnot_si128(_mm_lddqu_si128((__m128i *)lhs), mask)); } } switch (last - last_aligned) { case 3: lhs[2] = ~lhs[2]; /* FALLTHRU */ case 2: lhs[1] = ~lhs[1]; /* FALLTHRU */ case 1: lhs[0] = ~lhs[0]; } } static inline bool bitset_test_all(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); static const __m128i mask = _mm_set1_epi32(0xffffffffu); #ifndef __SSE4_1__ if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_load_si128((__m128i *)lhs), mask); if (_mm_movemask_epi8(eq) != 0xffffu) { return false; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_lddqu_si128((__m128i *)lhs), mask); if (_mm_movemask_epi8(eq) != 0xffffu) { return false; } } } #else if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i neq = _mm_xor_si128(_mm_load_si128((__m128i *)lhs), mask); if (!_mm_testz_si128(neq, neq)) { return false; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i neq = _mm_xor_si128(_mm_lddqu_si128((__m128i *)lhs), mask); if (!_mm_testz_si128(neq, neq)) { return false; } } } #endif // !__SSE4_1__ switch (last - last_aligned) { case 3: if (lhs[2] != 0xffffffffu) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0xffffffffu) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0xffffffffu) { return false; } } return true; } static inline bool bitset_test_any(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); #ifndef __SSE4_1__ static const __m128i zero = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_load_si128((__m128i *)lhs), zero); if (_mm_movemask_epi8(eq) != 0xffffu) { return true; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_lddqu_si128((__m128i *)lhs), zero); if (_mm_movemask_epi8(eq) != 0xffffu) { return true; } } } #else if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return true; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return true; } } } #endif // !__SSE4_1__ switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return true; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return true; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return true; } } return false; } static inline bool bitset_test_none(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); #ifndef __SSE4_1__ static __m128i zero = _mm_setzero_si128(); if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_load_si128((__m128i *)lhs), zero); if (_mm_movemask_epi8(eq) != 0xffffu) { return false; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i eq = _mm_cmpeq_epi32(_mm_lddqu_si128((__m128i *)lhs), zero); if (_mm_movemask_epi8(eq) != 0xffffu) { return false; } } } #else if (((uintptr_t)lhs & 0xf) == 0) { for (; lhs != last_aligned; lhs += 4) { __m128i xmm0 = _mm_load_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return false; } } } else { for (; lhs != last_aligned; lhs += 4) { __m128i xmm0 = _mm_lddqu_si128((__m128i *)lhs); if (!_mm_testz_si128(xmm0, xmm0)) { return false; } } } #endif // !__SSE4_1__ switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return false; } } return true; } #else #if defined(AILEGO_M64) static inline void bitset_and(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8, rhs += 8) { *(uint64_t *)(&lhs[6]) &= *(uint64_t *)(&rhs[6]); *(uint64_t *)(&lhs[4]) &= *(uint64_t *)(&rhs[4]); *(uint64_t *)(&lhs[2]) &= *(uint64_t *)(&rhs[2]); *(uint64_t *)(&lhs[0]) &= *(uint64_t *)(&rhs[0]); } switch (last - last_aligned) { case 7: lhs[6] &= rhs[6]; /* FALLTHRU */ case 6: lhs[5] &= rhs[5]; /* FALLTHRU */ case 5: lhs[4] &= rhs[4]; /* FALLTHRU */ case 4: lhs[3] &= rhs[3]; /* FALLTHRU */ case 3: lhs[2] &= rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= rhs[0]; } } static inline void bitset_andnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8, rhs += 8) { *(uint64_t *)(&lhs[6]) &= ~(*(uint64_t *)(&rhs[6])); *(uint64_t *)(&lhs[4]) &= ~(*(uint64_t *)(&rhs[4])); *(uint64_t *)(&lhs[2]) &= ~(*(uint64_t *)(&rhs[2])); *(uint64_t *)(&lhs[0]) &= ~(*(uint64_t *)(&rhs[0])); } switch (last - last_aligned) { case 7: lhs[6] &= ~rhs[6]; /* FALLTHRU */ case 6: lhs[5] &= ~rhs[5]; /* FALLTHRU */ case 5: lhs[4] &= ~rhs[4]; /* FALLTHRU */ case 4: lhs[3] &= ~rhs[3]; /* FALLTHRU */ case 3: lhs[2] &= ~rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= ~rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= ~rhs[0]; } } static inline void bitset_or(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8, rhs += 8) { *(uint64_t *)(&lhs[6]) |= *(uint64_t *)(&rhs[6]); *(uint64_t *)(&lhs[4]) |= *(uint64_t *)(&rhs[4]); *(uint64_t *)(&lhs[2]) |= *(uint64_t *)(&rhs[2]); *(uint64_t *)(&lhs[0]) |= *(uint64_t *)(&rhs[0]); } switch (last - last_aligned) { case 7: lhs[6] |= rhs[6]; /* FALLTHRU */ case 6: lhs[5] |= rhs[5]; /* FALLTHRU */ case 5: lhs[4] |= rhs[4]; /* FALLTHRU */ case 4: lhs[3] |= rhs[3]; /* FALLTHRU */ case 3: lhs[2] |= rhs[2]; /* FALLTHRU */ case 2: lhs[1] |= rhs[1]; /* FALLTHRU */ case 1: lhs[0] |= rhs[0]; } } static inline void bitset_xor(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8, rhs += 8) { *(uint64_t *)(&lhs[6]) ^= *(uint64_t *)(&rhs[6]); *(uint64_t *)(&lhs[4]) ^= *(uint64_t *)(&rhs[4]); *(uint64_t *)(&lhs[2]) ^= *(uint64_t *)(&rhs[2]); *(uint64_t *)(&lhs[0]) ^= *(uint64_t *)(&rhs[0]); } switch (last - last_aligned) { case 7: lhs[6] ^= rhs[6]; /* FALLTHRU */ case 6: lhs[5] ^= rhs[5]; /* FALLTHRU */ case 5: lhs[4] ^= rhs[4]; /* FALLTHRU */ case 4: lhs[3] ^= rhs[3]; /* FALLTHRU */ case 3: lhs[2] ^= rhs[2]; /* FALLTHRU */ case 2: lhs[1] ^= rhs[1]; /* FALLTHRU */ case 1: lhs[0] ^= rhs[0]; } } static inline void bitset_not(uint32_t *lhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8) { *(uint64_t *)(&lhs[6]) = ~(*(uint64_t *)(&lhs[6])); *(uint64_t *)(&lhs[4]) = ~(*(uint64_t *)(&lhs[4])); *(uint64_t *)(&lhs[2]) = ~(*(uint64_t *)(&lhs[2])); *(uint64_t *)(&lhs[0]) = ~(*(uint64_t *)(&lhs[0])); } switch (last - last_aligned) { case 7: lhs[6] = ~lhs[6]; /* FALLTHRU */ case 6: lhs[5] = ~lhs[5]; /* FALLTHRU */ case 5: lhs[4] = ~lhs[4]; /* FALLTHRU */ case 4: lhs[3] = ~lhs[3]; /* FALLTHRU */ case 3: lhs[2] = ~lhs[2]; /* FALLTHRU */ case 2: lhs[1] = ~lhs[1]; /* FALLTHRU */ case 1: lhs[0] = ~lhs[0]; } } static inline bool bitset_test_all(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8) { if (*(uint64_t *)(&lhs[6]) != (uint64_t)-1) { return false; } if (*(uint64_t *)(&lhs[4]) != (uint64_t)-1) { return false; } if (*(uint64_t *)(&lhs[2]) != (uint64_t)-1) { return false; } if (*(uint64_t *)(&lhs[0]) != (uint64_t)-1) { return false; } } switch (last - last_aligned) { case 7: if (lhs[6] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 6: if (lhs[5] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 5: if (lhs[4] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 4: if (lhs[3] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 3: if (lhs[2] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != (uint32_t)-1) { return false; } } return true; } static inline bool bitset_test_any(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8) { if (*(uint64_t *)(&lhs[6]) != 0u) { return true; } if (*(uint64_t *)(&lhs[4]) != 0u) { return true; } if (*(uint64_t *)(&lhs[2]) != 0u) { return true; } if (*(uint64_t *)(&lhs[0]) != 0u) { return true; } } switch (last - last_aligned) { case 7: if (lhs[6] != 0u) { return true; } /* FALLTHRU */ case 6: if (lhs[5] != 0u) { return true; } /* FALLTHRU */ case 5: if (lhs[4] != 0u) { return true; } /* FALLTHRU */ case 4: if (lhs[3] != 0u) { return true; } /* FALLTHRU */ case 3: if (lhs[2] != 0u) { return true; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return true; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return true; } } return false; } static inline bool bitset_test_none(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); for (; lhs != last_aligned; lhs += 8) { if (*(uint64_t *)(&lhs[6]) != 0u) { return false; } if (*(uint64_t *)(&lhs[4]) != 0u) { return false; } if (*(uint64_t *)(&lhs[2]) != 0u) { return false; } if (*(uint64_t *)(&lhs[0]) != 0u) { return false; } } switch (last - last_aligned) { case 7: if (lhs[6] != 0u) { return false; } /* FALLTHRU */ case 6: if (lhs[5] != 0u) { return false; } /* FALLTHRU */ case 5: if (lhs[4] != 0u) { return false; } /* FALLTHRU */ case 4: if (lhs[3] != 0u) { return false; } /* FALLTHRU */ case 3: if (lhs[2] != 0u) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return false; } } return true; } #else // AILEGO_M64 static inline void bitset_and(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { lhs[3] &= rhs[3]; lhs[2] &= rhs[2]; lhs[1] &= rhs[1]; lhs[0] &= rhs[0]; } switch (last - last_aligned) { case 3: lhs[2] &= rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= rhs[0]; } } static inline void bitset_andnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { lhs[3] &= ~rhs[3]; lhs[2] &= ~rhs[2]; lhs[1] &= ~rhs[1]; lhs[0] &= ~rhs[0]; } switch (last - last_aligned) { case 3: lhs[2] &= ~rhs[2]; /* FALLTHRU */ case 2: lhs[1] &= ~rhs[1]; /* FALLTHRU */ case 1: lhs[0] &= ~rhs[0]; } } static inline void bitset_or(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { lhs[3] |= rhs[3]; lhs[2] |= rhs[2]; lhs[1] |= rhs[1]; lhs[0] |= rhs[0]; } switch (last - last_aligned) { case 3: lhs[2] |= rhs[2]; /* FALLTHRU */ case 2: lhs[1] |= rhs[1]; /* FALLTHRU */ case 1: lhs[0] |= rhs[0]; } } static inline void bitset_xor(uint32_t *lhs, const uint32_t *rhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4, rhs += 4) { lhs[3] ^= rhs[3]; lhs[2] ^= rhs[2]; lhs[1] ^= rhs[1]; lhs[0] ^= rhs[0]; } switch (last - last_aligned) { case 3: lhs[2] ^= rhs[2]; /* FALLTHRU */ case 2: lhs[1] ^= rhs[1]; /* FALLTHRU */ case 1: lhs[0] ^= rhs[0]; } } static inline void bitset_not(uint32_t *lhs, size_t size) { uint32_t *last = lhs + size; uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { lhs[3] = ~lhs[3]; lhs[2] = ~lhs[2]; lhs[1] = ~lhs[1]; lhs[0] = ~lhs[0]; } switch (last - last_aligned) { case 3: lhs[2] = ~lhs[2]; /* FALLTHRU */ case 2: lhs[1] = ~lhs[1]; /* FALLTHRU */ case 1: lhs[0] = ~lhs[0]; } } static inline bool bitset_test_all(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { if (lhs[3] != (uint32_t)-1) { return false; } if (lhs[2] != (uint32_t)-1) { return false; } if (lhs[1] != (uint32_t)-1) { return false; } if (lhs[0] != (uint32_t)-1) { return false; } } switch (last - last_aligned) { case 3: if (lhs[2] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != (uint32_t)-1) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != (uint32_t)-1) { return false; } } return true; } static inline bool bitset_test_any(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { if (lhs[3] != 0u) { return true; } if (lhs[2] != 0u) { return true; } if (lhs[1] != 0u) { return true; } if (lhs[0] != 0u) { return true; } } switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return true; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return true; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return true; } } return false; } static inline bool bitset_test_none(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); for (; lhs != last_aligned; lhs += 4) { if (lhs[3] != 0u) { return false; } if (lhs[2] != 0u) { return false; } if (lhs[1] != 0u) { return false; } if (lhs[0] != 0u) { return false; } } switch (last - last_aligned) { case 3: if (lhs[2] != 0u) { return false; } /* FALLTHRU */ case 2: if (lhs[1] != 0u) { return false; } /* FALLTHRU */ case 1: if (lhs[0] != 0u) { return false; } } return true; } #endif // AILEGO_M64 #endif // __AVX2__ #if (defined(__ARM_NEON) && defined(__aarch64__)) static inline size_t bitset_cardinality(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; while (lhs != last_aligned) { const uint32_t *last_stage = (last_aligned <= lhs + 124u) ? last_aligned : lhs + 124u; uint8x16_t v_count = vdupq_n_u8(0); for (; lhs != last_stage; lhs += 4) { v_count = vaddq_u8(vcntq_u8(vld1q_u8((const uint8_t *)lhs)), v_count); } v_count = vreinterpretq_u8_u16(vpaddlq_u8(v_count)); count += vaddvq_u16(vreinterpretq_u16_u8(v_count)); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0]); } return count; } static inline size_t bitset_xor_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; while (lhs != last_aligned) { const uint32_t *last_stage = (last_aligned <= lhs + 124u) ? last_aligned : lhs + 124u; uint8x16_t v_count = vdupq_n_u8(0); for (; lhs != last_stage; lhs += 4, rhs += 4) { v_count = vaddq_u8(vcntq_u8(veorq_u8(vld1q_u8((const uint8_t *)lhs), vld1q_u8((const uint8_t *)rhs))), v_count); } v_count = vreinterpretq_u8_u16(vpaddlq_u8(v_count)); count += vaddvq_u16(vreinterpretq_u16_u8(v_count)); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] ^ rhs[0]); } return count; } static inline size_t bitset_and_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; while (lhs != last_aligned) { const uint32_t *last_stage = (last_aligned <= lhs + 124u) ? last_aligned : lhs + 124u; uint8x16_t v_count = vdupq_n_u8(0); for (; lhs != last_stage; lhs += 4, rhs += 4) { v_count = vaddq_u8(vcntq_u8(vandq_u8(vld1q_u8((const uint8_t *)lhs), vld1q_u8((const uint8_t *)rhs))), v_count); } v_count = vreinterpretq_u8_u16(vpaddlq_u8(v_count)); count += vaddvq_u16(vreinterpretq_u16_u8(v_count)); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] & rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & rhs[0]); } return count; } static inline size_t bitset_andnot_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; while (lhs != last_aligned) { const uint32_t *last_stage = (last_aligned <= lhs + 124u) ? last_aligned : lhs + 124u; uint8x16_t v_count = vdupq_n_u8(0); for (; lhs != last_stage; lhs += 4, rhs += 4) { v_count = vaddq_u8(vcntq_u8(vbicq_u8(vld1q_u8((const uint8_t *)lhs), vld1q_u8((const uint8_t *)rhs))), v_count); } v_count = vreinterpretq_u8_u16(vpaddlq_u8(v_count)); count += vaddvq_u16(vreinterpretq_u16_u8(v_count)); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] & ~rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & ~rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & ~rhs[0]); } return count; } static inline size_t bitset_or_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; while (lhs != last_aligned) { const uint32_t *last_stage = (last_aligned <= lhs + 124u) ? last_aligned : lhs + 124u; uint8x16_t v_count = vdupq_n_u8(0); for (; lhs != last_stage; lhs += 4, rhs += 4) { v_count = vaddq_u8(vcntq_u8(vorrq_u8(vld1q_u8((const uint8_t *)lhs), vld1q_u8((const uint8_t *)rhs))), v_count); } v_count = vreinterpretq_u8_u16(vpaddlq_u8(v_count)); count += vaddvq_u16(vreinterpretq_u16_u8(v_count)); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] | rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] | rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] | rhs[0]); } return count; } #elif defined(AILEGO_M64) static inline size_t bitset_cardinality(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8) { count += bitset_popcount64(*(uint64_t *)(&lhs[6])); count += bitset_popcount64(*(uint64_t *)(&lhs[4])); count += bitset_popcount64(*(uint64_t *)(&lhs[2])); count += bitset_popcount64(*(uint64_t *)(&lhs[0])); } switch (last - last_aligned) { case 7: count += bitset_popcount32(lhs[6]); /* FALLTHRU */ case 6: count += bitset_popcount32(lhs[5]); /* FALLTHRU */ case 5: count += bitset_popcount32(lhs[4]); /* FALLTHRU */ case 4: count += bitset_popcount32(lhs[3]); /* FALLTHRU */ case 3: count += bitset_popcount32(lhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0]); } return count; } static inline size_t bitset_xor_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8, rhs += 8) { count += bitset_popcount64(*(uint64_t *)(&lhs[6]) ^ *(uint64_t *)(&rhs[6])); count += bitset_popcount64(*(uint64_t *)(&lhs[4]) ^ *(uint64_t *)(&rhs[4])); count += bitset_popcount64(*(uint64_t *)(&lhs[2]) ^ *(uint64_t *)(&rhs[2])); count += bitset_popcount64(*(uint64_t *)(&lhs[0]) ^ *(uint64_t *)(&rhs[0])); } switch (last - last_aligned) { case 7: count += bitset_popcount32(lhs[6] ^ rhs[6]); /* FALLTHRU */ case 6: count += bitset_popcount32(lhs[5] ^ rhs[5]); /* FALLTHRU */ case 5: count += bitset_popcount32(lhs[4] ^ rhs[4]); /* FALLTHRU */ case 4: count += bitset_popcount32(lhs[3] ^ rhs[3]); /* FALLTHRU */ case 3: count += bitset_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] ^ rhs[0]); } return count; } static inline size_t bitset_and_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8, rhs += 8) { count += bitset_popcount64(*(uint64_t *)(&lhs[6]) & *(uint64_t *)(&rhs[6])); count += bitset_popcount64(*(uint64_t *)(&lhs[4]) & *(uint64_t *)(&rhs[4])); count += bitset_popcount64(*(uint64_t *)(&lhs[2]) & *(uint64_t *)(&rhs[2])); count += bitset_popcount64(*(uint64_t *)(&lhs[0]) & *(uint64_t *)(&rhs[0])); } switch (last - last_aligned) { case 7: count += bitset_popcount32(lhs[6] & rhs[6]); /* FALLTHRU */ case 6: count += bitset_popcount32(lhs[5] & rhs[5]); /* FALLTHRU */ case 5: count += bitset_popcount32(lhs[4] & rhs[4]); /* FALLTHRU */ case 4: count += bitset_popcount32(lhs[3] & rhs[3]); /* FALLTHRU */ case 3: count += bitset_popcount32(lhs[2] & rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & rhs[0]); } return count; } static inline size_t bitset_andnot_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8, rhs += 8) { count += bitset_popcount64(*(uint64_t *)(&lhs[6]) & ~(*(uint64_t *)(&rhs[6]))); count += bitset_popcount64(*(uint64_t *)(&lhs[4]) & ~(*(uint64_t *)(&rhs[4]))); count += bitset_popcount64(*(uint64_t *)(&lhs[2]) & ~(*(uint64_t *)(&rhs[2]))); count += bitset_popcount64(*(uint64_t *)(&lhs[0]) & ~(*(uint64_t *)(&rhs[0]))); } switch (last - last_aligned) { case 7: count += bitset_popcount32(lhs[6] & ~rhs[6]); /* FALLTHRU */ case 6: count += bitset_popcount32(lhs[5] & ~rhs[5]); /* FALLTHRU */ case 5: count += bitset_popcount32(lhs[4] & ~rhs[4]); /* FALLTHRU */ case 4: count += bitset_popcount32(lhs[3] & ~rhs[3]); /* FALLTHRU */ case 3: count += bitset_popcount32(lhs[2] & ~rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & ~rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & ~rhs[0]); } return count; } static inline size_t bitset_or_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 3) << 3); size_t count = 0; for (; lhs != last_aligned; lhs += 8, rhs += 8) { count += bitset_popcount64(*(uint64_t *)(&lhs[6]) | *(uint64_t *)(&rhs[6])); count += bitset_popcount64(*(uint64_t *)(&lhs[4]) | *(uint64_t *)(&rhs[4])); count += bitset_popcount64(*(uint64_t *)(&lhs[2]) | *(uint64_t *)(&rhs[2])); count += bitset_popcount64(*(uint64_t *)(&lhs[0]) | *(uint64_t *)(&rhs[0])); } switch (last - last_aligned) { case 7: count += bitset_popcount32(lhs[6] | rhs[6]); /* FALLTHRU */ case 6: count += bitset_popcount32(lhs[5] | rhs[5]); /* FALLTHRU */ case 5: count += bitset_popcount32(lhs[4] | rhs[4]); /* FALLTHRU */ case 4: count += bitset_popcount32(lhs[3] | rhs[3]); /* FALLTHRU */ case 3: count += bitset_popcount32(lhs[2] | rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] | rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] | rhs[0]); } return count; } #else // !__ARM_NEON && !AILEGO_M64 static inline size_t bitset_cardinality(const uint32_t *lhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4) { count += bitset_popcount32(lhs[3]); count += bitset_popcount32(lhs[2]); count += bitset_popcount32(lhs[1]); count += bitset_popcount32(lhs[0]); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0]); } return count; } static inline size_t bitset_xor_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += bitset_popcount32(lhs[3] ^ rhs[3]); count += bitset_popcount32(lhs[2] ^ rhs[2]); count += bitset_popcount32(lhs[1] ^ rhs[1]); count += bitset_popcount32(lhs[0] ^ rhs[0]); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] ^ rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] ^ rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] ^ rhs[0]); } return count; } static inline size_t bitset_and_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += bitset_popcount32(lhs[3] & rhs[3]); count += bitset_popcount32(lhs[2] & rhs[2]); count += bitset_popcount32(lhs[1] & rhs[1]); count += bitset_popcount32(lhs[0] & rhs[0]); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] & rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & rhs[0]); } return count; } static inline size_t bitset_andnot_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += bitset_popcount32(lhs[3] & ~rhs[3]); count += bitset_popcount32(lhs[2] & ~rhs[2]); count += bitset_popcount32(lhs[1] & ~rhs[1]); count += bitset_popcount32(lhs[0] & ~rhs[0]); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] & ~rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] & ~rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] & ~rhs[0]); } return count; } static inline size_t bitset_or_cardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { const uint32_t *last = lhs + size; const uint32_t *last_aligned = lhs + ((size >> 2) << 2); size_t count = 0; for (; lhs != last_aligned; lhs += 4, rhs += 4) { count += bitset_popcount32(lhs[3] | rhs[3]); count += bitset_popcount32(lhs[2] | rhs[2]); count += bitset_popcount32(lhs[1] | rhs[1]); count += bitset_popcount32(lhs[0] | rhs[0]); } switch (last - last_aligned) { case 3: count += bitset_popcount32(lhs[2] | rhs[2]); /* FALLTHRU */ case 2: count += bitset_popcount32(lhs[1] | rhs[1]); /* FALLTHRU */ case 1: count += bitset_popcount32(lhs[0] | rhs[0]); } return count; } #endif // __ARM_NEON && __aarch64__ namespace zvec { namespace ailego { void BitsetHelper::BitwiseAnd(uint32_t *lhs, const uint32_t *rhs, size_t size) { bitset_and(lhs, rhs, size); } void BitsetHelper::BitwiseAndnot(uint32_t *lhs, const uint32_t *rhs, size_t size) { bitset_andnot(lhs, rhs, size); } void BitsetHelper::BitwiseOr(uint32_t *lhs, const uint32_t *rhs, size_t size) { bitset_or(lhs, rhs, size); } void BitsetHelper::BitwiseXor(uint32_t *lhs, const uint32_t *rhs, size_t size) { bitset_xor(lhs, rhs, size); } void BitsetHelper::BitwiseNot(uint32_t *arr, size_t size) { bitset_not(arr, size); } bool BitsetHelper::TestAll(const uint32_t *arr, size_t size) { return bitset_test_all(arr, size); } bool BitsetHelper::TestAny(const uint32_t *arr, size_t size) { return bitset_test_any(arr, size); } bool BitsetHelper::TestNone(const uint32_t *arr, size_t size) { return bitset_test_none(arr, size); } size_t BitsetHelper::BitwiseAndCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { return bitset_and_cardinality(lhs, rhs, size); } size_t BitsetHelper::BitwiseOrCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { return bitset_or_cardinality(lhs, rhs, size); } size_t BitsetHelper::BitwiseAndnotCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { return bitset_andnot_cardinality(lhs, rhs, size); } size_t BitsetHelper::BitwiseXorCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size) { return bitset_xor_cardinality(lhs, rhs, size); } size_t BitsetHelper::Cardinality(const uint32_t *arr, size_t size) { return bitset_cardinality(arr, size); } bool BitsetHelper::test_all(void) const { return bitset_test_all(array_, size_); } bool BitsetHelper::test_any(void) const { return bitset_test_any(array_, size_); } bool BitsetHelper::test_none(void) const { return bitset_test_none(array_, size_); } size_t BitsetHelper::cardinality(void) const { return bitset_cardinality(array_, size_); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/bitset_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { /*! Bitset Helper */ class BitsetHelper { public: //! Constructor BitsetHelper(void) {} //! Constructor BitsetHelper(void *buf, size_t len) : array_(reinterpret_cast(buf)), size_(len / sizeof(uint32_t)) {} //! Mount a buffer as bitset void mount(void *buf, size_t len) { array_ = reinterpret_cast(buf); size_ = len / sizeof(uint32_t); } //! Umount the buffer void umount(void) { array_ = nullptr; size_ = 0u; } // !Clear the bitset void clear(void) { memset(array_, 0, sizeof(uint32_t) * size_); } //! Test a bit in bitset bool test(size_t num) const { ailego_assert_with((size_ << 5) > num, "overflow argument"); return ((array_[num >> 5] & (1u << (num & 0x1f))) != 0); } //! Set a bit in bitset void set(size_t num) { ailego_assert_with((size_ << 5) > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] |= mask; } //! Reset a bit in bitset void reset(size_t num) { ailego_assert_with((size_ << 5) > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] &= ~mask; } //! Toggle a bit in bitset void flip(size_t num) { ailego_assert_with((size_ << 5) > num, "overflow argument"); uint32_t mask = (1u << (num & 0x1f)); array_[num >> 5] ^= mask; } //! Extract the bitset to an array void extract(size_t base, std::vector *out) const { const uint32_t *iter = array_; const uint32_t *last = array_ + size_; for (; iter != last; ++iter) { uint32_t w = *iter; while (w != 0) { uint32_t c = ailego_ctz32(w); w &= ~(1u << c); out->push_back(base + c); } base += 32u; } } //! Extract the bitset to an array void extract(std::vector *out) const { this->extract(0, out); } //! Check if all bits are set to true bool test_all(void) const; //! Check if any bits are set to true bool test_any(void) const; //! Check if none of the bits are set to true bool test_none(void) const; //! Compute the cardinality of a bitset size_t cardinality(void) const; //! Calculate the size of buffer if it contains N bits static size_t BufferSize(size_t N) { return (((N + 0x1f) >> 5) << 2); } //! Calculate the count of bits can be contained static size_t BitsCount(size_t len) { return ((len >> 2) << 2); } //! Check if all bits are set to true static bool TestAll(const uint32_t *arr, size_t size); //! Check if cube bits are set to true static bool TestAny(const uint32_t *arr, size_t size); //! Check if none of the bits are set to true static bool TestNone(const uint32_t *arr, size_t size); //! Compute the AND cardinality between two bitsets static size_t BitwiseAndCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size); //! Compute the OR cardinality between two bitsets static size_t BitwiseOrCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size); //! Compute the ANDNOT cardinality between two bitsets static size_t BitwiseAndnotCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size); //! Compute the XOR cardinality between two bitsets static size_t BitwiseXorCardinality(const uint32_t *lhs, const uint32_t *rhs, size_t size); //! Compute the cardinality of a bitset static size_t Cardinality(const uint32_t *arr, size_t size); //! Perform binary AND static void BitwiseAnd(uint32_t *lhs, const uint32_t *rhs, size_t size); //! Perform binary AND_NOT static void BitwiseAndnot(uint32_t *lhs, const uint32_t *rhs, size_t size); //! Perform binary OR static void BitwiseOr(uint32_t *lhs, const uint32_t *rhs, size_t size); //! Perform binary XOR static void BitwiseXor(uint32_t *lhs, const uint32_t *rhs, size_t size); //! Perform binary NOT static void BitwiseNot(uint32_t *arr, size_t size); private: uint32_t *array_{nullptr}; size_t size_{0u}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/concurrency_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "concurrency_helper.h" #include #include #include #include #include namespace zvec { namespace ailego { // Refer to: // https://stackoverflow.com/questions/65551215/get-docker-cpu-memory-limit-inside-container ConcurrencyHelper::ConcurrencyHelper() { std::string cfs_quota_us = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"; std::string cfs_period_us = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"; concurrency_ = std::thread::hardware_concurrency(); if (FileHelper::IsExist(cfs_quota_us.c_str()) && FileHelper::IsExist(cfs_period_us.c_str())) { std::ifstream quota_ifs; std::string quota_str{""}; uint32_t quota_val = 0; quota_ifs.open(cfs_quota_us, std::ios::in); if (quota_ifs.is_open()) { quota_ifs >> quota_str; if (quota_str != "-1") { StringHelper::ToUint32(quota_str, "a_val); } quota_ifs.close(); } if (quota_val > 0) { std::ifstream period_ifs; std::string period_str{""}; uint32_t period_val = 0; period_ifs.open(cfs_period_us, std::ios::in); if (period_ifs.is_open()) { period_ifs >> period_str; StringHelper::ToUint32(period_str, &period_val); period_ifs.close(); } if (period_val > 0) { concurrency_ = (quota_val + period_val - 1) / period_val; } } } } uint32_t ConcurrencyHelper::container_aware_concurrency() { static ConcurrencyHelper concurrency_helper; return concurrency_helper.concurrency_; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/concurrency_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { class ConcurrencyHelper { public: ConcurrencyHelper(); //! get hardware concurrency from either vm or container static uint32_t container_aware_concurrency(); private: uint32_t concurrency_{0}; }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/dl_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "dl_helper.h" #if !defined(_WIN64) && !defined(_WIN32) #include #else #include #endif namespace zvec { namespace ailego { #if !defined(_WIN64) && !defined(_WIN32) void *DLHelper::Load(const char *path, std::string *err) { void *handle = dlopen(path, RTLD_NOW); if (!handle && err) { *err = dlerror(); } return handle; } void DLHelper::Unload(void *handle) { ailego_return_if_false(handle); dlclose(handle); } void *DLHelper::Symbol(void *handle, const char *symbol) { ailego_null_if_false(handle && symbol); return dlsym(handle, symbol); } #else void *DLHelper::Load(const char *path, std::string *err) { HMODULE handle = LoadLibraryA(path); if (!handle && err) { DWORD error_code = GetLastError(); LPSTR error_msg = nullptr; DWORD len = FormatMessageA( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, nullptr, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&error_msg, 0, nullptr); err->assign(error_msg, len); LocalFree(error_msg); } return handle; } void DLHelper::Unload(void *handle) { ailego_return_if_false(handle); FreeLibrary((HMODULE)handle); } void *DLHelper::Symbol(void *handle, const char *symbol) { ailego_null_if_false(handle && symbol); return GetProcAddress((HMODULE)handle, symbol); } #endif // !_WIN64 && !_WIN32 } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/dl_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace ailego { /*! Dynamic Library Helper */ struct DLHelper { //! Load library from path static void *Load(const char *path, std::string *err); //! Unload a library static void Unload(void *handle); //! Retrieve a symbol from a library handle static void *Symbol(void *handle, const char *symbol); //! Load library from path static void *Load(const std::string &path, std::string *err) { return DLHelper::Load(path.c_str(), err); } //! Retrieve a symbol from a library handle static void *Symbol(void *handle, const std::string &symbol) { return DLHelper::Symbol(handle, symbol.c_str()); } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/file_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #if defined(_WIN32) || defined(_WIN64) #include #else #if defined(__APPLE__) || defined(__MACH__) #include #endif #include #include #include #include #include #include #endif namespace zvec { namespace ailego { bool FileHelper::GetSelfPath(std::string *path) { #if defined(_WIN32) || defined(_WIN64) char buf[MAX_PATH]; DWORD len = GetModuleFileNameA(NULL, buf, MAX_PATH); #elif defined(__APPLE__) || defined(__MACH__) char buf[PATH_MAX]; size_t len = 0; char dirty_buf[PATH_MAX]; uint32_t size = sizeof(dirty_buf); if (_NSGetExecutablePath(dirty_buf, &size) == 0) { realpath(dirty_buf, buf); len = strlen(buf); } #elif defined(__FreeBSD__) char buf[PATH_MAX]; size_t len = PATH_MAX; int mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, -1}; if (sysctl(mib, 4, &buf, &len, NULL, 0) != 0) { len = 0; } #else char buf[PATH_MAX]; ssize_t len = readlink("/proc/self/exe", buf, PATH_MAX); #endif if (len <= 0) { return false; } path->assign(buf, len); return true; } bool FileHelper::GetFilePath(NativeHandle handle, std::string *path) { #if defined(_WIN32) || defined(_WIN64) char buf[MAX_PATH]; DWORD len = GetFinalPathNameByHandleA(handle, buf, MAX_PATH, FILE_NAME_OPENED); #elif defined(__linux) || defined(__linux__) char buf[PATH_MAX]; char src[32]; snprintf(src, sizeof(src), "/proc/self/fd/%d", handle); ssize_t len = readlink(src, buf, PATH_MAX); #else char buf[PATH_MAX]; size_t len = 0; if (fcntl(handle, F_GETPATH, buf) != -1) { len = strlen(buf); } #endif if (len <= 0) { return false; } path->assign(buf, len); return true; } #if !defined(_WIN32) && !defined(_WIN64) static inline char *JoinFilePath(const char *prefix, const char *suffix) { size_t prefix_len = strlen(prefix); size_t suffix_len = strlen(suffix); char *path = (char *)malloc(prefix_len + suffix_len + 2); if (path) { memcpy(path, prefix, prefix_len); memcpy(path + prefix_len + 1, suffix, suffix_len); path[prefix_len] = '/'; path[prefix_len + suffix_len + 1] = '\0'; } return path; } bool FileHelper::GetWorkingDirectory(std::string *path) { char buf[PATH_MAX]; if (!getcwd(buf, PATH_MAX)) { return false; } path->assign(buf); return !path->empty(); } bool FileHelper::GetFileSize(const char *path, size_t *psz) { struct stat buf; if (stat(path, &buf) != 0) { return false; } *psz = buf.st_size; return true; } bool FileHelper::DeleteFile(const char *path) { // Delete a file by the path return (unlink(path) == 0); } bool FileHelper::RenameFile(const char *oldpath, const char *newpath) { return (rename(oldpath, newpath) == 0); } bool FileHelper::MakePath(const char *path) { char pathbuf[PATH_MAX]; char *sp, *pp; strncpy(pathbuf, path, sizeof(pathbuf) - 1); pathbuf[PATH_MAX - 1] = '\0'; pp = pathbuf; while ((sp = strchr(pp, '/')) != nullptr) { // Neither root nor double slash in path if (sp != pp) { *sp = '\0'; if (mkdir(pathbuf, 0755) == -1 && errno != EEXIST) { return false; } *sp = '/'; } pp = sp + 1; } return !(*pp != '\0' && mkdir(pathbuf, 0755) == -1 && errno != EEXIST); } bool FileHelper::RemoveDirectory(const char *path) { DIR *dir = opendir(path); if (!dir) { return false; } struct dirent *dent; while ((dent = readdir(dir)) != nullptr) { if (!strcmp(dent->d_name, ".") || !strcmp(dent->d_name, "..")) { continue; } char *fullpath = JoinFilePath(path, dent->d_name); if (!fullpath) { continue; } if (FileHelper::IsDirectory(fullpath)) { FileHelper::RemoveDirectory(fullpath); } else { FileHelper::DeleteFile(fullpath); } free(fullpath); } closedir(dir); return (rmdir(path) == 0); } bool FileHelper::IsExist(const char *path) { return (access(path, F_OK) == 0); } bool FileHelper::IsRegular(const char *path) { struct stat buf; if (stat(path, &buf) != 0) { return false; } return ((buf.st_mode & S_IFREG) != 0); } bool FileHelper::IsDirectory(const char *path) { struct stat buf; if (stat(path, &buf) != 0) { return false; } return ((buf.st_mode & S_IFDIR) != 0); } bool FileHelper::IsSymbolicLink(const char *path) { struct stat buf; if (stat(path, &buf) != 0) { return false; } return ((buf.st_mode & S_IFLNK) != 0); } bool FileHelper::IsSame(const char *path1, const char *path2) { char real_path1[PATH_MAX]; char real_path2[PATH_MAX]; if (!realpath(path1, real_path1)) { return false; } if (!realpath(path2, real_path2)) { return false; } return (!strcmp(real_path1, real_path2)); } #else #undef RemoveDirectory #undef DeleteFile #undef GetFileSize static inline char *JoinFilePath(const char *prefix, const char *suffix) { size_t prefix_len = strlen(prefix); size_t suffix_len = strlen(suffix); char *path = (char *)malloc(prefix_len + suffix_len + 2); if (path) { memcpy(path, prefix, prefix_len); memcpy(path + prefix_len + 1, suffix, suffix_len); path[prefix_len] = '\\'; path[prefix_len + suffix_len + 1] = '\0'; } return path; } bool FileHelper::GetWorkingDirectory(std::string *path) { char buf[MAX_PATH]; DWORD len = GetCurrentDirectoryA(MAX_PATH, buf); if (len <= 0) { return false; } path->assign(buf, len); return true; } bool FileHelper::GetFileSize(const char *path, size_t *psz) { HANDLE handle = CreateFileA(path, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); LARGE_INTEGER file_size; if (!GetFileSizeEx(handle, &file_size)) { return false; } *psz = (size_t)file_size.QuadPart; return true; } bool FileHelper::DeleteFile(const char *path) { // Delete a file by the path return (DeleteFileA(path)); } bool FileHelper::RenameFile(const char *oldpath, const char *newpath) { return (MoveFileA(oldpath, newpath)); } bool FileHelper::MakePath(const char *path) { char pathbuf[MAX_PATH]; char *sp, *pp; strncpy(pathbuf, path, sizeof(pathbuf) - 1); pathbuf[MAX_PATH - 1] = '\0'; pp = pathbuf; while ((sp = strpbrk(pp, "/\\")) != nullptr) { // Neither root nor double slash in path if (sp != pp) { *sp = '\0'; if (!CreateDirectoryA(pathbuf, nullptr) && GetLastError() != ERROR_ALREADY_EXISTS) { return false; } *sp = '\\'; } pp = sp + 1; } return !(*pp != '\0' && !CreateDirectoryA(pathbuf, nullptr) && GetLastError() != ERROR_ALREADY_EXISTS); } bool FileHelper::RemoveDirectory(const char *path) { char *pathbuf = JoinFilePath(path, "*.*"); ailego_false_if_false(pathbuf); WIN32_FIND_DATAA file_info; HANDLE file = FindFirstFileA(pathbuf, &file_info); ailego_do_if_false(file != INVALID_HANDLE_VALUE) { free(pathbuf); FindClose(file); return false; } do { if (!strcmp(file_info.cFileName, ".") || !strcmp(file_info.cFileName, "..")) { continue; } char *fullpath = JoinFilePath(path, file_info.cFileName); if (!fullpath) { continue; } if (file_info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { FileHelper::RemoveDirectory(fullpath); } else { FileHelper::DeleteFile(fullpath); } free(fullpath); } while (FindNextFileA(file, &file_info)); free(pathbuf); FindClose(file); return (!!RemoveDirectoryA(path)); } bool FileHelper::IsExist(const char *path) { DWORD attr = GetFileAttributesA(path); return (attr != INVALID_FILE_ATTRIBUTES); } bool FileHelper::IsRegular(const char *path) { DWORD attr = GetFileAttributesA(path); return (attr != INVALID_FILE_ATTRIBUTES && !(attr & FILE_ATTRIBUTE_DIRECTORY)); } bool FileHelper::IsDirectory(const char *path) { DWORD attr = GetFileAttributesA(path); return (attr != INVALID_FILE_ATTRIBUTES && (attr & FILE_ATTRIBUTE_DIRECTORY)); } bool FileHelper::IsSymbolicLink(const char *path) { DWORD attr = GetFileAttributesA(path); return (attr != INVALID_FILE_ATTRIBUTES && (attr & FILE_ATTRIBUTE_REPARSE_POINT)); } bool FileHelper::IsSame(const char *path1, const char *path2) { char real_path1[MAX_PATH]; char real_path2[MAX_PATH]; char **part_path1 = nullptr; char **part_path2 = nullptr; DWORD path1_size = GetFullPathNameA(path1, sizeof(real_path1), real_path1, part_path1); DWORD path2_size = GetFullPathNameA(path2, sizeof(real_path2), real_path2, part_path2); if ((part_path1 && *part_path1 != 0) || (part_path2 && *part_path2 != 0) || (path1_size != path2_size)) { return false; } return (!strcmp(real_path1, real_path2)); } #endif // !_WIN32 && !_WIN64 bool FileHelper::RemovePath(const char *path) { if (FileHelper::IsDirectory(path)) { return FileHelper::RemoveDirectory(path); } return FileHelper::DeleteFile(path); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/float_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include // #if defined(__F16C__) && defined(__AVX__) // #define float16(x) _cvtss_sh((x), _MM_FROUND_NO_EXC) // #define float32(x) _cvtsh_ss(x) // #endif // __F16C__ && __AVX__ #if defined(__aarch64__) static inline float float32(uint16_t val) { __fp16 *p = reinterpret_cast<__fp16 *>(&val); return *p; } static inline uint16_t float16(float val) { __fp16 f = static_cast<__fp16>(val); uint16_t *fp = reinterpret_cast(&f); return *fp; } static inline void convert_fp16_to_fp32(const uint16_t *arr, size_t size, float *out) { for (size_t i = 0; i != size; ++i) { out[i] = float32(arr[i]); } } static inline void convert_fp16_to_fp32(const uint16_t *arr, size_t size, float norm, float *out) { for (size_t i = 0; i != size; ++i) { out[i] = float32(arr[i]) / norm; } } static inline void convert_fp32_to_fp16(const float *arr, size_t size, uint16_t *out) { for (size_t i = 0; i != size; ++i) { out[i] = float16(arr[i]); } } static inline void convert_fp32_to_fp16(const float *arr, size_t size, float norm, uint16_t *out) { for (size_t i = 0; i != size; ++i) { out[i] = float16(arr[i] / norm); } } #else // Refer: https://github.com/Maratyszcza/FP16/blob/master/third-party/half.hpp static inline float float32(uint16_t val) { static const uint32_t mantissa_table[2048] = { 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; static const uint32_t exponent_table[64] = { 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000}; static const uint16_t offset_table[64] = { 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; uint16_t hval = static_cast(val >> 10); uint32_t bits = mantissa_table[offset_table[hval] + (val & 0x3FF)] + exponent_table[hval]; float *p = reinterpret_cast(&bits); return (*p); } // Refer: https://github.com/Maratyszcza/FP16/blob/master/third-party/half.hpp static inline uint16_t float16(float val) { static const uint16_t base_table[512] = { 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00}; static const uint8_t shift_table[512] = { 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; uint32_t *p = reinterpret_cast(&val); uint32_t hbits = base_table[*p >> 23] + static_cast((*p & 0x7FFFFF) >> shift_table[*p >> 23]); hbits += (((*p & 0x7FFFFF) >> (shift_table[*p >> 23] - 1)) | (((*p >> 23) & 0xFF) == 102)) & ((hbits & 0x7C00) != 0x7C00); return static_cast(hbits); } #if defined(__F16C__) && defined(__AVX512F__) static inline void convert_fp16_to_fp32_avx512f(const uint16_t *arr, size_t size, float *out) { const uint16_t *last = arr + size; const uint16_t *last_aligned = arr + ((size >> 5) << 5); if (((uintptr_t)arr & 0x1f) == 0 && ((uintptr_t)out & 0x3f) == 0) { for (; arr != last_aligned; arr += 32, out += 32) { _mm512_store_ps(out + 0, _mm512_cvtph_ps(_mm256_load_si256((__m256i *)(arr + 0)))); _mm512_store_ps( out + 16, _mm512_cvtph_ps(_mm256_load_si256((__m256i *)(arr + 16)))); } if (last >= last_aligned + 16) { _mm512_store_ps(out, _mm512_cvtph_ps(_mm256_load_si256((__m256i *)arr))); arr += 16; out += 16; } if (last >= arr + 8) { _mm256_store_ps(out, _mm256_cvtph_ps(_mm_load_si128((__m128i *)arr))); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 32, out += 32) { _mm512_storeu_ps( out + 0, _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(arr + 0)))); _mm512_storeu_ps( out + 16, _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(arr + 16)))); } if (last >= last_aligned + 16) { _mm512_storeu_ps(out, _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)arr))); arr += 16; out += 16; } if (last >= arr + 8) { _mm256_storeu_ps(out, _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)arr))); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = float32(arr[6]); /* FALLTHRU */ case 6: out[5] = float32(arr[5]); /* FALLTHRU */ case 5: out[4] = float32(arr[4]); /* FALLTHRU */ case 4: out[3] = float32(arr[3]); /* FALLTHRU */ case 3: out[2] = float32(arr[2]); /* FALLTHRU */ case 2: out[1] = float32(arr[1]); /* FALLTHRU */ case 1: out[0] = float32(arr[0]); } } static inline void convert_fp16_to_fp32_avx512f(const uint16_t *arr, size_t size, float norm, float *out) { const uint16_t *last = arr + size; const uint16_t *last_aligned = arr + ((size >> 5) << 5); __m512 zmm_norm = _mm512_set1_ps(norm); if (((uintptr_t)arr & 0x1f) == 0 && ((uintptr_t)out & 0x3f) == 0) { for (; arr != last_aligned; arr += 32, out += 32) { __m512 zmm_0 = _mm512_div_ps( _mm512_cvtph_ps(_mm256_load_si256((__m256i *)(arr + 0))), zmm_norm); __m512 zmm_1 = _mm512_div_ps( _mm512_cvtph_ps(_mm256_load_si256((__m256i *)(arr + 16))), zmm_norm); _mm512_store_ps(out + 0, zmm_0); _mm512_store_ps(out + 16, zmm_1); } if (last >= last_aligned + 16) { _mm512_store_ps( out, _mm512_div_ps(_mm512_cvtph_ps(_mm256_load_si256((__m256i *)arr)), zmm_norm)); arr += 16; out += 16; } if (last >= arr + 8) { _mm256_store_ps( out, _mm256_div_ps(_mm256_cvtph_ps(_mm_load_si128((__m128i *)arr)), _mm256_set1_ps(norm))); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 32, out += 32) { __m512 zmm_0 = _mm512_div_ps( _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(arr + 0))), zmm_norm); __m512 zmm_1 = _mm512_div_ps( _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(arr + 16))), zmm_norm); _mm512_storeu_ps(out + 0, zmm_0); _mm512_storeu_ps(out + 16, zmm_1); } if (last >= last_aligned + 16) { _mm512_storeu_ps( out, _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)arr)), zmm_norm)); arr += 16; out += 16; } if (last >= arr + 8) { _mm256_storeu_ps( out, _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((__m128i *)arr)), _mm256_set1_ps(norm))); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = float32(arr[6]) / norm; /* FALLTHRU */ case 6: out[5] = float32(arr[5]) / norm; /* FALLTHRU */ case 5: out[4] = float32(arr[4]) / norm; /* FALLTHRU */ case 4: out[3] = float32(arr[3]) / norm; /* FALLTHRU */ case 3: out[2] = float32(arr[2]) / norm; /* FALLTHRU */ case 2: out[1] = float32(arr[1]) / norm; /* FALLTHRU */ case 1: out[0] = float32(arr[0]) / norm; } } static inline void convert_fp32_to_fp16_avx512f(const float *arr, size_t size, uint16_t *out) { const float *last = arr + size; const float *last_aligned = arr + ((size >> 5) << 5); if (((uintptr_t)arr & 0x3f) == 0 && ((uintptr_t)out & 0x1f) == 0) { for (; arr != last_aligned; arr += 32, out += 32) { _mm256_store_si256( (__m256i *)(out + 0), _mm512_cvtps_ph(_mm512_load_ps(arr + 0), _MM_FROUND_NO_EXC)); _mm256_store_si256( (__m256i *)(out + 16), _mm512_cvtps_ph(_mm512_load_ps(arr + 16), _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 16) { _mm256_store_si256( (__m256i *)(out + 0), _mm512_cvtps_ph(_mm512_load_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 16; out += 16; } if (last >= arr + 8) { _mm_store_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_load_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 32, out += 32) { _mm256_storeu_si256( (__m256i *)(out + 0), _mm512_cvtps_ph(_mm512_loadu_ps(arr + 0), _MM_FROUND_NO_EXC)); _mm256_storeu_si256( (__m256i *)(out + 16), _mm512_cvtps_ph(_mm512_loadu_ps(arr + 16), _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 16) { _mm256_storeu_si256( (__m256i *)(out + 0), _mm512_cvtps_ph(_mm512_loadu_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 16; out += 16; } if (last >= arr + 8) { _mm_storeu_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_loadu_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = float16(arr[6]); /* FALLTHRU */ case 6: out[5] = float16(arr[5]); /* FALLTHRU */ case 5: out[4] = float16(arr[4]); /* FALLTHRU */ case 4: out[3] = float16(arr[3]); /* FALLTHRU */ case 3: out[2] = float16(arr[2]); /* FALLTHRU */ case 2: out[1] = float16(arr[1]); /* FALLTHRU */ case 1: out[0] = float16(arr[0]); } } static inline void convert_fp32_to_fp16_avx512f(const float *arr, size_t size, float norm, uint16_t *out) { const float *last = arr + size; const float *last_aligned = arr + ((size >> 5) << 5); __m512 zmm_norm = _mm512_set1_ps(norm); if (((uintptr_t)arr & 0x3f) == 0 && ((uintptr_t)out & 0x1f) == 0) { for (; arr != last_aligned; arr += 32, out += 32) { __m512 zmm_0 = _mm512_div_ps(_mm512_load_ps(arr + 0), zmm_norm); __m512 zmm_1 = _mm512_div_ps(_mm512_load_ps(arr + 16), zmm_norm); _mm256_store_si256((__m256i *)(out + 0), _mm512_cvtps_ph(zmm_0, _MM_FROUND_NO_EXC)); _mm256_store_si256((__m256i *)(out + 16), _mm512_cvtps_ph(zmm_1, _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 16) { _mm256_store_si256( (__m256i *)out, _mm512_cvtps_ph(_mm512_div_ps(_mm512_load_ps(arr), zmm_norm), _MM_FROUND_NO_EXC)); arr += 16; out += 16; } if (last >= arr + 8) { _mm_store_si128((__m128i *)out, _mm256_cvtps_ph(_mm256_div_ps(_mm256_load_ps(arr), _mm256_set1_ps(norm)), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 32, out += 32) { __m512 zmm_0 = _mm512_div_ps(_mm512_loadu_ps(arr + 0), zmm_norm); __m512 zmm_1 = _mm512_div_ps(_mm512_loadu_ps(arr + 16), zmm_norm); _mm256_storeu_si256((__m256i *)(out + 0), _mm512_cvtps_ph(zmm_0, _MM_FROUND_NO_EXC)); _mm256_storeu_si256((__m256i *)(out + 16), _mm512_cvtps_ph(zmm_1, _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 16) { _mm256_storeu_si256( (__m256i *)out, _mm512_cvtps_ph(_mm512_div_ps(_mm512_loadu_ps(arr), zmm_norm), _MM_FROUND_NO_EXC)); arr += 16; out += 16; } if (last >= arr + 8) { _mm_storeu_si128((__m128i *)out, _mm256_cvtps_ph(_mm256_div_ps(_mm256_loadu_ps(arr), _mm256_set1_ps(norm)), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = float16(arr[6] / norm); /* FALLTHRU */ case 6: out[5] = float16(arr[5] / norm); /* FALLTHRU */ case 5: out[4] = float16(arr[4] / norm); /* FALLTHRU */ case 4: out[3] = float16(arr[3] / norm); /* FALLTHRU */ case 3: out[2] = float16(arr[2] / norm); /* FALLTHRU */ case 2: out[1] = float16(arr[1] / norm); /* FALLTHRU */ case 1: out[0] = float16(arr[0] / norm); } } #endif //__F16C__ && __AVX512F__ #if defined(__F16C__) && defined(__AVX__) static inline void convert_fp16_to_fp32_avx(const uint16_t *arr, size_t size, float *out) { const uint16_t *last = arr + size; const uint16_t *last_aligned = arr + ((size >> 4) << 4); if (((uintptr_t)arr & 0xf) == 0 && ((uintptr_t)out & 0x1f) == 0) { for (; arr != last_aligned; arr += 16, out += 16) { _mm256_store_ps(out + 0, _mm256_cvtph_ps(_mm_load_si128((__m128i *)(arr + 0)))); _mm256_store_ps(out + 8, _mm256_cvtph_ps(_mm_load_si128((__m128i *)(arr + 8)))); } if (last >= last_aligned + 8) { _mm256_store_ps(out + 0, _mm256_cvtph_ps(_mm_load_si128((__m128i *)(arr + 0)))); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 16, out += 16) { _mm256_storeu_ps(out + 0, _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(arr + 0)))); _mm256_storeu_ps(out + 8, _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(arr + 8)))); } if (last >= last_aligned + 8) { _mm256_storeu_ps(out + 0, _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(arr + 0)))); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = _cvtsh_ss(arr[6]); /* FALLTHRU */ case 6: out[5] = _cvtsh_ss(arr[5]); /* FALLTHRU */ case 5: out[4] = _cvtsh_ss(arr[4]); /* FALLTHRU */ case 4: out[3] = _cvtsh_ss(arr[3]); /* FALLTHRU */ case 3: out[2] = _cvtsh_ss(arr[2]); /* FALLTHRU */ case 2: out[1] = _cvtsh_ss(arr[1]); /* FALLTHRU */ case 1: out[0] = _cvtsh_ss(arr[0]); } } static inline void convert_fp16_to_fp32_avx(const uint16_t *arr, size_t size, float norm, float *out) { const uint16_t *last = arr + size; const uint16_t *last_aligned = arr + ((size >> 4) << 4); __m256 ymm_norm = _mm256_set1_ps(norm); if (((uintptr_t)arr & 0xf) == 0 && ((uintptr_t)out & 0x1f) == 0) { for (; arr != last_aligned; arr += 16, out += 16) { __m256 ymm_0 = _mm256_cvtph_ps(_mm_load_si128((__m128i *)(arr + 0))); __m256 ymm_1 = _mm256_cvtph_ps(_mm_load_si128((__m128i *)(arr + 8))); ymm_0 = _mm256_div_ps(ymm_0, ymm_norm); ymm_1 = _mm256_div_ps(ymm_1, ymm_norm); _mm256_store_ps(out + 0, ymm_0); _mm256_store_ps(out + 8, ymm_1); } if (last >= last_aligned + 8) { _mm256_store_ps( out, _mm256_div_ps(_mm256_cvtph_ps(_mm_load_si128((__m128i *)arr)), ymm_norm)); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 16, out += 16) { __m256 ymm_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(arr + 0))); __m256 ymm_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(arr + 8))); ymm_0 = _mm256_div_ps(ymm_0, ymm_norm); ymm_1 = _mm256_div_ps(ymm_1, ymm_norm); _mm256_storeu_ps(out + 0, ymm_0); _mm256_storeu_ps(out + 8, ymm_1); } if (last >= last_aligned + 8) { _mm256_storeu_ps( out, _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((__m128i *)arr)), ymm_norm)); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = _cvtsh_ss(arr[6]) / norm; /* FALLTHRU */ case 6: out[5] = _cvtsh_ss(arr[5]) / norm; /* FALLTHRU */ case 5: out[4] = _cvtsh_ss(arr[4]) / norm; /* FALLTHRU */ case 4: out[3] = _cvtsh_ss(arr[3]) / norm; /* FALLTHRU */ case 3: out[2] = _cvtsh_ss(arr[2]) / norm; /* FALLTHRU */ case 2: out[1] = _cvtsh_ss(arr[1]) / norm; /* FALLTHRU */ case 1: out[0] = _cvtsh_ss(arr[0]) / norm; } } static inline void convert_fp32_to_fp16_avx(const float *arr, size_t size, uint16_t *out) { const float *last = arr + size; const float *last_aligned = arr + ((size >> 4) << 4); if (((uintptr_t)arr & 0x1f) == 0 && ((uintptr_t)out & 0xf) == 0) { for (; arr != last_aligned; arr += 16, out += 16) { _mm_store_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_load_ps(arr + 0), _MM_FROUND_NO_EXC)); _mm_store_si128( (__m128i *)(out + 8), _mm256_cvtps_ph(_mm256_load_ps(arr + 8), _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 8) { _mm_store_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_load_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 16, out += 16) { _mm_storeu_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_loadu_ps(arr + 0), _MM_FROUND_NO_EXC)); _mm_storeu_si128( (__m128i *)(out + 8), _mm256_cvtps_ph(_mm256_loadu_ps(arr + 8), _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 8) { _mm_storeu_si128( (__m128i *)(out + 0), _mm256_cvtps_ph(_mm256_loadu_ps(arr + 0), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = _cvtss_sh(arr[6], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 6: out[5] = _cvtss_sh(arr[5], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 5: out[4] = _cvtss_sh(arr[4], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 4: out[3] = _cvtss_sh(arr[3], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 3: out[2] = _cvtss_sh(arr[2], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 2: out[1] = _cvtss_sh(arr[1], _MM_FROUND_NO_EXC); /* FALLTHRU */ case 1: out[0] = _cvtss_sh(arr[0], _MM_FROUND_NO_EXC); } } static inline void convert_fp32_to_fp16_avx(const float *arr, size_t size, float norm, uint16_t *out) { const float *last = arr + size; const float *last_aligned = arr + ((size >> 4) << 4); __m256 ymm_norm = _mm256_set1_ps(norm); if (((uintptr_t)arr & 0x1f) == 0 && ((uintptr_t)out & 0xf) == 0) { for (; arr != last_aligned; arr += 16, out += 16) { __m256 ymm_0 = _mm256_load_ps(arr + 0); __m256 ymm_1 = _mm256_load_ps(arr + 8); ymm_0 = _mm256_div_ps(ymm_0, ymm_norm); ymm_1 = _mm256_div_ps(ymm_1, ymm_norm); _mm_store_si128((__m128i *)(out + 0), _mm256_cvtps_ph(ymm_0, _MM_FROUND_NO_EXC)); _mm_store_si128((__m128i *)(out + 8), _mm256_cvtps_ph(ymm_1, _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 8) { _mm_store_si128( (__m128i *)out, _mm256_cvtps_ph(_mm256_div_ps(_mm256_load_ps(arr), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } else { for (; arr != last_aligned; arr += 16, out += 16) { __m256 ymm_0 = _mm256_loadu_ps(arr + 0); __m256 ymm_1 = _mm256_loadu_ps(arr + 8); ymm_0 = _mm256_div_ps(ymm_0, ymm_norm); ymm_1 = _mm256_div_ps(ymm_1, ymm_norm); _mm_storeu_si128((__m128i *)(out + 0), _mm256_cvtps_ph(ymm_0, _MM_FROUND_NO_EXC)); _mm_storeu_si128((__m128i *)(out + 8), _mm256_cvtps_ph(ymm_1, _MM_FROUND_NO_EXC)); } if (last >= last_aligned + 8) { _mm_storeu_si128( (__m128i *)out, _mm256_cvtps_ph(_mm256_div_ps(_mm256_loadu_ps(arr), ymm_norm), _MM_FROUND_NO_EXC)); arr += 8; out += 8; } } switch (last - arr) { case 7: out[6] = _cvtss_sh(arr[6] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 6: out[5] = _cvtss_sh(arr[5] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 5: out[4] = _cvtss_sh(arr[4] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 4: out[3] = _cvtss_sh(arr[3] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 3: out[2] = _cvtss_sh(arr[2] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 2: out[1] = _cvtss_sh(arr[1] / norm, _MM_FROUND_NO_EXC); /* FALLTHRU */ case 1: out[0] = _cvtss_sh(arr[0] / norm, _MM_FROUND_NO_EXC); } } #endif // __F16C__ && __AVX__ static inline void convert_fp16_to_fp32_fallback(const uint16_t *arr, size_t size, float *out) { for (size_t i = 0; i != size; ++i) { out[i] = float32(arr[i]); } } static inline void convert_fp16_to_fp32_fallback(const uint16_t *arr, size_t size, float norm, float *out) { for (size_t i = 0; i != size; ++i) { out[i] = float32(arr[i]) / norm; } } static inline void convert_fp32_to_fp16_fallback(const float *arr, size_t size, uint16_t *out) { for (size_t i = 0; i != size; ++i) { out[i] = float16(arr[i]); } } static inline void convert_fp32_to_fp16_fallback(const float *arr, size_t size, float norm, uint16_t *out) { for (size_t i = 0; i != size; ++i) { out[i] = float16(arr[i] / norm); } } static inline void convert_fp16_to_fp32(const uint16_t *arr, size_t size, float *out) { #if defined(__F16C__) && defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return convert_fp16_to_fp32_avx512f(arr, size, out); } #endif #if defined(__F16C__) && defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { return convert_fp16_to_fp32_avx(arr, size, out); } #endif return convert_fp16_to_fp32_fallback(arr, size, out); } static inline void convert_fp16_to_fp32(const uint16_t *arr, size_t size, float norm, float *out) { #if defined(__F16C__) && defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return convert_fp16_to_fp32_avx512f(arr, size, norm, out); } #endif #if defined(__F16C__) && defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { return convert_fp16_to_fp32_avx(arr, size, norm, out); } #endif return convert_fp16_to_fp32_fallback(arr, size, norm, out); } static inline void convert_fp32_to_fp16(const float *arr, size_t size, uint16_t *out) { #if defined(__F16C__) && defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return convert_fp32_to_fp16_avx512f(arr, size, out); } #endif #if defined(__F16C__) && defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { return convert_fp32_to_fp16_avx(arr, size, out); } #endif return convert_fp32_to_fp16_fallback(arr, size, out); } static inline void convert_fp32_to_fp16(const float *arr, size_t size, float norm, uint16_t *out) { #if defined(__F16C__) && defined(__AVX512F__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX512F) { return convert_fp32_to_fp16_avx512f(arr, size, norm, out); } #endif #if defined(__F16C__) && defined(__AVX__) if (zvec::ailego::internal::CpuFeatures::static_flags_.F16C && zvec::ailego::internal::CpuFeatures::static_flags_.AVX) { return convert_fp32_to_fp16_avx(arr, size, norm, out); } #endif return convert_fp32_to_fp16_fallback(arr, size, norm, out); } #endif // namespace zvec { namespace ailego { float FloatHelper::ToFP32(uint16_t val) { return float32(val); } void FloatHelper::ToFP32(const uint16_t *arr, size_t size, float *out) { return convert_fp16_to_fp32(arr, size, out); } void FloatHelper::ToFP32(const uint16_t *arr, size_t size, float norm, float *out) { return convert_fp16_to_fp32(arr, size, norm, out); } uint16_t FloatHelper::ToFP16(float val) { return float16(val); } void FloatHelper::ToFP16(const float *arr, size_t size, uint16_t *out) { return convert_fp32_to_fp16(arr, size, out); } void FloatHelper::ToFP16(const float *arr, size_t size, float norm, uint16_t *out) { return convert_fp32_to_fp16(arr, size, norm, out); } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/math_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include namespace zvec { namespace ailego { /*! Math Helper */ struct MathHelper { //! Calculate the absolute value template static inline auto Absolute(const T &x) -> typename std::enable_if::value, R>::type { return static_cast(std::abs(x)); } //! Calculate the absolute value template static inline R Absolute(const Float16 &x) { return static_cast(Float16::Absolute(x)); } //! Calculate the absolute difference template static inline auto AbsoluteDifference(const T &x, const T &y) -> typename std::enable_if::value, R>::type { auto m = ((x ^ y) & -(x < y)); auto d = static_cast::type>((x ^ m) - (y ^ m)); return static_cast(d); } //! Calculate the absolute difference template static inline auto AbsoluteDifference(const T &x, const T &y) -> typename std::enable_if::value, R>::type { return static_cast(std::abs(x - y)); } //! Calculate the absolute difference template static inline R AbsoluteDifference(const Float16 &x, const Float16 &y) { return static_cast(std::abs(x - y)); } //! Calculate the squared difference template static inline auto SquaredDifference(const T &x, const T &y) -> typename std::enable_if::value, R>::type { auto m = ((x ^ y) & -(x < y)); auto d = static_cast::type>((x ^ m) - (y ^ m)); return static_cast(d * d); } //! Calculate the squared difference template static inline auto SquaredDifference(const T &x, const T &y) -> typename std::enable_if::value, R>::type { auto d = x - y; return static_cast(d * d); } //! Calculate the squared difference template static inline R SquaredDifference(const Float16 &x, const Float16 &y) { auto d = x - y; return static_cast(d * d); } //! Test whether two integral numbers are equal template static inline auto IsAlmostEqual(const T &x, const T &y, int) -> typename std::enable_if::value, bool>::type { return (x == y); } //! Test whether two floating point numbers are equal template static inline auto IsAlmostEqual(const T &x, const T &y, int ulp) -> typename std::enable_if::value, bool>::type { // the machine epsilon has to be scaled to the magnitude of the values used // and multiplied by the desired precision in ULPs (units in the last place) return ((std::fabs(x - y) <= std::numeric_limits::epsilon() * std::fabs(x + y) * ulp) || (std::fabs(x - y) < std::numeric_limits::min())); } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/matrix_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { struct MatrixHelper { //! Transpose a matrix template static inline void Transpose(const void *src, size_t N, void *dst) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { *(reinterpret_cast(dst) + (j * M + i)) = *(reinterpret_cast(src) + (i * N + j)); } } } //! Reverse transpose a matrix template static inline void ReverseTranspose(const void *src, size_t N, void *dst) { for (size_t i = 0; i < N; ++i) { for (size_t j = 0; j < M; ++j) { *(reinterpret_cast(dst) + (j * N + i)) = *(reinterpret_cast(src) + (i * M + j)); } } } //! Transpose a matrix template static inline void Transpose(const void *src, size_t M, size_t N, void *dst) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { *(reinterpret_cast(dst) + (j * M + i)) = *(reinterpret_cast(src) + (i * N + j)); } } } //! Reverse transpose a matrix template static inline void ReverseTranspose(const void *src, size_t M, size_t N, void *dst) { for (size_t i = 0; i < N; ++i) { for (size_t j = 0; j < M; ++j) { *(reinterpret_cast(dst) + (j * N + i)) = *(reinterpret_cast(src) + (i * M + j)); } } } }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/memory_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "memory_helper.h" #include #include #include #include #include #if defined(_WIN64) || defined(_WIN32) #include #include #else #if defined(__linux__) || defined(__linux) #include #elif defined(__APPLE__) && defined(__MACH__) #include #include #endif #include #endif namespace zvec { namespace ailego { #if defined(__linux__) || defined(__linux) bool MemoryHelper::SelfUsage(size_t *vsz, size_t *rss) { FILE *fp = fopen("/proc/self/statm", "r"); if (!fp) { return false; } if (fscanf(fp, "%zd %zd", vsz, rss) == EOF) { fclose(fp); return false; } fclose(fp); long pagesz = sysconf(_SC_PAGESIZE); *vsz *= (size_t)pagesz; *rss *= (size_t)pagesz; return true; } size_t MemoryHelper::SelfRSS(void) { FILE *fp = fopen("/proc/self/statm", "r"); if (!fp) { return 0; } size_t rss = 0; if (fscanf(fp, "%*d %zd %*d", &rss) == EOF) { fclose(fp); return 0; } fclose(fp); return (rss * sysconf(_SC_PAGESIZE)); } size_t MemoryHelper::SelfPeakRSS(void) { struct rusage rusage; getrusage(RUSAGE_SELF, &rusage); return (size_t)(rusage.ru_maxrss * 1024); } size_t MemoryHelper::TotalRamSize(void) { return (sysconf(_SC_PHYS_PAGES) * sysconf(_SC_PAGESIZE)); } size_t MemoryHelper::AvailableRamSize(void) { FILE *fp = fopen("/proc/meminfo", "r"); if (!fp) { return 0; } size_t avail = 0; char buf[128]; while (fgets(buf, sizeof(buf), fp)) { if (strncmp(buf, "MemAvailable:", 13) == 0) { avail = (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); break; } } // No found 'MemAvailable' if (avail == 0) { fseek(fp, 0L, SEEK_SET); size_t count = 0; while (fgets(buf, sizeof(buf), fp)) { switch (buf[0]) { case 'M': if (strncmp(buf, "MemFree:", 8) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; case 'B': if (strncmp(buf, "Buffers:", 8) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; case 'C': if (strncmp(buf, "Cached:", 7) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; } // All read if (count == 3) { break; } } } fclose(fp); return (avail * 1024); } size_t MemoryHelper::UsedRamSize(void) { FILE *fp = fopen("/proc/meminfo", "r"); if (!fp) { return 0; } size_t total = 0, avail = 0, count = 0; char buf[128]; while (fgets(buf, sizeof(buf), fp)) { switch (buf[0]) { case 'M': if (strncmp(buf, "MemTotal:", 9) == 0) { total = (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } else if (strncmp(buf, "MemFree:", 8) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; case 'B': if (strncmp(buf, "Buffers:", 8) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; case 'C': if (strncmp(buf, "Cached:", 7) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; case 'S': if (strncmp(buf, "Slab:", 5) == 0) { avail += (size_t)strtoull(strchr(buf, ':') + 1, NULL, 10); ++count; } break; } // All read if (count == 5) { break; } } fclose(fp); if (total == 0) { total = (sysconf(_SC_PHYS_PAGES) * sysconf(_SC_PAGESIZE)) / 1024; } return ((total - avail) * 1024); } size_t MemoryHelper::ContainerAwareTotalRamSize(void) { size_t total_ram_size = TotalRamSize(); std::string limit_in_bytes = "/sys/fs/cgroup/memory/memory.limit_in_bytes"; if (FileHelper::IsExist(limit_in_bytes.c_str())) { std::ifstream memory_limit_ifs; std::string memory_limit_str{""}; memory_limit_ifs.open(limit_in_bytes, std::ios::in); if (memory_limit_ifs.is_open()) { uint64_t limit = 0; memory_limit_ifs >> memory_limit_str; if (memory_limit_str != "-1") { // Refer to: // https://access.redhat.com/documentation/zh-cn/red_hat_enterprise_linux/7/html/resource_management_guide/sec-memory StringHelper::ToUint64(memory_limit_str, &limit); if (limit != 0x7FFFFFFFFFFFF000) { // Refer to: // https://stackoverflow.com/questions/70332396/why-cgroups-file-memory-limit-in-bytes-use-9223372036854771712-as-a-default-valu total_ram_size = static_cast(limit); } } memory_limit_ifs.close(); } } return total_ram_size; } #elif defined(__APPLE__) && defined(__MACH__) bool MemoryHelper::SelfUsage(size_t *vsz, size_t *rss) { struct mach_task_basic_info info; mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, (task_info_t)&info, &count) != KERN_SUCCESS) { return false; } *vsz = info.virtual_size; *rss = info.resident_size; return true; } size_t MemoryHelper::SelfRSS(void) { struct mach_task_basic_info info; mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, (task_info_t)&info, &count) != KERN_SUCCESS) { return 0; } return info.resident_size; } size_t MemoryHelper::SelfPeakRSS(void) { struct mach_task_basic_info info; mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, (task_info_t)&info, &count) != KERN_SUCCESS) { return 0; } return info.resident_size_max; } size_t MemoryHelper::TotalRamSize(void) { int mib[2] = {CTL_HW, HW_MEMSIZE}; uint64_t size = 0; size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, 0, 0) != 0) { return 0; } return (size_t)size; } size_t MemoryHelper::AvailableRamSize(void) { struct vm_statistics stat; mach_msg_type_number_t count = HOST_VM_INFO_COUNT; vm_size_t pagesize = 0; if (host_page_size(mach_host_self(), &pagesize) != KERN_SUCCESS) { return 0; } if (host_statistics(mach_host_self(), HOST_VM_INFO, (host_info_t)&stat, &count) != KERN_SUCCESS) { return 0; } return ((stat.free_count + stat.inactive_count) * pagesize); } size_t MemoryHelper::UsedRamSize(void) { struct vm_statistics stat; mach_msg_type_number_t count = HOST_VM_INFO_COUNT; vm_size_t pagesize = 0; if (host_page_size(mach_host_self(), &pagesize) != KERN_SUCCESS) { return 0; } if (host_statistics(mach_host_self(), HOST_VM_INFO, (host_info_t)&stat, &count) != KERN_SUCCESS) { return 0; } return ((stat.active_count + stat.wire_count) * pagesize); } size_t MemoryHelper::ContainerAwareTotalRamSize(void) { return 0u; } #elif defined(_WIN64) || defined(_WIN32) static inline int getpagesize(void) { SYSTEM_INFO info; GetSystemInfo(&info); return info.dwPageSize; } bool MemoryHelper::SelfUsage(size_t *vsz, size_t *rss) { PROCESS_MEMORY_COUNTERS info; if (!GetProcessMemoryInfo(GetCurrentProcess(), &info, sizeof(info))) { return false; } *vsz = (size_t)info.PagefileUsage; *rss = (size_t)info.WorkingSetSize; return true; } size_t MemoryHelper::SelfRSS(void) { PROCESS_MEMORY_COUNTERS info; if (!GetProcessMemoryInfo(GetCurrentProcess(), &info, sizeof(info))) { return 0u; } return (size_t)info.WorkingSetSize; } size_t MemoryHelper::SelfPeakRSS(void) { PROCESS_MEMORY_COUNTERS info; GetProcessMemoryInfo(GetCurrentProcess(), &info, sizeof(info)); return (size_t)info.PeakWorkingSetSize; } size_t MemoryHelper::TotalRamSize(void) { MEMORYSTATUSEX status; status.dwLength = sizeof(status); GlobalMemoryStatusEx(&status); return (size_t)status.ullTotalPhys; } size_t MemoryHelper::AvailableRamSize(void) { MEMORYSTATUSEX status; status.dwLength = sizeof(status); GlobalMemoryStatusEx(&status); return (size_t)status.ullAvailPhys; } size_t MemoryHelper::UsedRamSize(void) { MEMORYSTATUSEX status; status.dwLength = sizeof(status); GlobalMemoryStatusEx(&status); return (size_t)(status.ullTotalPhys - status.ullAvailPhys); } size_t MemoryHelper::ContainerAwareTotalRamSize(void) { return 0u; } #else bool MemoryHelper::SelfUsage(size_t *vsz, size_t *rss) { *vsz = 0u; *rss = 0u; return false; } size_t MemoryHelper::SelfRSS(void) { return 0u; } size_t MemoryHelper::SelfPeakRSS(void) { return 0u; } size_t MemoryHelper::TotalRamSize(void) { return 0u; } size_t MemoryHelper::AvailableRamSize(void) { return 0u; } size_t MemoryHelper::UsedRamSize(void) { return 0u; } size_t MemoryHelper::ContainerAwareTotalRamSize(void) { return 0u; } #endif size_t MemoryHelper::PageSize(void) { static size_t page_size = static_cast(getpagesize()); return page_size; } size_t MemoryHelper::HugePageSize(void) { static size_t page_size = static_cast(2 * 1024 * 1024); return page_size; } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/memory_helper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace ailego { /*! Memory Helper */ struct MemoryHelper { //! Retrieve the page size of memory static size_t PageSize(void); //! Retrieve the huge page size of memory static size_t HugePageSize(void); //! Retrieve the VSZ and RSS of self process in bytes static bool SelfUsage(size_t *vsz, size_t *rss); //! Retrieve the RSS of self process in bytes static size_t SelfRSS(void); //! Retrieve the peak RSS of self process in bytes static size_t SelfPeakRSS(void); //! Retrieve the total size of physical memory (RAM) in bytes static size_t TotalRamSize(void); //! Retrieve the available size of physical memory (RAM) in bytes static size_t AvailableRamSize(void); //! Retrieve the used size of physical memory (RAM) in bytes static size_t UsedRamSize(void); //! Retrieve the total size of physical memory (RAM) in bytes in container static size_t ContainerAwareTotalRamSize(void); }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/string_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include namespace zvec { namespace ailego { bool StringHelper::StartsWith(const std::string &ref, const std::string &prefix) { return (ref.size() >= prefix.size()) && (ref.compare(0, prefix.size(), prefix) == 0); } bool StringHelper::EndsWith(const std::string &ref, const std::string &suffix) { size_t s1 = ref.size(); size_t s2 = suffix.size(); return (s1 >= s2) && (ref.compare(s1 - s2, s2, suffix) == 0); } void StringHelper::LeftTrim(std::string &str) { str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](int ch) { return !std::isspace(ch); })); } void StringHelper::RightTrim(std::string &str) { str.erase(std::find_if(str.rbegin(), str.rend(), [](int ch) { return !std::isspace(ch); }) .base(), str.end()); } void StringHelper::Trim(std::string &str) { StringHelper::RightTrim(str); StringHelper::LeftTrim(str); } std::string StringHelper::CopyLeftTrim(std::string str) { StringHelper::LeftTrim(str); return str; } std::string StringHelper::CopyRightTrim(std::string str) { StringHelper::RightTrim(str); return str; } std::string StringHelper::CopyTrim(std::string str) { StringHelper::Trim(str); return str; } #if defined(_MSC_VER) #define strncasecmp _strnicmp #endif bool StringHelper::CompareIgnoreCase(const std::string &a, const std::string &b) { if (a.size() != b.size()) { return false; } return (strncasecmp(a.data(), b.data(), a.size()) == 0); } void StringHelper::Append(std::string *str, const internal::Alphameric &a) { str->reserve(str->size() + a.size()); str->append(a.data(), a.size()); } void StringHelper::Append(std::string *str, const internal::Alphameric &a, const internal::Alphameric &b) { str->reserve(str->size() + a.size() + b.size()); str->append(a.data(), a.size()); str->append(b.data(), b.size()); } void StringHelper::Append(std::string *str, const internal::Alphameric &a, const internal::Alphameric &b, const internal::Alphameric &c) { str->reserve(str->size() + a.size() + b.size() + c.size()); str->append(a.data(), a.size()); str->append(b.data(), b.size()); str->append(c.data(), c.size()); } void StringHelper::Append(std::string *str, const internal::Alphameric &a, const internal::Alphameric &b, const internal::Alphameric &c, const internal::Alphameric &d) { str->reserve(str->size() + a.size() + b.size() + c.size() + d.size()); str->append(a.data(), a.size()); str->append(b.data(), b.size()); str->append(c.data(), c.size()); str->append(d.data(), d.size()); } void StringHelper::AppendViews(std::string *str, std::initializer_list views) { size_t new_size = str->size(); for (auto &v : views) { new_size += v.size(); } str->reserve(new_size); for (auto &v : views) { str->append(v.data(), v.size()); } } } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/utility/time_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #if defined(_WIN64) || defined(_WIN32) #include #endif namespace zvec { namespace ailego { #if defined(_WIN64) || defined(_WIN32) uint64_t Monotime::NanoSeconds(void) { LARGE_INTEGER stamp, freq; QueryPerformanceFrequency(&freq); QueryPerformanceCounter(&stamp); return (uint64_t)((double)stamp.QuadPart * (1000000000.0 / (double)freq.QuadPart)); } uint64_t Monotime::MicroSeconds(void) { LARGE_INTEGER stamp, freq; QueryPerformanceFrequency(&freq); QueryPerformanceCounter(&stamp); return (stamp.QuadPart * 1000000u / freq.QuadPart); } uint64_t Monotime::MilliSeconds(void) { LARGE_INTEGER stamp, freq; QueryPerformanceFrequency(&freq); QueryPerformanceCounter(&stamp); return (stamp.QuadPart * 1000u / freq.QuadPart); } uint64_t Monotime::Seconds(void) { LARGE_INTEGER stamp, freq; QueryPerformanceFrequency(&freq); QueryPerformanceCounter(&stamp); return (stamp.QuadPart / freq.QuadPart); } // January 1, 1970 (start of Unix epoch) in "ticks" #define UNIX_TIME_START 0x019DB1DED53E8000ull uint64_t Realtime::NanoSeconds(void) { LARGE_INTEGER stamp; FILETIME file; GetSystemTimeAsFileTime(&file); stamp.HighPart = file.dwHighDateTime; stamp.LowPart = file.dwLowDateTime; return (stamp.QuadPart - UNIX_TIME_START) * 100u; } uint64_t Realtime::MicroSeconds(void) { LARGE_INTEGER stamp; FILETIME file; GetSystemTimeAsFileTime(&file); stamp.HighPart = file.dwHighDateTime; stamp.LowPart = file.dwLowDateTime; return (stamp.QuadPart - UNIX_TIME_START) / 10u; } uint64_t Realtime::MilliSeconds(void) { LARGE_INTEGER stamp; FILETIME file; GetSystemTimeAsFileTime(&file); stamp.HighPart = file.dwHighDateTime; stamp.LowPart = file.dwLowDateTime; return (stamp.QuadPart - UNIX_TIME_START) / 10000u; } uint64_t Realtime::Seconds(void) { LARGE_INTEGER stamp; FILETIME file; GetSystemTimeAsFileTime(&file); stamp.HighPart = file.dwHighDateTime; stamp.LowPart = file.dwLowDateTime; return (stamp.QuadPart - UNIX_TIME_START) / 10000000u; } size_t Realtime::Localtime(uint64_t stamp, const char *format, char *buf, size_t len) { time_t val = static_cast(stamp); return strftime(buf, len, format, localtime(&val)); } size_t Realtime::Gmtime(uint64_t stamp, const char *format, char *buf, size_t len) { time_t val = static_cast(stamp); return strftime(buf, len, format, gmtime(&val)); } size_t Realtime::Localtime(const char *format, char *buf, size_t len) { time_t now = time(0); return strftime(buf, len, format, localtime(&now)); } size_t Realtime::Gmtime(const char *format, char *buf, size_t len) { time_t now = time(0); return strftime(buf, len, format, gmtime(&now)); } #else uint64_t Monotime::NanoSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_MONOTONIC, &tspec); return (tspec.tv_sec * 1000000000u + tspec.tv_nsec); } uint64_t Monotime::MicroSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_MONOTONIC, &tspec); return (tspec.tv_sec * 1000000u + tspec.tv_nsec / 1000u); } uint64_t Monotime::MilliSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_MONOTONIC, &tspec); return (tspec.tv_sec * 1000u + tspec.tv_nsec / 1000000u); } uint64_t Monotime::Seconds(void) { struct timespec tspec; clock_gettime(CLOCK_MONOTONIC, &tspec); return (tspec.tv_sec); } uint64_t Realtime::NanoSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_REALTIME, &tspec); return (tspec.tv_sec * 1000000000u + tspec.tv_nsec); } uint64_t Realtime::MicroSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_REALTIME, &tspec); return (tspec.tv_sec * 1000000u + tspec.tv_nsec / 1000u); } uint64_t Realtime::MilliSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_REALTIME, &tspec); return (tspec.tv_sec * 1000u + tspec.tv_nsec / 1000000u); } uint64_t Realtime::Seconds(void) { struct timespec tspec; clock_gettime(CLOCK_REALTIME, &tspec); return (tspec.tv_sec); } size_t Realtime::Localtime(uint64_t stamp, const char *format, char *buf, size_t len) { struct tm tmbuf; time_t val = static_cast(stamp); return strftime(buf, len, format, localtime_r(&val, &tmbuf)); } size_t Realtime::Gmtime(uint64_t stamp, const char *format, char *buf, size_t len) { struct tm tmbuf; time_t val = static_cast(stamp); return strftime(buf, len, format, gmtime_r(&val, &tmbuf)); } size_t Realtime::Localtime(const char *format, char *buf, size_t len) { struct tm tmbuf; time_t now = time(0); return strftime(buf, len, format, localtime_r(&now, &tmbuf)); } size_t Realtime::Gmtime(const char *format, char *buf, size_t len) { struct tm tmbuf; time_t now = time(0); return strftime(buf, len, format, gmtime_r(&now, &tmbuf)); } uint64_t CPUtime::NanoSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tspec); return (tspec.tv_sec * 1000000000u + tspec.tv_nsec); } uint64_t CPUtime::MicroSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tspec); return (tspec.tv_sec * 1000000u + tspec.tv_nsec / 1000u); } uint64_t CPUtime::MilliSeconds(void) { struct timespec tspec; clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tspec); return (tspec.tv_sec * 1000u + tspec.tv_nsec / 1000000u); } uint64_t CPUtime::Seconds(void) { struct timespec tspec; clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tspec); return (tspec.tv_sec); } #endif // _WIN64 || _WIN32 } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/version.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "version.h" #include "version.i" #ifdef ailego_VERSION #define AILEGO_VERSION_STRING ailego_VERSION #else #define AILEGO_VERSION_STRING "unknown" #endif namespace zvec { namespace ailego { static const char AILEGO_VERSION_DETAILS[] = AILEGO_VERSION_COMPILE_DETAILS( "AiLego Library Version " AILEGO_VERSION_STRING ".\nCopyright (C) The Software Authors. All rights reserved.\n"); const char *Version::String(void) { return AILEGO_VERSION_STRING; } const char *Version::Details(void) { return AILEGO_VERSION_DETAILS; } } // namespace ailego } // namespace zvec // extern "C" int __wrap_main(int, char *[]) { // fwrite(ailego::AILEGO_VERSION_DETAILS, 1, // strlen(ailego::AILEGO_VERSION_DETAILS), stdout); // fflush(stdout); // _Exit(0); // } ================================================ FILE: src/ailego/version.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once namespace zvec { namespace ailego { /*! AiLego Version */ struct Version { //! Retrieve the version number in string static const char *String(void); //! Retrieve the detailed version information static const char *Details(void); }; } // namespace ailego } // namespace zvec ================================================ FILE: src/ailego/version.i ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #ifndef AILEGO_VERSION_TO_STRING_ #define AILEGO_VERSION_TO_STRING_(x) #x #endif #ifndef AILEGO_VERSION_TO_STRING #define AILEGO_VERSION_TO_STRING(x) AILEGO_VERSION_TO_STRING_(x) #endif /*! http://nadeausoftware.com/articles/2012/01/ * c_c_tip_how_use_compiler_predefined_macros_detect_operating_system */ #if defined(__linux) || defined(__linux__) #define AILEGO_VERSION_PLATFORM "Linux" #elif defined(__FreeBSD__) #define AILEGO_VERSION_PLATFORM "FreeBSD" #elif defined(__NetBSD__) #define AILEGO_VERSION_PLATFORM "NetBSD" #elif defined(__OpenBSD__) #define AILEGO_VERSION_PLATFORM "OpenBSD" #elif defined(__APPLE__) || defined(__MACH__) #define AILEGO_VERSION_PLATFORM "Darwin" #elif defined(__CYGWIN__) && !defined(_WIN32) #define AILEGO_VERSION_PLATFORM "Cygwin" #elif defined(_WIN64) #define AILEGO_VERSION_PLATFORM "Microsoft Windows (64-bit)" #elif defined(_WIN32) #define AILEGO_VERSION_PLATFORM "Microsoft Windows (32-bit)" #elif defined(__sun) && defined(__SVR4) #define AILEGO_VERSION_PLATFORM "Solaris" #elif defined(_AIX) #define AILEGO_VERSION_PLATFORM "AIX" #elif defined(__hpux) #define AILEGO_VERSION_PLATFORM "HP-UX" #elif defined(__unix) || defined(__unix__) #define AILEGO_VERSION_PLATFORM "Unix" #else #define AILEGO_VERSION_PLATFORM "Unknown Platform" #endif /*! http://nadeausoftware.com/articles/2012/10/ * c_c_tip_how_detect_compiler_name_and_version_using_compiler_predefined_macros */ #if defined(__NVCC__) #define AILEGO_VERSION_COMPILER_NAME "Nvidia CUDA Compiler" #elif defined(__clang__) #define AILEGO_VERSION_COMPILER_NAME "Clang/LLVM" #elif defined(__ICC) || defined(__INTEL_COMPILER) #define AILEGO_VERSION_COMPILER_NAME "Intel ICC/ICPC" #elif defined(__GNUC__) || defined(__GNUG__) #define AILEGO_VERSION_COMPILER_NAME "GNU GCC/G++" #elif defined(__HP_cc) || defined(__HP_aCC) #define AILEGO_VERSION_COMPILER_NAME "Hewlett-Packard C/aC++" #elif defined(__IBMC__) || defined(__IBMCPP__) #define AILEGO_VERSION_COMPILER_NAME "IBM XL C/C++" #elif defined(_MSC_VER) #define AILEGO_VERSION_COMPILER_NAME "Microsoft Visual C++" #elif defined(__PGI) #define AILEGO_VERSION_COMPILER_NAME "Portland Group PGCC/PGCPP" #elif defined(__SUNPRO_C) || defined(__SUNPRO_CC) #define AILEGO_VERSION_COMPILER_NAME "Oracle Solaris Studio" #else #define AILEGO_VERSION_COMPILER_NAME "Unknown Compiler" #endif #if defined(__CUDACC_VER_MAJOR__) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME \ " (" AILEGO_VERSION_TO_STRING(__CUDACC_VER_MAJOR__) \ "." AILEGO_VERSION_TO_STRING(__CUDACC_VER_MINOR__) \ "." AILEGO_VERSION_TO_STRING(__CUDACC_VER_BUILD__) ")" #elif defined(__VERSION__) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME " (" __VERSION__ ")" #elif defined(_MSC_FULL_VER) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME " (" AILEGO_VERSION_TO_STRING(_MSC_FULL_VER) ")" #elif defined(_MSC_VER) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME " (" AILEGO_VERSION_TO_STRING(_MSC_VER) ")" #elif defined(__PGIC__) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME \ " (" AILEGO_VERSION_TO_STRING(__PGIC__) "." AILEGO_VERSION_TO_STRING( \ __PGIC_MINOR__) "." AILEGO_VERSION_TO_STRING(__PGIC_PATCHLEVEL__) ")" #elif defined(__xlc__) #define AILEGO_VERSION_COMPILER AILEGO_VERSION_COMPILER_NAME " (" __xlc__ ")" #elif defined(__SUNPRO_C) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME " (" AILEGO_VERSION_TO_STRING(__SUNPRO_C) ")" #elif defined(__HP_cc) #define AILEGO_VERSION_COMPILER \ AILEGO_VERSION_COMPILER_NAME " (" AILEGO_VERSION_TO_STRING(__HP_cc) ")" #else #define AILEGO_VERSION_COMPILER AILEGO_VERSION_COMPILER_NAME #endif #if defined(__x86_64__) || defined(_M_X64) #define AILEGO_VERSION_PROCESSOR "x86 64-bit Processor" #elif defined(__i386) || defined(_M_IX86) #define AILEGO_VERSION_PROCESSOR "x86 32-bit Processor" #elif defined(__ARM_ARCH) #if defined(__ARM_64BIT_STATE) #define AILEGO_VERSION_PROCESSOR "ARM 64-bit Processor" #else #define AILEGO_VERSION_PROCESSOR "ARM 32-bit Processor" #endif #elif defined(__ia64) || defined(__itanium__) || defined(_M_IA64) #define AILEGO_VERSION_PROCESSOR "Itanium Processor" #elif defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__) #define AILEGO_VERSION_PROCESSOR "PowerPC 64-bit Processor" #elif defined(__powerpc__) || defined(__ppc__) || defined(__PPC__) #define AILEGO_VERSION_PROCESSOR "PowerPC 32-bit Processor" #elif defined(__sparc) #define AILEGO_VERSION_PROCESSOR "SPARC Processor" #else #define AILEGO_VERSION_PROCESSOR "Unknown Processor" #endif #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #define AILEGO_VERSION_BYTE_ORDER " Little-endian Byte Order\n" #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #define AILEGO_VERSION_BYTE_ORDER " Big-endian Byte Order\n" #elif __BYTE_ORDER__ == __ORDER_PDP_ENDIAN__ #define AILEGO_VERSION_BYTE_ORDER " PDP-endian Byte Order\n" #else #define AILEGO_VERSION_BYTE_ORDER "" #endif #if defined(_DEBUG) || (!defined(__OPTIMIZE__) && !defined(NDEBUG)) #define AILEGO_VERSION_DEBUG_INFO " Debug Information\n" #else #define AILEGO_VERSION_DEBUG_INFO "" #endif #if defined(__SANITIZE_ADDRESS__) #define AILEGO_VERSION_ASAN " Address Sanitizer\n" #else #define AILEGO_VERSION_ASAN "" #endif #if defined(__STDC_VERSION__) #define AILEGO_VERSION_STDC \ " C Standard " AILEGO_VERSION_TO_STRING(__STDC_VERSION__) "\n" #else #define AILEGO_VERSION_STDC "" #endif #if defined(__cplusplus) #define AILEGO_VERSION_CPLUSPLUS \ " C++ Standard " AILEGO_VERSION_TO_STRING(__cplusplus) "\n" #else #define AILEGO_VERSION_CPLUSPLUS "" #endif #if defined(__GXX_ABI_VERSION) #define AILEGO_VERSION_GXX_ABI \ " GNU C++ ABI " AILEGO_VERSION_TO_STRING(__GXX_ABI_VERSION) "\n" #else #define AILEGO_VERSION_GXX_ABI "" #endif #if defined(__GLIBC__) #define AILEGO_VERSION_GLIBC \ " GNU glibc " AILEGO_VERSION_TO_STRING( \ __GLIBC__) "." AILEGO_VERSION_TO_STRING(__GLIBC_MINOR__) "\n" #else #define AILEGO_VERSION_GLIBC "" #endif #if defined(WINVER) #define AILEGO_VERSION_WINSDK \ " Microsoft Windows SDK " AILEGO_VERSION_TO_STRING(WINVER) "\n" #else #define AILEGO_VERSION_WINSDK "" #endif #if defined(__CLR_VER) #define AILEGO_VERSION_CLR \ " Microsoft CLR " AILEGO_VERSION_TO_STRING(__CLR_VER) "\n" #else #define AILEGO_VERSION_CLR "" #endif #if defined(__LSB_VERSION__) #define AILEGO_VERSION_LSB \ " Linux Standards Base " AILEGO_VERSION_TO_STRING(__LSB_VERSION__) "\n" #else #define AILEGO_VERSION_LSB "" #endif #if defined(_POSIX_VERSION) #define AILEGO_VERSION_POSIX \ " POSIX Specification " AILEGO_VERSION_TO_STRING(_POSIX_VERSION) "\n" #else #define AILEGO_VERSION_POSIX "" #endif #if defined(_XOPEN_VERSION) #define AILEGO_VERSION_XOPEN \ " X/Open Specification " AILEGO_VERSION_TO_STRING(_XOPEN_VERSION) "\n" #else #define AILEGO_VERSION_XOPEN "" #endif #if defined(_OPENMP) #define AILEGO_VERSION_OPENMP \ " OpenMP API " AILEGO_VERSION_TO_STRING(_OPENMP) "\n" #else #define AILEGO_VERSION_OPENMP "" #endif #if defined(__ARM_NEON) #define AILEGO_VERSION_SIMD " Arm Neon Instruction Set\n" #elif defined(__AVX512FP16__) #define AILEGO_VERSION_SIMD " AVX-512FP16 Instruction Set\n" #elif defined(__AVX512F__) #define AILEGO_VERSION_SIMD " AVX-512F Instruction Set\n" #elif defined(__AVX2__) #define AILEGO_VERSION_SIMD " AVX-2 Instruction Set\n" #elif defined(__AVX__) #define AILEGO_VERSION_SIMD " AVX Instruction Set\n" #elif defined(__SSE4_2__) #define AILEGO_VERSION_SIMD " SSE-4.2 Instruction Set\n" #elif defined(__SSE4_1__) #define AILEGO_VERSION_SIMD " SSE-4.1 Instruction Set\n" #elif defined(__SSSE3__) #define AILEGO_VERSION_SIMD " SSSE-3 Instruction Set\n" #elif defined(__SSE3__) #define AILEGO_VERSION_SIMD " SSE-3 Instruction Set\n" #elif defined(__SSE2__) #define AILEGO_VERSION_SIMD " SSE-2 Instruction Set\n" #elif defined(__SSE__) #define AILEGO_VERSION_SIMD " SSE Instruction Set\n" #elif defined(__MMX__) #define AILEGO_VERSION_SIMD " MMX Instruction Set\n" #else #define AILEGO_VERSION_SIMD "" #endif #if defined(PY_VERSION) #if PY_RELEASE_LEVEL == PY_RELEASE_LEVEL_ALPHA #define AILEGO_VERSION_PYTHON \ " Python API " PY_VERSION \ " Alpha " AILEGO_VERSION_TO_STRING(PY_RELEASE_SERIAL) "\n" #elif PY_RELEASE_LEVEL == PY_RELEASE_LEVEL_BETA #define AILEGO_VERSION_PYTHON \ " Python API " PY_VERSION \ " Beta " AILEGO_VERSION_TO_STRING(PY_RELEASE_SERIAL) "\n" #elif PY_RELEASE_LEVEL == PY_RELEASE_LEVEL_GAMMA #define AILEGO_VERSION_PYTHON \ " Python API " PY_VERSION \ " Release Candidate " AILEGO_VERSION_TO_STRING(PY_RELEASE_SERIAL) "\n" #elif PY_RELEASE_LEVEL == PY_RELEASE_LEVEL_FINAL #define AILEGO_VERSION_PYTHON " Python API " PY_VERSION " Final\n" #else #define AILEGO_VERSION_PYTHON " Python API " PY_VERSION "\n" #endif #else #define AILEGO_VERSION_PYTHON "" #endif //! Gather information of compiling #define AILEGO_VERSION_COMPILE_DETAILS(__PREFIX_INFO__) \ __PREFIX_INFO__ \ "Compiled by " AILEGO_VERSION_COMPILER \ ".\n" \ "Compiled for " AILEGO_VERSION_PROCESSOR \ ".\n" \ "Compiled on " AILEGO_VERSION_PLATFORM " on " __DATE__ " " __TIME__ \ ".\n" \ "Compiled with: \n" \ "" AILEGO_VERSION_BYTE_ORDER "" AILEGO_VERSION_SIMD \ "" AILEGO_VERSION_DEBUG_INFO "" AILEGO_VERSION_ASAN "" AILEGO_VERSION_STDC \ "" AILEGO_VERSION_CPLUSPLUS "" AILEGO_VERSION_GXX_ABI \ "" AILEGO_VERSION_POSIX "" AILEGO_VERSION_XOPEN "" AILEGO_VERSION_LSB \ "" AILEGO_VERSION_GLIBC "" AILEGO_VERSION_WINSDK "" AILEGO_VERSION_CLR \ "" AILEGO_VERSION_OPENMP "" AILEGO_VERSION_PYTHON "\n" ================================================ FILE: src/binding/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) # Retrieve version from git repository git_version(ZVEC_VERSION ${CMAKE_CURRENT_SOURCE_DIR}) # Add repository cc_directory(python) ================================================ FILE: src/binding/python/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(pybind11 REQUIRED) set(SRC_LISTS binding.cc model/python_collection.cc model/python_doc.cc model/param/python_param.cc model/schema/python_schema.cc model/common/python_config.cc typing/python_type.cc ) pybind11_add_module(_zvec ${SRC_LISTS}) if (CMAKE_SYSTEM_NAME STREQUAL "Linux") target_link_libraries(_zvec PRIVATE -Wl,--whole-archive $ $ $ $ $ $ $ $ $ $ $ -Wl,--no-whole-archive zvec_db ) target_link_options(_zvec PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/exports.map" ) elseif (APPLE) target_link_libraries(_zvec PRIVATE -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ zvec_db ) target_link_libraries(_zvec PRIVATE -Wl,-exported_symbols_list,${CMAKE_CURRENT_SOURCE_DIR}/exports.mac ) endif () target_include_directories(_zvec PRIVATE ${PYBIND11_INCLUDE_DIR} ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/binding/python/include) ================================================ FILE: src/binding/python/binding.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_collection.h" #include "python_config.h" #include "python_doc.h" #include "python_param.h" #include "python_schema.h" #include "python_type.h" namespace zvec { PYBIND11_MODULE(_zvec, m) { m.doc() = "Zvec core module"; ZVecPyTyping::Initialize(m); ZVecPyParams::Initialize(m); ZVecPySchemas::Initialize(m); ZVecPyConfig::Initialize(m); ZVecPyDoc::Initialize(m); ZVecPyCollection::Initialize(m); } } // namespace zvec ================================================ FILE: src/binding/python/exports.mac ================================================ _PyInit__zvec ================================================ FILE: src/binding/python/include/python_collection.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include namespace py = pybind11; namespace zvec { class ZVecPyCollection { public: ZVecPyCollection() = delete; public: static void Initialize(py::module_ &m); private: static void bind_db_methods(py::class_ &col); static void bind_ddl_methods(py::class_ &col); static void bind_dml_methods(py::class_ &col); static void bind_dql_methods(py::class_ &col); }; } // namespace zvec ================================================ FILE: src/binding/python/include/python_config.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include namespace py = pybind11; namespace zvec { class ZVecPyConfig { public: ZVecPyConfig() = delete; public: static void Initialize(py::module_ &m); }; } // namespace zvec ================================================ FILE: src/binding/python/include/python_doc.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include namespace py = pybind11; namespace zvec { class ZVecPyDoc { public: ZVecPyDoc() = delete; public: static void Initialize(py::module_ &m); private: static void bind_doc_operator(py::module_ &m); static void bind_doc(py::module_ &m); }; } // namespace zvec ================================================ FILE: src/binding/python/include/python_param.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include #include namespace py = pybind11; namespace zvec { class ZVecPyParams { public: ZVecPyParams() = delete; public: static void Initialize(py::module_ &m); private: static void bind_index_params(py::module_ &m); static void bind_query_params(py::module_ &m); static void bind_options(py::module_ &m); static void bind_vector_query(py::module_ &m); }; } // namespace zvec ================================================ FILE: src/binding/python/include/python_schema.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include namespace py = pybind11; namespace zvec { class ZVecPySchemas { public: ZVecPySchemas() = delete; public: static void Initialize(py::module_ &m); private: static void bind_field_schema(py::module_ &m); static void bind_collection_schema(py::module_ &m); static void bind_collection_stats(py::module_ &m); }; } // namespace zvec ================================================ FILE: src/binding/python/include/python_type.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.#pragma once #include #include #include namespace py = pybind11; namespace zvec { class ZVecPyTyping { public: ZVecPyTyping() = delete; public: static void Initialize(py::module_ &m); private: static void bind_datatypes(py::module_ &m); static void bind_index_types(py::module_ &m); static void bind_metric_types(py::module_ &m); static void bind_quantize_types(py::module_ &m); static void bind_status(py::module_ &m); }; } // namespace zvec ================================================ FILE: src/binding/python/model/common/python_config.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_config.h" #include namespace zvec { inline bool has_key(py::dict d, const std::string &key) { return py::bool_(d.contains(key)); } template std::optional get_if(py::dict d, const std::string &key) { if (has_key(d, key)) { try { py::object obj = d[py::str(key)]; return obj.cast(); } catch (const py::cast_error &) { throw py::type_error("Key '" + key + "' is not of expected type."); } } return std::nullopt; } inline std::string to_lower(const std::string &s) { std::string lower; lower.reserve(s.size()); std::transform(s.begin(), s.end(), std::back_inserter(lower), ::tolower); return lower; } inline bool iequals(const std::string &a, const std::string &b) { return to_lower(a) == to_lower(b); } GlobalConfig::LogLevel str_to_loglevel(const std::string &s) { if (iequals(s, "debug")) return GlobalConfig::LogLevel::DEBUG; if (iequals(s, "info")) return GlobalConfig::LogLevel::INFO; if (iequals(s, "warn") || iequals(s, "warning")) return GlobalConfig::LogLevel::WARN; if (iequals(s, "error")) return GlobalConfig::LogLevel::ERROR; if (iequals(s, "fatal")) return GlobalConfig::LogLevel::FATAL; throw py::value_error("Invalid log level: "); } void ZVecPyConfig::Initialize(pybind11::module_ &m) { m.def("Initialize", [](py::args args, py::kwargs kwargs) -> py::none { py::dict config_dict; // parse args for (auto &arg : args) { if (py::isinstance(arg)) { for (auto item : arg.cast()) { config_dict[item.first] = item.second; } } else { throw py::type_error("Positional argument must be a dict if provided"); } } // parser kwargs if (kwargs) { for (auto item : kwargs) { config_dict[item.first] = item.second; } } if (config_dict.empty()) { return py::none(); } GlobalConfig::ConfigData data; // config memory_limit_mb if (has_key(config_dict, "memory_limit_mb")) { auto mb = get_if(config_dict, "memory_limit_mb").value(); if (mb <= 0) throw py::value_error("memory_limit_mb must be positive"); data.memory_limit_bytes = static_cast(mb) * 1024 * 1024; } // config log bool has_log_type = has_key(config_dict, "log_type"); bool has_log_level = has_key(config_dict, "log_level"); if (has_log_type || has_log_level) { std::string log_type = "console"; std::string log_level_str = "warn"; if (has_log_type) { log_type = config_dict["log_type"].cast(); } if (has_log_level) { log_level_str = config_dict["log_level"].cast(); } auto log_level = str_to_loglevel(log_level_str); if (iequals(log_type, "file")) { std::string dir = DEFAULT_LOG_DIR; std::string basename = DEFAULT_LOG_BASENAME; uint32_t file_size = DEFAULT_LOG_FILE_SIZE; uint32_t overdue_days = DEFAULT_LOG_OVERDUE_DAYS; if (has_key(config_dict, "log_dir")) { dir = get_if(config_dict, "log_dir").value(); } if (has_key(config_dict, "log_basename")) { basename = get_if(config_dict, "log_basename").value(); } if (has_key(config_dict, "log_file_size")) { auto s = get_if(config_dict, "log_file_size").value(); if (s <= 0) { throw py::value_error("log_file_size must be positive"); } file_size = static_cast(s); } if (has_key(config_dict, "log_overdue_days")) { std::cout << " ** log_overdue_days: " << overdue_days << std::endl; auto d = get_if(config_dict, "log_overdue_days").value(); if (d <= 0) { throw py::value_error("log_overdue_days must be positive"); } overdue_days = static_cast(d); } data.log_config = std::make_shared( log_level, dir, basename, file_size, overdue_days); } else if (iequals(log_type, "console")) { data.log_config = std::make_shared(log_level); } else { throw py::value_error("log_type must be 'console' or 'file'"); } } // set query thread count if (has_key(config_dict, "query_threads")) { auto q = get_if(config_dict, "query_threads").value(); if (q <= 0) throw py::value_error("query_threads must be positive"); data.query_thread_count = static_cast(q); } // set optimize thread count if (has_key(config_dict, "optimize_threads")) { auto o = get_if(config_dict, "optimize_threads").value(); if (o <= 0) throw py::value_error("optimize_threads must be positive"); data.optimize_thread_count = static_cast(o); } // set invert_to_forward_scan_ratio if (has_key(config_dict, "invert_to_forward_scan_ratio")) { auto v = get_if(config_dict, "invert_to_forward_scan_ratio").value(); if (v < 0.0 || v > 1.0) { throw py::value_error( "invert_to_forward_scan_ratio must be in [0.0, 1.0]"); } data.invert_to_forward_scan_ratio = static_cast(v); } // set brute_force_by_keys_ratio if (has_key(config_dict, "brute_force_by_keys_ratio")) { auto v = get_if(config_dict, "brute_force_by_keys_ratio").value(); if (v < 0.0 || v > 1.0) { throw py::value_error( "brute_force_by_keys_ratio must be in [0.0, 1.0]"); } data.brute_force_by_keys_ratio = static_cast(v); } // initialize (contains validate) Status status = GlobalConfig::Instance().Initialize(data); if (!status.ok()) { throw std::runtime_error("Initialization failed: " + status.message()); } return py::none(); }); } } // namespace zvec ================================================ FILE: src/binding/python/model/param/python_param.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_param.h" #include #include #include #include #include "python_doc.h" namespace zvec { static std::string index_type_to_string(const IndexType type) { switch (type) { case IndexType::INVERT: return "INVERT"; case IndexType::FLAT: return "FLAT"; case IndexType::IVF: return "IVF"; case IndexType::HNSW: return "HNSW"; case IndexType::HNSW_RABITQ: return "HNSW_RABITQ"; default: return "UNDEFINED"; } } static std::string metric_type_to_string(const MetricType type) { switch (type) { case MetricType::COSINE: return "COSINE"; case MetricType::IP: return "IP"; case MetricType::L2: return "L2"; default: return "UNDEFINED"; } } static std::string quantize_type_to_string(const QuantizeType type) { switch (type) { case QuantizeType::UNDEFINED: return "UNDEFINED"; case QuantizeType::INT8: return "INT8"; case QuantizeType::INT4: return "INT4"; case QuantizeType::FP16: return "FP16"; case QuantizeType::RABITQ: return "RABITQ"; default: return "UNDEFINED"; } } template T checked_cast(const py::handle &h, const std::string &vector_field, const std::string &expected_type) { try { return py::cast(h); } catch (const py::cast_error &e) { std::string actual_type = std::string(py::str(py::type::of(h))); std::string msg = vector_field + ": expected " + expected_type + ", got " + actual_type; throw py::type_error(msg); } } template std::string serialize_vector(const T *data, size_t n) { std::string buf; buf.resize(n * sizeof(T)); std::memcpy(buf.data(), data, n * sizeof(T)); return buf; } template std::pair serialize_sparse_vector( const py::dict &sparse_dict, ValueCastFn &&value_caster) { const size_t n = sparse_dict.size(); if (n == 0) return {{}, {}}; std::string indices_buf; indices_buf.resize(n * sizeof(uint32_t)); auto *indices_ptr = reinterpret_cast(indices_buf.data()); std::string values_buf; values_buf.resize(n * sizeof(ValueType)); auto *values_ptr = reinterpret_cast(values_buf.data()); size_t i = 0; for (const auto &[py_key, py_val] : sparse_dict) { indices_ptr[i] = checked_cast(py_key, "Sparse indices", "UINT32"); values_ptr[i] = value_caster(py_val, i); ++i; } return {std::move(indices_buf), std::move(values_buf)}; } void ZVecPyParams::Initialize(pybind11::module_ &parent) { auto m = parent.def_submodule("param", "This module contains the params of Zvec"); // binding index_params [invert/hnsw/flat/ivf] bind_index_params(m); // bind query_params [hnsw/ivf] bind_query_params(m); // bind options [collection/index/optimize/column] bind_options(m); // bind vector query bind_vector_query(m); } void ZVecPyParams::bind_index_params(pybind11::module_ &m) { // binding base index params py::class_> index_params( m, "IndexParam", R"pbdoc( Base class for all index parameter configurations. This abstract base class defines the common interface for index types. It should not be instantiated directly; use derived classes instead. Attributes: type (IndexType): The type of the index (e.g., HNSW, FLAT, INVERT). )pbdoc"); index_params .def_property_readonly( "type", [](const IndexParams &self) -> IndexType { return self.type(); }, "IndexType: The type of the index.") .def("clone", &IndexParams::clone, py::return_value_policy::copy) .def( "__eq__", [](const IndexParams &self, const py::object &other) { if (!py::isinstance(other)) return false; return self == other.cast(); }, py::is_operator()) .def( "to_dict", [](const IndexParams &self) -> py::dict { py::dict dict; dict["type"] = index_type_to_string(self.type()); return dict; }, "Convert to dictionary with all fields") .def(py::pickle( [](const IndexParams &self) { // __getstate__ return py::make_tuple(self.type()); }, [](py::tuple t) { // __setstate__ if (t.size() != 1) throw std::runtime_error("Invalid state for IndexParams"); return std::shared_ptr(); })); // binding invert index params py::class_> invert_params(m, "InvertIndexParam", R"pbdoc( Parameters for configuring an invert index. This class controls whether range query optimization is enabled for invert index structures. Attributes: type (IndexType): Always `IndexType.INVERTED`. enable_range_optimization (bool): Whether range optimization is enabled. enable_extended_wildcard (bool): Whether extended wildcard (suffix and infix) search is enabled. Examples: >>> params = InvertIndexParam(enable_range_optimization=True, enable_extended_wildcard=False) >>> print(params.enable_range_optimization) True >>> print(params.enable_extended_wildcard) False >>> config = params.to_dict() >>> print(config) {'enable_range_optimization': True, 'enable_extended_wildcard': False} )pbdoc"); invert_params .def(py::init(), py::arg("enable_range_optimization") = false, py::arg("enable_extended_wildcard") = false, R"pbdoc( Constructs an InvertIndexParam instance. Args: enable_range_optimization (bool, optional): If True, enables range query optimization for the invert index. Defaults to False. enable_extended_wildcard (bool, optional): If True, enables extended wildcard search including suffix and infix patterns. Defaults to False. )pbdoc") .def_property_readonly("enable_range_optimization", &InvertIndexParams::enable_range_optimization, R"pbdoc( bool: Whether range optimization is enabled for this inverted index. )pbdoc") .def_property_readonly("enable_extended_wildcard", &InvertIndexParams::enable_extended_wildcard, R"pbdoc( bool: Whether extended wildcard (suffix and infix) search is enabled. Note: Prefix search is always enabled regardless of this setting. )pbdoc") .def( "to_dict", [](const InvertIndexParams &self) -> py::dict { py::dict dict; dict["enable_range_optimization"] = self.enable_range_optimization(); dict["enable_extended_wildcard"] = self.enable_extended_wildcard(); return dict; }, "Convert to dictionary with all fields") .def("__repr__", [](const InvertIndexParams &self) -> std::string { return "{" "\"enable_range_optimization\":" + std::to_string(self.enable_range_optimization()) + "," "\"enable_extended_wildcard\":" + std::to_string(self.enable_extended_wildcard()) + "}"; }) .def(py::pickle( [](const InvertIndexParams &self) { // __getstate__ return py::make_tuple(self.enable_range_optimization(), self.enable_extended_wildcard()); }, [](py::tuple t) { // __setstate__ if (t.size() != 2) throw std::runtime_error("Invalid state for InvertIndexParams"); return std::make_shared(t[0].cast(), t[1].cast()); })); // binding base vector index params py::class_> vector_params(m, "VectorIndexParam", R"pbdoc( Base class for vector index parameter configurations. Encapsulates common settings for all vector index types. Attributes: type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. )pbdoc"); vector_params .def_property_readonly( "metric_type", [](const VectorIndexParams &self) -> MetricType { return self.metric_type(); }, "MetricType: Distance metric (e.g., IP, COSINE, L2).") .def_property_readonly( "quantize_type", [](const VectorIndexParams &self) -> QuantizeType { return self.quantize_type(); }, "QuantizeType: Vector quantization type (e.g., FP16, INT8).") .def( "to_dict", [](const VectorIndexParams &self) -> py::dict { py::dict dict; dict["type"] = index_type_to_string(self.type()); dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); return dict; }, "Convert to dictionary with all fields") .def(py::pickle( [](const VectorIndexParams &self) { // __getstate__ return py::make_tuple(self.type(), self.metric_type(), self.quantize_type()); }, [](py::tuple t) { // __setstate__ if (t.size() != 3) throw std::runtime_error("Invalid state for VectorIndexParams"); // 基类,不能直接实例化,用于子类 return std::shared_ptr(); })); // binding hnsw index params py::class_> hnsw_params(m, "HnswIndexParam", R"pbdoc( Parameters for configuring an HNSW (Hierarchical Navigable Small World) index. HNSW is a graph-based approximate nearest neighbor search index. This class encapsulates its construction hyperparameters. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). m (int): Number of bi-directional links created for every new element during construction. Higher values improve accuracy but increase memory usage and construction time. Default is 50. ef_construction (int): Size of the dynamic candidate list for nearest neighbors during index construction. Larger values yield better graph quality at the cost of slower build time. Default is 500. quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Default is `QuantizeType.UNDEFINED` to disable quantization. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = HnswIndexParam( ... metric_type=MetricType.COSINE, ... m=16, ... ef_construction=200, ... quantize_type=QuantizeType.INT8 ... ) >>> print(params) {'metric_type': 'IP', 'm': 16, 'ef_construction': 200, 'quantize_type': 'INT8'} )pbdoc"); hnsw_params .def(py::init(), py::arg("metric_type") = MetricType::IP, py::arg("m") = core_interface::kDefaultHnswNeighborCnt, py::arg("ef_construction") = core_interface::kDefaultHnswEfConstruction, py::arg("quantize_type") = QuantizeType::UNDEFINED) .def_property_readonly( "m", &HnswIndexParams::m, "int: Maximum number of neighbors per node in upper layers.") .def_property_readonly( "ef_construction", &HnswIndexParams::ef_construction, "int: Candidate list size during index construction.") .def( "to_dict", [](const HnswIndexParams &self) -> py::dict { py::dict dict; dict["type"] = index_type_to_string(self.type()); dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["m"] = self.m(); dict["ef_construction"] = self.ef_construction(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); return dict; }, "Convert to dictionary with all fields") .def("__repr__", [](const HnswIndexParams &self) -> std::string { return "{" "\"metric_type\":" + metric_type_to_string(self.metric_type()) + ", \"m\":" + std::to_string(self.m()) + ", \"ef_construction\":" + std::to_string(self.ef_construction()) + ", \"quantize_type\":" + quantize_type_to_string(self.quantize_type()) + "}"; }) .def(py::pickle( [](const HnswIndexParams &self) { return py::make_tuple(self.metric_type(), self.m(), self.ef_construction(), self.quantize_type()); }, [](py::tuple t) { if (t.size() != 4) throw std::runtime_error("Invalid state for HnswIndexParams"); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast()); })); // binding hnsw rabitq index params py::class_> hnsw_rabitq_params(m, "HnswRabitqIndexParam", R"pbdoc( Parameters for configuring an HNSW (Hierarchical Navigable Small World) index with RabitQ quantization. HNSW is a graph-based approximate nearest neighbor search index. RabitQ is a quantization method that provides high compression with minimal accuracy loss. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). m (int): Number of bi-directional links created for every new element during construction. Higher values improve accuracy but increase memory usage and construction time. Default is 50. ef_construction (int): Size of the dynamic candidate list for nearest neighbors during index construction. Larger values yield better graph quality at the cost of slower build time. Default is 500. Examples: >>> from zvec.typing import MetricType >>> params = HnswRabitqIndexParam( ... metric_type=MetricType.COSINE, ... m=16, ... ef_construction=200 ... ) >>> print(params) {'metric_type': 'COSINE', 'm': 16, 'ef_construction': 200} )pbdoc"); hnsw_rabitq_params .def(py::init(), py::arg("metric_type") = MetricType::IP, py::arg("total_bits") = core_interface::kDefaultRabitqTotalBits, py::arg("num_clusters") = core_interface::kDefaultRabitqNumClusters, py::arg("m") = core_interface::kDefaultHnswNeighborCnt, py::arg("ef_construction") = core_interface::kDefaultHnswEfConstruction, py::arg("sample_count") = 0) .def_property_readonly("m", &HnswRabitqIndexParams::m, "int: Maximum number of neighbors per node.") .def_property_readonly( "ef_construction", &HnswRabitqIndexParams::ef_construction, "int: Candidate list size during index construction.") .def_property_readonly("total_bits", &HnswRabitqIndexParams::total_bits, "int: Total bits for RabitQ quantization.") .def_property_readonly("num_clusters", &HnswRabitqIndexParams::num_clusters, "int: Number of clusters for RabitQ.") .def_property_readonly("sample_count", &HnswRabitqIndexParams::sample_count, "int: Sample count for RabitQ training.") .def( "to_dict", [](const HnswRabitqIndexParams &self) -> py::dict { py::dict dict; dict["type"] = index_type_to_string(self.type()); dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); dict["total_bits"] = self.total_bits(); dict["num_clusters"] = self.num_clusters(); dict["sample_count"] = self.sample_count(); dict["m"] = self.m(); dict["ef_construction"] = self.ef_construction(); return dict; }, "Convert to dictionary with all fields") .def( "__repr__", [](const HnswRabitqIndexParams &self) -> std::string { return "{" "\"type\":\"" + index_type_to_string(self.type()) + "\", \"metric_type\":\"" + metric_type_to_string(self.metric_type()) + "\", \"total_bits\":" + std::to_string(self.total_bits()) + ", \"num_clusters\":" + std::to_string(self.num_clusters()) + ", \"sample_count\":" + std::to_string(self.sample_count()) + ", \"m\":" + std::to_string(self.m()) + ", \"ef_construction\":" + std::to_string(self.ef_construction()) + ", \"quantize_type\":\"" + quantize_type_to_string(self.quantize_type()) + "\"}"; }) .def(py::pickle( [](const HnswRabitqIndexParams &self) { return py::make_tuple(self.metric_type(), self.total_bits(), self.num_clusters(), self.m(), self.ef_construction(), self.sample_count()); }, [](py::tuple t) { if (t.size() != 6) throw std::runtime_error( "Invalid state for HnswRabitqIndexParams"); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast(), t[5].cast()); })); // FlatIndexParams py::class_> flat_params(m, "FlatIndexParam", R"pbdoc( Parameters for configuring a flat (brute-force) index. A flat index performs exact nearest neighbor search by comparing the query vector against all vectors in the collection. It is simple, accurate, and suitable for small to medium datasets or as a baseline. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Use ``QuantizeType.UNDEFINED`` to disable quantization. Default is ``QuantizeType.UNDEFINED``. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = FlatIndexParam( ... metric_type=MetricType.L2, ... quantize_type=QuantizeType.FP16 ... ) >>> print(params) {'metric_type': 'L2', 'quantize_type': 'FP16'} )pbdoc"); flat_params .def(py::init(), py::arg("metric_type") = MetricType::IP, py::arg("quantize_type") = QuantizeType::UNDEFINED, R"pbdoc( Constructs a FlatIndexParam instance. Args: metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). )pbdoc") .def( "to_dict", [](const FlatIndexParams &self) -> py::dict { py::dict dict; dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); return dict; }, "Convert to dictionary with all fields") .def("__repr__", [](const FlatIndexParams &self) -> std::string { return "{" "\"metric_type\":" + metric_type_to_string(self.metric_type()) + ", \"quantize_type\":" + quantize_type_to_string(self.quantize_type()) + "}"; }) .def(py::pickle( [](const FlatIndexParams &self) { return py::make_tuple(self.metric_type(), self.quantize_type()); }, [](py::tuple t) { if (t.size() != 2) throw std::runtime_error("Invalid state for FlatIndexParams"); return std::make_shared(t[0].cast(), t[1].cast()); })); // IVFIndexParams py::class_> ivf_params(m, "IVFIndexParam", R"pbdoc( Parameters for configuring an IVF (Inverted File Index) index. IVF partitions the vector space into clusters (inverted lists). At query time, only a subset of clusters is searched, providing a trade-off between speed and accuracy. Attributes: metric_type (MetricType): Distance metric used for similarity computation. Default is ``MetricType.IP`` (inner product). n_list (int): Number of clusters (inverted lists) to partition the dataset into. If set to 0, the system will auto-select a reasonable value based on data size. Default is 0 (auto). n_iters (int): Number of iterations for k-means clustering during index training. Higher values yield more stable centroids. Default is 10. use_soar (bool): Whether to enable SOAR (Scalable Optimized Adaptive Routing) for improved IVF search performance. Default is False. quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Default is ``QuantizeType.UNDEFINED``. Examples: >>> from zvec.typing import MetricType, QuantizeType >>> params = IVFIndexParam( ... metric_type=MetricType.COSINE, ... n_list=100, ... n_iters=15, ... use_soar=True, ... quantize_type=QuantizeType.INT8 ... ) >>> print(params.n_list) 100 )pbdoc"); ivf_params .def(py::init(), py::arg("metric_type") = MetricType::IP, py::arg("n_list") = 0, py::arg("n_iters") = 10, py::arg("use_soar") = false, py::arg("quantize_type") = QuantizeType::UNDEFINED, R"pbdoc( Constructs an IVFIndexParam instance. Args: metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. n_list (int, optional): Number of inverted lists (clusters). Set to 0 for auto. Defaults to 0. n_iters (int, optional): Number of k-means iterations during training. Defaults to 10. use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. )pbdoc") .def_property_readonly("n_list", &IVFIndexParams::n_list, "int: Number of inverted lists (0 = auto).") .def_property_readonly( "n_iters", &IVFIndexParams::n_iters, "int: Number of k-means iterations during training.") .def_property_readonly("use_soar", &IVFIndexParams::use_soar, "bool: Whether SOAR optimization is enabled.") .def( "to_dict", [](const IVFIndexParams &self) -> py::dict { py::dict dict; dict["type"] = index_type_to_string(self.type()); dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["n_list"] = self.n_list(); dict["n_iters"] = self.n_iters(); dict["use_soar"] = self.use_soar(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); return dict; }, "Convert to dictionary with all fields") .def("__repr__", [](const IVFIndexParams &self) { return "{" "\"metric_type\":" + metric_type_to_string(self.metric_type()) + ", \"n_list\":" + std::to_string(self.n_list()) + ", \"n_iters\":" + std::to_string(self.n_iters()) + ", \"use_soar\":" + std::to_string(self.use_soar()) + ", \"quantize_type\":" + quantize_type_to_string(self.quantize_type()) + "}"; }) .def(py::pickle( [](const IVFIndexParams &self) { return py::make_tuple(self.metric_type(), self.n_list(), self.n_iters(), self.use_soar(), self.quantize_type()); }, [](py::tuple t) { if (t.size() != 5) throw std::runtime_error("Invalid state for IVFIndexParams"); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast()); })); } void ZVecPyParams::bind_query_params(py::module_ &m) { // binding base query params py::class_> query_params( m, "QueryParam", R"pbdoc( Base class for all query parameter configurations. This abstract base class defines common query settings such as search radius and whether to force linear (brute-force) search. It should not be instantiated directly; use derived classes like `HnswQueryParam` or `IVFQueryParam`. Attributes: type (IndexType): The index type this query is configured for. radius (float): Search radius for range queries. Used in combination with top-k to filter results. Default is 0.0 (disabled). is_linear (bool): If True, forces brute-force linear search instead of using the index. Useful for debugging or small datasets. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. )pbdoc"); query_params .def_property_readonly( "type", [](const QueryParams &self) -> IndexType { return self.type(); }, "IndexType: The type of index this query targets.") .def_property_readonly( "radius", [](const QueryParams &self) -> float { return self.radius(); }, "IndexType: The type of index this query targets.") .def_property_readonly( "is_linear", [](const QueryParams &self) -> bool { return self.is_linear(); }, "bool: Whether to bypass the index and use brute-force linear " "search.") .def_property_readonly( "is_using_refiner", [](const QueryParams &self) -> bool { return self.is_using_refiner(); }, "bool: Whether to use refiner for the query.") .def(py::pickle( [](const QueryParams &self) { // __getstate__ return py::make_tuple(self.type(), self.radius(), self.is_linear()); }, [](py::tuple t) { // __setstate__ if (t.size() != 3) throw std::runtime_error("Invalid state for QueryParams"); return std::shared_ptr(); })); // binding hnsw query params py::class_> hnsw_params(m, "HnswQueryParam", R"pbdoc( Query parameters for HNSW (Hierarchical Navigable Small World) index. Controls the trade-off between search speed and accuracy via the `ef` parameter. Attributes: type (IndexType): Always ``IndexType.HNSW``. ef (int): Size of the dynamic candidate list during search. Larger values improve recall but slow down search. Default is 300. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. Examples: >>> params = HnswQueryParam(ef=300) >>> print(params.ef) 300 >>> print(params.to_dict() if hasattr(params, 'to_dict') else params) {"type":"HNSW", "ef":300} )pbdoc"); hnsw_params .def(py::init(), py::arg("ef") = core_interface::kDefaultHnswEfSearch, py::arg("radius") = 0.0f, py::arg("is_linear") = false, py::arg("is_using_refiner") = false, R"pbdoc( Constructs an HnswQueryParam instance. Args: ef (int, optional): Search-time candidate list size. Higher values improve accuracy. Defaults to 100. radius (float, optional): Search radius for range queries. Default is 0.0. is_linear (bool, optional): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. )pbdoc") .def_property_readonly( "ef", [](const HnswQueryParams &self) -> int { return self.ef(); }, "int: Size of the dynamic candidate list during HNSW search.") .def("__repr__", [](const HnswQueryParams &self) -> std::string { return "{" "\"type\":" + index_type_to_string(self.type()) + ", \"ef\":" + std::to_string(self.ef()) + ", \"radius\":" + std::to_string(self.radius()) + ", \"is_linear\":" + std::to_string(self.is_linear()) + ", \"is_using_refiner\":" + std::to_string(self.is_using_refiner()) + "}"; }) .def(py::pickle( [](const HnswQueryParams &self) { return py::make_tuple(self.ef(), self.radius(), self.is_linear(), self.is_using_refiner()); }, [](py::tuple t) { if (t.size() != 4) throw std::runtime_error("Invalid state for HnswQueryParams"); auto obj = std::make_shared(t[0].cast()); obj->set_radius(t[1].cast()); obj->set_is_linear(t[2].cast()); obj->set_is_using_refiner(t[3].cast()); return obj; })); // binding ivf query params py::class_> ivf_params(m, "IVFQueryParam", R"pbdoc( Query parameters for IVF (Inverted File Index) index. Controls how many inverted lists (`nprobe`) to visit during search. Attributes: type (IndexType): Always ``IndexType.IVF``. nprobe (int): Number of closest clusters (inverted lists) to search. Higher values improve recall but increase latency. Default is 10. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. Examples: >>> params = IVFQueryParam(nprobe=20) >>> print(params.nprobe) 20 )pbdoc"); ivf_params .def(py::init(), py::arg("nprobe") = 10, R"pbdoc( Constructs an IVFQueryParam instance. Args: nprobe (int, optional): Number of inverted lists to probe during search. Higher values improve accuracy. Defaults to 10. )pbdoc") .def_property_readonly( "nprobe", [](const IVFQueryParams &self) -> int { return self.nprobe(); }, "int: Number of inverted lists to search during IVF query.") .def("__repr__", [](const IVFQueryParams &self) -> std::string { return "{" "\"type\":" + index_type_to_string(self.type()) + ", \"nprobe\":" + std::to_string(self.nprobe()) + "}"; }) .def(py::pickle( [](const IVFQueryParams &self) { return py::make_tuple(self.nprobe(), self.radius(), self.is_linear()); }, [](py::tuple t) { if (t.size() != 3) throw std::runtime_error("Invalid state for IVFQueryParams"); auto obj = std::make_shared(t[0].cast()); obj->set_radius(t[1].cast()); obj->set_is_linear(t[2].cast()); return obj; })); // binding hnsw rabitq query params py::class_> hnsw_rabitq_query_params(m, "HnswRabitqQueryParam", R"pbdoc( Query parameters for HNSW RaBitQ (Hierarchical Navigable Small World with RaBitQ quantization) index. Controls the trade-off between search speed and accuracy via the `ef` parameter. RaBitQ provides efficient quantization while maintaining high search quality. Attributes: type (IndexType): Always ``IndexType.HNSW_RABITQ``. ef (int): Size of the dynamic candidate list during search. Larger values improve recall but slow down search. Default is 300. radius (float): Search radius for range queries. Default is 0.0. is_linear (bool): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. Examples: >>> params = HnswRabitqQueryParam(ef=300) >>> print(params.ef) 300 >>> print(params.to_dict() if hasattr(params, 'to_dict') else params) {"type":"HNSW_RABITQ", "ef":300} )pbdoc"); hnsw_rabitq_query_params .def(py::init(), py::arg("ef") = core_interface::kDefaultHnswEfSearch, py::arg("radius") = 0.0f, py::arg("is_linear") = false, py::arg("is_using_refiner") = false, R"pbdoc( Constructs an HnswRabitqQueryParam instance. Args: ef (int, optional): Search-time candidate list size. Higher values improve accuracy. Defaults to 300. radius (float, optional): Search radius for range queries. Default is 0.0. is_linear (bool, optional): Force linear search. Default is False. is_using_refiner (bool, optional): Whether to use refiner for the query. Default is False. )pbdoc") .def_property_readonly( "ef", [](const HnswRabitqQueryParams &self) -> int { return self.ef(); }, "int: Size of the dynamic candidate list during HNSW RaBitQ search.") .def("__repr__", [](const HnswRabitqQueryParams &self) -> std::string { return "{" "\"type\":\"" + index_type_to_string(self.type()) + "\", \"ef\":" + std::to_string(self.ef()) + ", \"radius\":" + std::to_string(self.radius()) + ", \"is_linear\":" + std::to_string(self.is_linear()) + ", \"is_using_refiner\":" + std::to_string(self.is_using_refiner()) + "}"; }) .def(py::pickle( [](const HnswRabitqQueryParams &self) { return py::make_tuple(self.ef(), self.radius(), self.is_linear(), self.is_using_refiner()); }, [](py::tuple t) { if (t.size() != 4) throw std::runtime_error( "Invalid state for HnswRabitqQueryParams"); auto obj = std::make_shared(t[0].cast()); obj->set_radius(t[1].cast()); obj->set_is_linear(t[2].cast()); obj->set_is_using_refiner(t[3].cast()); return obj; })); } void ZVecPyParams::bind_options(py::module_ &m) { // binding collection options py::class_(m, "CollectionOption", R"pbdoc( Options for opening or creating a collection. Attributes: read_only (bool): Whether the collection is opened in read-only mode. Default is False. enable_mmap (bool): Whether to use memory-mapped I/O for data files. Default is True. Examples: >>> opt = CollectionOption(read_only=True, enable_mmap=False) >>> print(opt.read_only) True )pbdoc") .def(py::init(), py::arg("read_only") = false, py::arg("enable_mmap") = true, R"pbdoc( Constructs a CollectionOption instance. Args: read_only (bool, optional): Open collection in read-only mode. Defaults to False. enable_mmap (bool, optional): Enable memory-mapped I/O. Defaults to True. )pbdoc") .def_property_readonly( "enable_mmap", [](const CollectionOptions &self) { return self.enable_mmap_; }) .def_property_readonly( "read_only", [](const CollectionOptions &self) { return self.read_only_; }) .def("__repr__", [](const CollectionOptions &self) -> std::string { return "{" "\"enable_mmap\":" + std::to_string(self.enable_mmap_) + ", \"read_only\":" + std::to_string(self.read_only_) + "}"; }) .def(py::pickle( [](const CollectionOptions &self) { return py::make_tuple(self.read_only_, self.enable_mmap_, self.max_buffer_size_); }, [](py::tuple t) { if (t.size() != 3) throw std::runtime_error( "Invalid pickle data for CollectionOptions"); CollectionOptions obj{}; obj.read_only_ = t[0].cast(); obj.enable_mmap_ = t[1].cast(); obj.max_buffer_size_ = t[2].cast(); return obj; })); // SegmentOptions py::class_(m, "SegmentOption", R"pbdoc( Options for segment-level operations. Currently, this class mirrors CollectionOption and is used internally. It supports read-only mode, memory mapping, and buffer configuration. Note: This class is primarily for internal use. Most users should use CollectionOption instead. Examples: >>> opt = SegmentOption() >>> print(opt.enable_mmap) True )pbdoc") .def(py::init<>(), "Constructs a SegmentOption with default settings.") .def_property_readonly( "enable_mmap", [](const SegmentOptions &self) { return self.enable_mmap_; }, "bool: Whether memory-mapped I/O is enabled.") .def_property_readonly( "read_only", [](const SegmentOptions &self) { return self.read_only_; }, "bool: Whether the segment is read-only.") .def_property_readonly( "max_buffer_size", [](const SegmentOptions &self) { return self.max_buffer_size_; }, "int: Maximum buffer size in bytes (internal use).") .def("__repr__", [](const SegmentOptions &self) -> std::string { return "{" "\"enable_mmap\":" + std::to_string(self.enable_mmap_) + ", \"read_only\":" + std::to_string(self.read_only_) + ", \"max_buffer_size\":" + std::to_string(self.max_buffer_size_) + "}"; }) .def(py::pickle( [](const SegmentOptions &self) { return py::make_tuple(self.read_only_, self.enable_mmap_, self.max_buffer_size_); }, [](py::tuple t) { if (t.size() != 3) throw std::runtime_error( "Invalid pickle data for SegmentOptions"); SegmentOptions obj{}; obj.read_only_ = t[0].cast(); obj.enable_mmap_ = t[1].cast(); obj.max_buffer_size_ = t[2].cast(); return obj; })); // CreateIndexOptions py::class_(m, "IndexOption", R"pbdoc( Options for creating an index. Attributes: concurrency (int): Number of threads to use during index creation. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = IndexOption(concurrency=4) >>> print(opt.concurrency) 4 )pbdoc") .def(py::init(), py::arg("concurrency") = 0, R"pbdoc( Constructs an IndexOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. )pbdoc") .def_property_readonly( "concurrency", [](const CreateIndexOptions &self) { return self.concurrency_; }, "int: Number of threads used for index creation (0 = auto).") .def(py::pickle( [](const CreateIndexOptions &self) { return py::make_tuple(self.concurrency_); }, [](py::tuple t) { if (t.size() != 1) throw std::runtime_error( "Invalid pickle data for CreateIndexOptions"); CreateIndexOptions obj{}; obj.concurrency_ = t[0].cast(); return obj; })); // OptimizeOptions py::class_(m, "OptimizeOption", R"pbdoc( Options for optimizing a collection (e.g., merging segments). Attributes: concurrency (int): Number of threads to use during optimization. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = OptimizeOption(concurrency=2) >>> print(opt.concurrency) 2 )pbdoc") .def(py::init(), py::arg("concurrency") = 0, R"pbdoc( Constructs an OptimizeOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. )pbdoc") .def_property_readonly( "concurrency", [](const OptimizeOptions &self) { return self.concurrency_; }, "int: Number of threads used for optimization (0 = auto).") .def(py::pickle( [](const OptimizeOptions &self) { return py::make_tuple(self.concurrency_); }, [](py::tuple t) { if (t.size() != 1) throw std::runtime_error( "Invalid pickle data for OptimizeOptions"); OptimizeOptions obj{}; obj.concurrency_ = t[0].cast(); return obj; })); // AddColumnOptions py::class_(m, "AddColumnOption", R"pbdoc( Options for adding a new column to a collection. Attributes: concurrency (int): Number of threads to use when backfilling data for the new column. If 0, auto-detect is used. Default is 0. Examples: >>> opt = AddColumnOption(concurrency=1) >>> print(opt.concurrency) 1 )pbdoc") .def(py::init(), py::arg("concurrency") = 0, R"pbdoc( Constructs an AddColumnOption instance. Args: concurrency (int, optional): Number of threads for data backfill. 0 means auto-detect. Defaults to 0. )pbdoc") .def_property_readonly( "concurrency", [](const AddColumnOptions &self) { return self.concurrency_; }, "int: Number of threads used when adding a column (0 = auto).") .def(py::pickle( [](const AddColumnOptions &self) { return py::make_tuple(self.concurrency_); }, [](py::tuple t) { if (t.size() != 1) throw std::runtime_error( "Invalid pickle data for AddColumnOptions"); AddColumnOptions obj{}; obj.concurrency_ = t[0].cast(); return obj; })); // AlterColumnOptions py::class_(m, "AlterColumnOption", R"pbdoc( Options for altering an existing column (e.g., changing index settings). Attributes: concurrency (int): Number of threads to use during the alteration process. If 0, the system will choose an optimal value automatically. Default is 0. Examples: >>> opt = AlterColumnOption(concurrency=1) >>> print(opt.concurrency) 1 )pbdoc") .def(py::init(), py::arg("concurrency") = 0, R"pbdoc( Constructs an AlterColumnOption instance. Args: concurrency (int, optional): Number of threads for column alteration. 0 means auto-detect. Defaults to 0. )pbdoc") .def_property_readonly( "concurrency", [](const AlterColumnOptions &self) { return self.concurrency_; }, "int: Number of threads used when altering a column (0 = auto).") .def(py::pickle( [](const AlterColumnOptions &self) { return py::make_tuple(self.concurrency_); }, [](py::tuple t) { if (t.size() != 1) throw std::runtime_error( "Invalid pickle data for AlterColumnOptions"); AlterColumnOptions obj{}; obj.concurrency_ = t[0].cast(); return obj; })); } void ZVecPyParams::bind_vector_query(py::module_ &m) { py::class_(m, "_VectorQuery") .def(py::init<>()) // properties .def_readwrite("topk", &VectorQuery::topk_) .def_readwrite("field_name", &VectorQuery::field_name_) .def_readwrite("filter", &VectorQuery::filter_) .def_readwrite("include_vector", &VectorQuery::include_vector_) .def_readwrite("query_params", &VectorQuery::query_params_) .def_readwrite("output_fields", &VectorQuery::output_fields_) // vector .def("set_vector", [](VectorQuery &self, const FieldSchema &field_schema, const py::object &obj) { const DataType data_type = field_schema.data_type(); // dense vector if (FieldSchema::is_dense_vector_field(data_type)) { if (!py::isinstance(obj)) { throw py::type_error("Dense vector[" + field_schema.name() + "] expects a ndarray, got " + std::string(py::str(py::type::of(obj)))); } const auto arr = obj.cast(); if (arr.ndim() != 1) { throw py::type_error("Dense vector expects 1D array, got " + std::to_string(arr.ndim()) + "D"); } const auto buf = arr.request(); switch (data_type) { case DataType::VECTOR_FP32: { self.query_vector_ = serialize_vector( static_cast(buf.ptr), buf.size); return; } case DataType::VECTOR_FP64: { self.query_vector_ = serialize_vector( static_cast(buf.ptr), buf.size); return; } case DataType::VECTOR_INT8: { self.query_vector_ = serialize_vector( static_cast(buf.ptr), buf.size); return; } case DataType::VECTOR_FP16: { self.query_vector_ = serialize_vector( static_cast(buf.ptr), buf.size); return; } default: throw py::type_error( "Unsupported dense vector type for ndarray input: " + std::to_string(static_cast(data_type))); } } // sparse vector if (FieldSchema::is_sparse_vector_field(data_type)) { if (!py::isinstance(obj)) { throw py::type_error("Sparse vector[" + field_schema.name() + "] expects a Python dict, got " + std::string(py::str(py::type::of(obj)))); } const auto sparse = obj.cast(); switch (data_type) { case DataType::SPARSE_VECTOR_FP16: { auto [indices, values] = serialize_sparse_vector( sparse, [](const py::handle &h, size_t idx) { float f = checked_cast( h, "Sparse value[" + std::to_string(idx) + "]", "FLOAT"); return ailego::Float16(f); }); self.query_sparse_indices_ = std::move(indices); self.query_sparse_values_ = std::move(values); break; } case DataType::SPARSE_VECTOR_FP32: { auto [indices, values] = serialize_sparse_vector( sparse, [](const py::handle &h, size_t idx) { return checked_cast( h, "Sparse value[" + std::to_string(idx) + "]", "FLOAT"); }); self.query_sparse_indices_ = std::move(indices); self.query_sparse_values_ = std::move(values); break; } default: throw py::type_error( "Unsupported sparse vector type: " + std::to_string(static_cast(data_type))); } return; } throw py::type_error("Unsupported vector field type for field: " + field_schema.name()); }) .def( "get_vector", [](const VectorQuery &self, const FieldSchema &field_schema) -> py::object { DataType data_type = field_schema.data_type(); if (FieldSchema::is_dense_vector_field(data_type)) { if (self.query_vector_.empty()) { throw std::runtime_error("No dense vector has been set"); } size_t byte_size = self.query_vector_.size(); const void *data = self.query_vector_.data(); switch (data_type) { case DataType::VECTOR_FP32: { if (byte_size % sizeof(float) != 0) { throw std::runtime_error( "Invalid buffer size for VECTOR_FP32"); } size_t dim = byte_size / sizeof(float); return py::array_t({dim}, {sizeof(float)}, static_cast(data)); } case DataType::VECTOR_FP64: { if (byte_size % sizeof(double) != 0) { throw std::runtime_error( "Invalid buffer size for VECTOR_FP64"); } size_t dim = byte_size / sizeof(double); return py::array_t({dim}, {sizeof(double)}, static_cast(data)); } case DataType::VECTOR_INT8: { if (byte_size % sizeof(int8_t) != 0) { throw std::runtime_error( "Invalid buffer size for VECTOR_INT8"); } size_t dim = byte_size / sizeof(int8_t); return py::array_t({dim}, {sizeof(int8_t)}, static_cast(data)); } case DataType::VECTOR_FP16: { if (byte_size % 2 != 0) { throw std::runtime_error( "Invalid buffer size for VECTOR_FP16"); } size_t dim = byte_size / 2; return py::array(py::dtype("float16"), {dim}, {2}, data); } default: throw py::type_error( "Unsupported dense vector type for get_vector: " + std::to_string(static_cast(data_type))); } } if (FieldSchema::is_sparse_vector_field(data_type)) { if (self.query_sparse_indices_.empty()) { return py::dict(); } // Deserialize indices: stored as uint32_t[] size_t indices_byte_size = self.query_sparse_indices_.size(); if (indices_byte_size % sizeof(uint32_t) != 0) { throw std::runtime_error( "Sparse indices buffer size not aligned to uint32_t"); } size_t n = indices_byte_size / sizeof(uint32_t); const uint32_t *indices = reinterpret_cast( self.query_sparse_indices_.data()); // Deserialize values switch (data_type) { case DataType::SPARSE_VECTOR_FP32: { if (self.query_sparse_values_.size() != n * sizeof(float)) { throw std::runtime_error( "Sparse FP32 values buffer size mismatch"); } const float *values = reinterpret_cast( self.query_sparse_values_.data()); py::dict result; for (size_t i = 0; i < n; ++i) { result[py::int_(indices[i])] = py::float_(values[i]); } return result; } case DataType::SPARSE_VECTOR_FP16: { if (self.query_sparse_values_.size() != n * sizeof(uint16_t)) { throw std::runtime_error( "Sparse FP16 values buffer size mismatch"); } const uint16_t *raw_bits = reinterpret_cast( self.query_sparse_values_.data()); py::dict result; for (size_t i = 0; i < n; ++i) { float f = ailego::FloatHelper::ToFP32(raw_bits[i]); result[py::int_(indices[i])] = py::float_(f); } return result; } default: throw py::type_error("Unsupported sparse vector type..."); } } throw py::type_error("Unsupported vector field type: " + field_schema.name()); }, py::arg("field_schema")) .def(py::pickle( [](const VectorQuery &self) { return py::make_tuple( self.topk_, self.field_name_, self.query_vector_, self.query_sparse_indices_, self.query_sparse_values_, self.filter_, self.include_vector_, self.output_fields_, self.query_params_ ? py::cast(self.query_params_) : py::none()); }, [](py::tuple t) { if (t.size() != 9) throw std::runtime_error("Invalid pickle data for VectorQuery"); VectorQuery obj{}; obj.topk_ = t[0].cast(); obj.field_name_ = t[1].cast(); obj.query_vector_ = t[2].cast(); obj.query_sparse_indices_ = t[3].cast(); obj.query_sparse_values_ = t[4].cast(); obj.filter_ = t[5].cast(); obj.include_vector_ = t[6].cast(); obj.output_fields_ = t[7].cast>(); if (!t[8].is_none()) { obj.query_params_ = t[8].cast(); } return obj; })); } } // namespace zvec ================================================ FILE: src/binding/python/model/python_collection.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_collection.h" #include #include namespace zvec { inline void throw_if_error(const Status &status) { switch (status.code()) { case StatusCode::OK: return; case StatusCode::NOT_FOUND: throw py::key_error(status.message()); case StatusCode::INVALID_ARGUMENT: throw py::value_error(status.message()); case StatusCode::INTERNAL_ERROR: case StatusCode::ALREADY_EXISTS: case StatusCode::NOT_SUPPORTED: case StatusCode::PERMISSION_DENIED: case StatusCode::FAILED_PRECONDITION: case StatusCode::UNKNOWN: default: throw std::runtime_error(status.message()); } } template T unwrap_expected(const tl::expected &exp) { if (exp.has_value()) { return exp.value(); } throw_if_error(exp.error()); return T{}; } void ZVecPyCollection::Initialize(pybind11::module_ &m) { py::class_ collection(m, "_Collection"); bind_db_methods(collection); bind_ddl_methods(collection); bind_dml_methods(collection); bind_dql_methods(collection); collection.def(py::pickle( [](const Collection &c) { return py::make_tuple(c.Path(), c.Schema(), c.Options()); }, [](py::tuple t) { if (t.size() != 3) { throw std::runtime_error("Invalid tuple size for Collection pickle"); } std::string path = t[0].cast(); auto schema = t[1].cast(); CollectionOptions options = t[2].cast(); auto result = Collection::Open(path, options); // auto result = Collection::CreateAndOpen(path, schema, options); return unwrap_expected(result); })); } void ZVecPyCollection::bind_db_methods( py::class_ &col) { col.def_static("CreateAndOpen", [](const std::string &path, const CollectionSchema &schema, const CollectionOptions &options) { auto result = Collection::CreateAndOpen(path, schema, options); return unwrap_expected(result); }) .def_static("Open", [](const std::string &path, const CollectionOptions &options) { auto result = Collection::Open(path, options); return unwrap_expected(result); }); } void ZVecPyCollection::bind_ddl_methods( py::class_ &col) { // bind collection properties col.def("Path", [](const Collection &self) { auto ret = self.Path(); return unwrap_expected(ret); }) .def("Options", [](const Collection &self) { auto ret = self.Options(); return unwrap_expected(ret); }) .def("Schema", [](const Collection &self) { auto ret = self.Schema(); return unwrap_expected(ret); }) .def("Stats", [](const Collection &self) { auto ret = self.Stats(); return unwrap_expected(ret); }); // bind collection ddl methods col.def("Destroy", [](Collection &self) { const auto status = self.Destroy(); throw_if_error(status); }) .def("Flush", [](Collection &self) { auto status = self.Flush(); throw_if_error(status); }); // binding index ddl methods col.def("CreateIndex", [](Collection &self, const std::string &column_name, const IndexParams::Ptr &index_options, const CreateIndexOptions &options) { const auto status = self.CreateIndex(column_name, index_options, options); throw_if_error(status); }) .def("DropIndex", [](Collection &self, const std::string &column_name) { const auto status = self.DropIndex(column_name); throw_if_error(status); }) .def("Optimize", [](Collection &self, const OptimizeOptions &options) { const auto status = self.Optimize(options); throw_if_error(status); }); // binding column ddl methods col.def("AddColumn", [](Collection &self, const FieldSchema::Ptr &column_schema, const std::string &expression, const AddColumnOptions &options) { const auto status = self.AddColumn(column_schema, expression, options); throw_if_error(status); }) .def("DropColumn", [](Collection &self, std::string &column_name) { auto status = self.DropColumn(column_name); throw_if_error(status); }) .def("AlterColumn", [](Collection &self, std::string &column_name, const std::string &rename, const FieldSchema::Ptr &new_column_schema, const AlterColumnOptions &options) { const auto status = self.AlterColumn(column_name, rename, new_column_schema, options); throw_if_error(status); }); } void ZVecPyCollection::bind_dml_methods( py::class_ &col) { // bind collection upsert/insert/update/delete methods col.def("Insert", [](Collection &self, std::vector &docs) { const auto result = self.Insert(docs); return unwrap_expected(result); }) .def("Update", [](Collection &self, std::vector &docs) { const auto result = self.Update(docs); return unwrap_expected(result); }) .def("Upsert", [](Collection &self, std::vector &docs) { const auto result = self.Upsert(docs); return unwrap_expected(result); }) .def("Delete", [](Collection &self, const std::vector &pks) { const auto result = self.Delete(pks); return unwrap_expected(result); }) .def("DeleteByFilter", [](Collection &self, const std::string &filter) { const auto status = self.DeleteByFilter(filter); throw_if_error(status); }); } void ZVecPyCollection::bind_dql_methods( py::class_ &col) { col.def("Query", [](const Collection &self, const VectorQuery &query) { const auto result = self.Query(query); // return DocPtrList return unwrap_expected(result); }) .def("GroupByQuery", [](const Collection &self, const GroupByVectorQuery &query) { const auto result = self.GroupByQuery(query); // return GroupResults return unwrap_expected(result); }) .def("Fetch", [](const Collection &self, const std::vector &pks) { const auto result = self.Fetch(pks); // return DocPtrMap return unwrap_expected(result); }); } } // namespace zvec ================================================ FILE: src/binding/python/model/python_doc.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_doc.h" #include #include namespace zvec { template T checked_cast(const py::object &obj, const std::string &field, const std::string &expected_type) { try { return obj.cast(); } catch (const py::cast_error &e) { std::string actual_type = std::string(py::str(py::type::of(obj))); std::string msg = "Field '" + field + "': expected " + expected_type + ", got " + actual_type; throw py::type_error(msg); } } void ZVecPyDoc::Initialize(pybind11::module_ &m) { bind_doc_operator(m); bind_doc(m); } void ZVecPyDoc::bind_doc_operator(py::module_ &m) { py::enum_(m, "_DocOp") .value("INSERT", Operator::INSERT) .value("UPDATE", Operator::UPDATE) .value("DELETE", Operator::DELETE) .value("UPSERT", Operator::UPSERT); } void ZVecPyDoc::bind_doc(py::module_ &m) { // binding doc py::class_ doc(m, "_Doc"); doc.def(py::init([]() { return std::make_shared(); })) .def("set_pk", &Doc::set_pk) .def("pk", &Doc::pk) .def("set_score", &Doc::set_score) .def("score", &Doc::score) .def("has_field", &Doc::has) .def("field_names", &Doc::field_names) .def(py::pickle( [](const Doc &d) { std::vector data = d.serialize(); return py::bytes(reinterpret_cast(data.data()), data.size()); }, [](py::bytes b) { py::buffer_info info(py::buffer(b).request()); const uint8_t *buf = reinterpret_cast(info.ptr); size_t size = static_cast(info.size); Doc::Ptr d = Doc::deserialize(buf, size); if (!d) throw std::runtime_error("Failed to unpickle Doc"); return d; })); // binding doc set field doc.def( "set_any", [](Doc &self, const std::string &field, const FieldSchema &field_schema, const py::object &obj) -> bool { if (obj.is_none()) { if (field_schema.nullable()) { self.set_null(field); return true; } throw py::value_error("Field '" + field + "': expected non-nullable type"); } switch (field_schema.data_type()) { // base datatypes case DataType::STRING: return self.set(field, checked_cast(obj, field, "STRING")); case DataType::BOOL: return self.set(field, checked_cast(obj, field, "BOOL")); case DataType::INT32: return self.set(field, checked_cast(obj, field, "INT32")); case DataType::INT64: return self.set(field, checked_cast(obj, field, "INT64")); case DataType::UINT32: return self.set(field, checked_cast(obj, field, "UINT32")); case DataType::UINT64: return self.set(field, checked_cast(obj, field, "UINT64")); case DataType::FLOAT: return self.set(field, checked_cast(obj, field, "FLOAT")); case DataType::DOUBLE: return self.set(field, checked_cast(obj, field, "DOUBLE")); // array datatypes case DataType::ARRAY_STRING: return self.set(field, checked_cast>( obj, field, "ARRAY_STRING")); case DataType::ARRAY_BOOL: return self.set(field, checked_cast>( obj, field, "ARRAY_BOOL")); case DataType::ARRAY_INT32: return self.set(field, checked_cast>( obj, field, "ARRAY_INT32")); case DataType::ARRAY_UINT32: return self.set(field, checked_cast>( obj, field, "ARRAY_UINT32")); case DataType::ARRAY_INT64: return self.set(field, checked_cast>( obj, field, "ARRAY_INT64")); case DataType::ARRAY_UINT64: return self.set(field, checked_cast>( obj, field, "ARRAY_UINT64")); case DataType::ARRAY_FLOAT: return self.set(field, checked_cast>( obj, field, "ARRAY_FLOAT")); case DataType::ARRAY_DOUBLE: return self.set(field, checked_cast>( obj, field, "ARRAY_DOUBLE")); // dense vector datatypes case DataType::VECTOR_FP16: { const auto value = checked_cast( obj, field, "VECTOR_FP16 (list of numbers)"); std::vector new_value; new_value.reserve(value.size()); for (const auto &item : value) { try { new_value.emplace_back(item.cast()); } catch (const py::cast_error &e) { throw py::type_error("Vector '" + field + "': expected VECTOR_FP16, got " + std::string(py::str(py::type::of(obj)))); } } return self.set(field, new_value); } case DataType::VECTOR_FP32: return self.set(field, checked_cast>( obj, field, "VECTOR_FP32")); case DataType::VECTOR_FP64: return self.set(field, checked_cast>( obj, field, "VECTOR_FP64")); case DataType::VECTOR_INT8: return self.set(field, checked_cast>( obj, field, "VECTOR_INT8")); // sparse vector datatypes case DataType::SPARSE_VECTOR_FP32: { const auto sparse_dict = checked_cast(obj, field, "SPARSE_VECTOR_FP32 (dict)"); std::vector indices; std::vector values; for (const auto &item : sparse_dict) { try { indices.push_back(item.first.cast()); values.push_back(item.second.cast()); } catch (const py::cast_error &e) { throw py::type_error( "Vector '" + field + "': sparse vector key/value must be (uint32, float), " "got key=" + std::string(py::str(py::type::of(item.first))) + ", value=" + std::string(py::str(py::type::of(item.second)))); } } const std::pair, std::vector> sparse_vector{std::move(indices), std::move(values)}; return self.set(field, sparse_vector); } case DataType::SPARSE_VECTOR_FP16: { const auto sparse_dict = checked_cast(obj, field, "SPARSE_VECTOR_FP16 (dict)"); std::vector indices; std::vector values; for (const auto &item : sparse_dict) { try { indices.push_back(item.first.cast()); values.push_back(ailego::Float16(item.second.cast())); } catch (const py::cast_error &e) { throw py::type_error( "Field '" + field + "': sparse vector key/value must be (uint32, float), " "got key=" + std::string(py::str(py::type::of(item.first))) + ", value=" + std::string(py::str(py::type::of(item.second)))); } } const std::pair, std::vector> sparse_vector{std::move(indices), std::move(values)}; return self.set(field, sparse_vector); } default: throw py::type_error("Unsupported type for field: " + field); } }); // binding doc get field doc.def( "get_any", [](Doc &self, const std::string &field, const DataType &type) -> py::object { switch (type) { // base datatypes case DataType::STRING: return py::cast(self.get(field)); case DataType::BOOL: return py::cast(self.get(field)); case DataType::INT32: return py::cast(self.get(field)); case DataType::UINT32: return py::cast(self.get(field)); case DataType::INT64: return py::cast(self.get(field)); case DataType::UINT64: return py::cast(self.get(field)); case DataType::FLOAT: return py::cast(self.get(field)); case DataType::DOUBLE: return py::cast(self.get(field)); // array datatypes case DataType::ARRAY_STRING: return py::cast(self.get>(field)); case DataType::ARRAY_INT32: return py::cast(self.get>(field)); case DataType::ARRAY_INT64: return py::cast(self.get>(field)); case DataType::ARRAY_UINT32: return py::cast(self.get>(field)); case DataType::ARRAY_UINT64: return py::cast(self.get>(field)); case DataType::ARRAY_FLOAT: return py::cast(self.get>(field)); case DataType::ARRAY_DOUBLE: return py::cast(self.get>(field)); case DataType::ARRAY_BOOL: return py::cast(self.get>(field)); // vector datatypes case DataType::VECTOR_INT8: return py::cast(self.get>(field)); case DataType::VECTOR_FP16: { auto value = self.get>(field); if (value.has_value()) { std::vector new_value; new_value.reserve(value.value().size()); for (auto &item : value.value()) { new_value.push_back(static_cast(item)); } return py::cast(new_value); } return py::none(); } case DataType::VECTOR_FP32: return py::cast(self.get>(field)); case DataType::VECTOR_FP64: return py::cast(self.get>(field)); case DataType::SPARSE_VECTOR_FP16: { auto vector = self.get< std::pair, std::vector>>( field); const auto &indices = vector->first; const auto &values = vector->second; py::dict d; for (size_t i = 0; i < indices.size(); ++i) { d[py::int_(indices[i])] = py::float_(static_cast(values[i])); } return std::move(d); } case DataType::SPARSE_VECTOR_FP32: { auto vector = self.get, std::vector>>( field); const auto &indices = vector->first; const auto &values = vector->second; py::dict d; for (size_t i = 0; i < indices.size(); ++i) { d[py::int_(indices[i])] = py::float_(values[i]); } return std::move(d); } default: throw py::type_error("Unsupported type for field: " + field); } }); doc.def( "get_all", [](Doc &self, const CollectionSchema &schema) -> py::tuple { py::tuple result(4); // 1. set doc id and score result[0] = py::str(self.pk()); result[1] = py::float_(self.score()); if (self.is_empty()) { result[2] = py::none(); result[3] = py::none(); return result; } // 2. set scalar fields py::dict fields; for (const auto &field_meta : schema.forward_fields()) { const std::string &field = field_meta->name(); if (!self.has_value(field)) { continue; } try { auto val = [&]() -> py::object { switch (field_meta->data_type()) { // base datatypes case DataType::STRING: return py::str(self.get(field).value()); case DataType::BOOL: return py::cast(self.get(field)); case DataType::INT32: return py::cast(self.get(field)); case DataType::UINT32: return py::cast(self.get(field)); case DataType::INT64: return py::cast(self.get(field)); case DataType::UINT64: return py::cast(self.get(field)); case DataType::FLOAT: return py::cast(self.get(field)); case DataType::DOUBLE: return py::cast(self.get(field)); // array datatypes case DataType::ARRAY_STRING: return py::cast(self.get>(field)); case DataType::ARRAY_INT32: return py::cast(self.get>(field)); case DataType::ARRAY_INT64: return py::cast(self.get>(field)); case DataType::ARRAY_UINT32: return py::cast(self.get>(field)); case DataType::ARRAY_UINT64: return py::cast(self.get>(field)); case DataType::ARRAY_FLOAT: return py::cast(self.get>(field)); case DataType::ARRAY_DOUBLE: return py::cast(self.get>(field)); case DataType::ARRAY_BOOL: return py::cast(self.get>(field)); default: throw py::type_error("Unsupported type for field: " + field); } }(); fields[py::str(field)] = val; } catch (const std::exception &e) { fields[py::str(field)] = py::none(); } } if (!fields.empty()) { result[2] = fields; } else { result[2] = py::none(); } // 3. set vector fields py::dict vectors; for (const auto &vec_meta : schema.vector_fields()) { const std::string &vec = vec_meta->name(); if (!self.has_value(vec)) continue; try { auto array = [&]() -> py::object { switch (vec_meta->data_type()) { case DataType::VECTOR_INT8: return py::cast(self.get>(vec)); case DataType::VECTOR_FP16: { auto value = self.get>(vec); if (value.has_value()) { std::vector new_value; new_value.reserve(value.value().size()); for (auto &item : value.value()) { new_value.push_back(static_cast(item)); } return py::cast(new_value); } return py::none(); } case DataType::VECTOR_FP32: return py::cast(self.get>(vec)); case DataType::VECTOR_FP64: return py::cast(self.get>(vec)); case DataType::SPARSE_VECTOR_FP16: { auto vector = self.get, std::vector>>(vec); const auto &indices = vector->first; const auto &values = vector->second; py::dict d; for (size_t i = 0; i < indices.size(); ++i) { d[py::int_(indices[i])] = py::float_(static_cast(values[i])); } return std::move(d); } case DataType::SPARSE_VECTOR_FP32: { auto vector = self.get< std::pair, std::vector>>( vec); const auto &indices = vector->first; const auto &values = vector->second; py::dict d; for (size_t i = 0; i < indices.size(); ++i) { d[py::int_(indices[i])] = py::float_(values[i]); } return std::move(d); } default: throw py::type_error("Unsupported type for field: " + vec); } }(); vectors[py::str(vec)] = array; } catch (const std::exception &e) { vectors[py::str(vec)] = py::none(); } } if (!vectors.empty()) { result[3] = vectors; } else { result[3] = py::none(); } return result; }, py::arg("schema"), "Get all fields and vectors as a tuple: (id, score, fields, vectors). " "Vectors are zero-copy numpy arrays (dense: ndarray, sparse: (indices, " "values) tuple)."); } } // namespace zvec ================================================ FILE: src/binding/python/model/schema/python_schema.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_schema.h" #include #include #include namespace zvec { void ZVecPySchemas::Initialize(pybind11::module_ &parent) { auto m = parent.def_submodule("schema", "This module contains the schema of Zvec"); bind_field_schema(m); bind_collection_schema(m); bind_collection_stats(m); } void ZVecPySchemas::bind_field_schema(py::module_ &m) { py::class_(m, "_FieldSchema") .def(py::init(), py::arg("name"), py::arg("data_type"), py::arg("dimension") = 0, py::arg("nullable") = false, py::arg("index_param") = nullptr) .def_property_readonly("name", &FieldSchema::name) .def_property_readonly("data_type", &FieldSchema::data_type) .def_property_readonly("nullable", &FieldSchema::nullable) .def_property_readonly("dimension", &FieldSchema::dimension) .def_property_readonly("is_dense_vector", &FieldSchema::is_dense_vector) .def_property_readonly("is_sparse_vector", &FieldSchema::is_sparse_vector) .def_property_readonly("index_type", [](const FieldSchema &self) { return self.index_params() ? self.index_type() : IndexType::UNDEFINED; }) .def_property_readonly("index_param", [](const FieldSchema &self) -> py::object { if (self.index_params()) { return py::cast(self.index_params()); } return py::none(); }) .def("__eq__", &FieldSchema::operator==) .def("__ne__", &FieldSchema::operator!=) .def(py::pickle( [](const FieldSchema &self) { return py::make_tuple(self.name(), self.data_type(), self.dimension(), self.nullable(), self.index_params() ? py::cast(self.index_params()) : py::none()); }, [](py::tuple t) { if (t.size() != 5) { throw std::runtime_error( "Invalid tuple size for FieldSchema pickle"); } std::string name = t[0].cast(); DataType dtype = t[1].cast(); uint32_t dim = t[2].cast(); bool nullable = t[3].cast(); IndexParams::Ptr idx_params = nullptr; if (!t[4].is_none()) { idx_params = t[4].cast(); } return std::make_shared(name, dtype, dim, nullable, idx_params); })); } void ZVecPySchemas::bind_collection_schema(py::module_ &m) { py::class_(m, "_CollectionSchema") .def(py::init(), py::arg("name"), py::arg("fields"), "Construct with name and list of fields") .def_property_readonly("name", &CollectionSchema::name) .def("has_field", &CollectionSchema::has_field, py::arg("field_name"), "Check if a field exists.") .def( "get_field", [](const CollectionSchema &self, const std::string &name) -> const FieldSchema * { return self.get_field(name); }, py::arg("field_name"), py::return_value_policy::reference_internal, "Get field by name (const pointer), returns None if not found.") .def( "get_forward_field", [](const CollectionSchema &self, const std::string &name) -> const FieldSchema * { return self.get_forward_field(name); }, py::arg("field_name"), py::return_value_policy::reference_internal, "Get forward field (used for filtering).") .def( "get_vector_field", [](const CollectionSchema &self, const std::string &name) -> const FieldSchema * { return self.get_vector_field(name); }, py::arg("field_name"), py::return_value_policy::reference_internal, "Get vector field by name.") .def("fields", &CollectionSchema::fields, "Return list of all field schemas.", py::return_value_policy::copy) .def("forward_fields", &CollectionSchema::forward_fields, "Return list of forward-indexed fields.", py::return_value_policy::copy) .def("vector_fields", &CollectionSchema::vector_fields, "Return list of vector fields.", py::return_value_policy::copy) .def("__eq__", &CollectionSchema::operator==) .def("__ne__", &CollectionSchema::operator!=) .def(py::pickle( [](const CollectionSchema &cs) { return py::make_tuple(cs.name(), cs.fields(), cs.max_doc_count_per_segment()); }, [](py::tuple t) { if (t.size() != 3) throw std::runtime_error("Invalid state for CollectionSchema!"); auto name = t[0].cast(); auto fields = t[1].cast(); auto max_docs = t[2].cast(); auto cs = std::make_shared(name, fields); cs->set_max_doc_count_per_segment(max_docs); return cs; })); } void ZVecPySchemas::bind_collection_stats(py::module_ &m) { pybind11::class_(m, "CollectionStats") .def(pybind11::init<>()) .def_property_readonly( "doc_count", [](const CollectionStats &c) { return c.doc_count; }) .def_property_readonly( "index_completeness", [](const CollectionStats &c) { return c.index_completeness; }) .def("__repr__", [](const CollectionStats &c) { std::string map_str = "{"; bool first = true; for (const auto &[k, v] : c.index_completeness) { if (!first) map_str += ", "; map_str += "\"" + k + "\":" + std::to_string(v); first = false; } map_str += "}"; return "{\"doc_count\":" + std::to_string(c.doc_count) + ", \"index_completeness\":" + map_str + "}"; }); } } // namespace zvec ================================================ FILE: src/binding/python/typing/python_type.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "python_type.h" namespace zvec { void ZVecPyTyping::Initialize(pybind11::module_ &parent) { auto m = parent.def_submodule( "typing", "This module contains the basic data types of Zvec"); // binding base types bind_datatypes(m); bind_index_types(m); bind_metric_types(m); bind_quantize_types(m); bind_status(m); } void ZVecPyTyping::bind_datatypes(pybind11::module_ &m) { py::enum_(m, "DataType", R"pbdoc( Enumeration of supported data types in Zvec. Includes scalar types, dense/sparse vector types, and array types. Examples: >>> from zvec.typing import DataType >>> print(DataType.FLOAT) DataType.FLOAT >>> print(DataType.VECTOR_FP32) DataType.VECTOR_FP32 )pbdoc") // field type .value("STRING", DataType::STRING) .value("BOOL", DataType::BOOL) .value("INT32", DataType::INT32) .value("INT64", DataType::INT64) .value("FLOAT", DataType::FLOAT) .value("DOUBLE", DataType::DOUBLE) .value("UINT32", DataType::UINT32) .value("UINT64", DataType::UINT64) // dense vector type .value("VECTOR_FP16", DataType::VECTOR_FP16) .value("VECTOR_FP32", DataType::VECTOR_FP32) .value("VECTOR_FP64", DataType::VECTOR_FP64) .value("VECTOR_INT8", DataType::VECTOR_INT8) // sparse vector type .value("SPARSE_VECTOR_FP32", DataType::SPARSE_VECTOR_FP32) .value("SPARSE_VECTOR_FP16", DataType::SPARSE_VECTOR_FP16) // array type [not support bool/bytes] .value("ARRAY_STRING", DataType::ARRAY_STRING) .value("ARRAY_INT32", DataType::ARRAY_INT32) .value("ARRAY_INT64", DataType::ARRAY_INT64) .value("ARRAY_FLOAT", DataType::ARRAY_FLOAT) .value("ARRAY_DOUBLE", DataType::ARRAY_DOUBLE) .value("ARRAY_BOOL", DataType::ARRAY_BOOL) .value("ARRAY_UINT32", DataType::ARRAY_UINT32) .value("ARRAY_UINT64", DataType::ARRAY_UINT64) // non support // .value("BINARY", DataType::BINARY) // .value("ARRAY_BINARY", DataType::ARRAY_BINARY) // .value("VECTOR_INT4", DataType::VECTOR_INT4) // .value("VECTOR_INT16", DataType::VECTOR_INT16) // .value("VECTOR_BINARY32", DataType::VECTOR_BINARY32) // .value("VECTOR_BINARY64", DataType::VECTOR_BINARY64) // .value("UNDEFINED", DataType::UNDEFINED) ; } void ZVecPyTyping::bind_index_types(pybind11::module_ &m) { py::enum_(m, "IndexType", R"pbdoc( Enumeration of supported index types in Zvec. Examples: >>> from zvec.typing import IndexType >>> print(IndexType.HNSW) IndexType.HNSW )pbdoc") .value("UNDEFINED", IndexType::UNDEFINED) .value("HNSW", IndexType::HNSW) .value("HNSW_RABITQ", IndexType::HNSW_RABITQ) .value("IVF", IndexType::IVF) .value("FLAT", IndexType::FLAT) .value("INVERT", IndexType::INVERT); } void ZVecPyTyping::bind_metric_types(pybind11::module_ &m) { py::enum_(m, "MetricType", R"pbdoc( Enumeration of supported distance/similarity metrics. - COSINE: Cosine similarity. - IP: Inner product (dot product). - L2: Euclidean distance (L2 norm). Examples: >>> from zvec.typing import MetricType >>> print(MetricType.COSINE) MetricType.COSINE )pbdoc") .value("COSINE", MetricType::COSINE) .value("IP", MetricType::IP) .value("L2", MetricType::L2); } void ZVecPyTyping::bind_quantize_types(py::module_ &m) { py::enum_(m, "QuantizeType", R"pbdoc( Enumeration of supported quantization types for vector compression. Examples: >>> from zvec.typing import QuantizeType >>> print(QuantizeType.INT8) QuantizeType.INT8 )pbdoc") .value("UNDEFINED", QuantizeType::UNDEFINED) .value("FP16", QuantizeType::FP16) .value("INT8", QuantizeType::INT8) .value("INT4", QuantizeType::INT4) .value("RABITQ", QuantizeType::RABITQ); } void ZVecPyTyping::bind_status(py::module_ &m) { // bind status code py::enum_(m, "StatusCode", R"pbdoc( Enumeration of possible status codes for Zvec operations. Used by the `Status` class to indicate success or failure reason. )pbdoc") .value("OK", StatusCode::OK) .value("NOT_FOUND", StatusCode::NOT_FOUND) .value("ALREADY_EXISTS", StatusCode::ALREADY_EXISTS) .value("INVALID_ARGUMENT", StatusCode::INVALID_ARGUMENT) .value("PERMISSION_DENIED", StatusCode::PERMISSION_DENIED) .value("FAILED_PRECONDITION", StatusCode::FAILED_PRECONDITION) .value("RESOURCE_EXHAUSTED", StatusCode::RESOURCE_EXHAUSTED) .value("UNAVAILABLE", StatusCode::UNAVAILABLE) .value("INTERNAL_ERROR", StatusCode::INTERNAL_ERROR) .value("NOT_SUPPORTED", StatusCode::NOT_SUPPORTED) .value("UNKNOWN", StatusCode::UNKNOWN); // bind status py::class_(m, "Status", R"pbdoc( Represents the outcome of a Zvec operation. A `Status` object is either OK (success) or carries an error code and message. Examples: >>> from zvec.typing import Status, StatusCode >>> s = Status() >>> print(s.ok()) True >>> s = Status(StatusCode.INVALID_ARGUMENT, "Field not found") >>> print(s.code() == StatusCode.INVALID_ARGUMENT) True >>> print(s.message()) Field not found )pbdoc") .def(py::init<>()) .def(py::init(), py::arg("code"), py::arg("message") = "", R"pbdoc( Construct a status with the given code and optional message. Args: code (StatusCode): The status code. message (str, optional): Error message. Defaults to empty string. )pbdoc") .def("ok", &Status::ok, "bool: Returns True if the status is OK.") .def("code", &Status::code, "StatusCode: Returns the status code.") .def("message", &Status::message, "str: Returns the error message (may be empty).") .def_static("OK", &Status::OK, "Create an OK status.") .def_static( "InvalidArgument", [](const std::string &msg) { return Status::InvalidArgument(msg); }, py::arg("message")) .def_static( "NotFound", [](const std::string &msg) { return Status::NotFound(msg); }, py::arg("message")) .def_static( "AlreadyExists", [](const std::string &msg) { return Status::AlreadyExists(msg); }, py::arg("message")) .def_static( "InternalError", [](const std::string &msg) { return Status::InternalError(msg); }, py::arg("message")) .def_static( "PermissionDenied", [](const std::string &msg) { return Status::PermissionDenied(msg); }, py::arg("message")) .def("__eq__", [](const Status &self, const Status &other) { return self == other; }) .def("__ne__", [](const Status &self, const Status &other) { return self != other; }) .def("__repr__", [](const Status &self) { std::string result = "{" "\"code\":" + std::to_string(static_cast(self.code())); if (!self.message().empty()) { result += ", \"message\":\"" + self.message() + "\""; } result += "}"; return result; }); } } // namespace zvec ================================================ FILE: src/core/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) if(RABITQ_SUPPORTED AND AUTO_DETECT_ARCH) set(HNSW_RABITQ_FILES hnsw_rabitq_query_algorithm.cc hnsw_rabitq_streamer.cc hnsw_rabitq_searcher.cc hnsw_rabitq_entity.cc rabitq_reformer.cc rabitq_converter.cc ) set(HNSW_RABITQ_FILES_FULL ${HNSW_RABITQ_FILES}) list(TRANSFORM HNSW_RABITQ_FILES_FULL PREPEND "algorithm/hnsw_rabitq/") foreach(FILE ${HNSW_RABITQ_FILES_FULL}) set_source_files_properties( ${FILE} PROPERTIES COMPILE_FLAGS "${RABITQ_ARCH_FLAG}" ) endforeach() endif() cc_directory(framework) cc_directory(algorithm) cc_directory(metric) cc_directory(quantizer) cc_directory(utility) cc_directory(interface) cc_directory(mixed_reducer) git_version(GIT_SRCS_VER ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE ALL_CORE_SRCS *.cc *.c *.h) # Remove algorithm/hnsw_rabitq implementation files if not supported. # interface/indexes/hnsw_rabitq_index.cc is kept because it provides the vtable # for HNSWRabitqIndex and guards rabitqlib usage with #if RABITQ_SUPPORTED. if(NOT RABITQ_SUPPORTED) list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/hnsw_rabitq/.*") endif() cc_library( NAME zvec_core STATIC STRICT PACKED SRCS ${ALL_CORE_SRCS} LIBS zvec_ailego zvec_turbo sparsehash magic_enum rabitqlib INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${GIT_SRCS_VER}" ) ================================================ FILE: src/core/algorithm/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_directory(cluster) cc_directory(flat) cc_directory(flat_sparse) cc_directory(ivf) cc_directory(hnsw) cc_directory(hnsw_sparse) if(RABITQ_SUPPORTED) message(STATUS "BUILD RABITQ") cc_directory(hnsw_rabitq) else() message(STATUS "NOT BUILD RABITQ") # Empty stub library for unsupported platforms file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc "// Stub implementation for unsupported platforms\n" "// RaBitQ only supports Linux x86_64\n" "namespace zvec { namespace core { /* empty namespace for compatibility */ } }\n" ) cc_library( NAME core_knn_hnsw_rabitq STATIC SHARED STRICT ALWAYS_LINK SRCS ${CMAKE_CURRENT_BINARY_DIR}/rabitq_stub.cc LIBS core_framework INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) endif() ================================================ FILE: src/core/algorithm/cluster/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_knn_cluster STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS zvec_ailego core_framework INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/cluster VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/cluster/cluster_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { //! General static const std::string GENERAL_CLUSTER_COUNT = "proxima.general.cluster.count"; static const std::string GENERAL_THREAD_COUNT = "proxima.general.cluster.thread_count"; //! Optimize K-means static const std::string OPTKMEANS_CLUSTER_COUNT = "proxima.optkmeans.cluster.count"; static const std::string OPTKMEANS_CLUSTER_MAX_ITERATIONS = "proxima.optkmeans.cluster.max_iterations"; static const std::string OPTKMEANS_CLUSTER_EPSILON = "proxima.optkmeans.cluster.epsilon"; static const std::string OPTKMEANS_CLUSTER_SHARD_FACTOR = "proxima.optkmeans.cluster.shard_factor"; static const std::string OPTKMEANS_CLUSTER_PURGE_EMPTY = "proxima.optkmeans.cluster.purge_empty"; static const std::string OPTKMEANS_CLUSTER_MARKOV_CHAIN_LENGTH = "proxima.optkmeans.cluster.markov_chain_length"; static const std::string OPTKMEANS_CLUSTER_ASSUMPTION_FREE = "proxima.optkmeans.cluster.assumption_free"; //! K-means static const std::string KMEANS_CLUSTER_COUNT = "proxima.kmeans.cluster.count"; static const std::string KMEANS_CLUSTER_SHARD_FACTOR = "proxima.kmeans.cluster.shard_factor"; static const std::string KMEANS_CLUSTER_EPSILON = "proxima.kmeans.cluster.epsilon"; static const std::string KMEANS_CLUSTER_MAX_ITERATIONS = "proxima.kmeans.cluster.max_iterations"; static const std::string KMEANS_CLUSTER_PURGE_EMPTY = "proxima.kmeans.cluster.purge_empty"; static const std::string KMEANS_CLUSTER_BATCH = "proxima.kmeans.cluster.batch"; static const std::string KMEANS_CLUSTER_SEEKER_CLASS = "proxima.kmeans.cluster.seeker_class"; static const std::string KMEANS_CLUSTER_SEEKER_PARAMS = "proxima.kmeans.cluster.seeker_params"; //! Mini Batch K-means static const std::string MINIBATCHKMEANS_CLUSTER_COUNT = "proxima.minibatchkmeans.cluster.count"; static const std::string MINIBATCHKMEANS_CLUSTER_SHARD_FACTOR = "proxima.minibatchkmeans.cluster.shard_factor"; static const std::string MINIBATCHKMEANS_CLUSTER_EPSILON = "proxima.minibatchkmeans.cluster.epsilon"; static const std::string MINIBATCHKMEANS_CLUSTER_MAX_ITERATIONS = "proxima.minibatchkmeans.cluster.max_iterations"; static const std::string MINIBATCHKMEANS_CLUSTER_PURGE_EMPTY = "proxima.minibatchkmeans.cluster.purge_empty"; static const std::string MINIBATCHKMEANS_CLUSTER_TRY_COUNT = "proxima.minibatchkmeans.cluster.try_count"; static const std::string MINIBATCHKMEANS_CLUSTER_BATCH_COUNT = "proxima.minibatchkmeans.cluster.batch_count"; static const std::string MINIBATCHKMEANS_CLUSTER_SEEKER_CLASS = "proxima.minibatchkmeans.cluster.seeker_class"; static const std::string MINIBATCHKMEANS_CLUSTER_SEEKER_PARAMS = "proxima.minibatchkmeans.cluster.seeker_params"; //! K-means++ static const std::string KMEANSPP_CLUSTER_COUNT = "proxima.kmeanspp.cluster.count"; static const std::string KMEANSPP_CLUSTER_SHARD_FACTOR = "proxima.kmeanspp.cluster.shard_factor"; static const std::string KMEANSPP_CLUSTER_CLASS = "proxima.kmeanspp.cluster.class"; static const std::string KMEANSPP_CLUSTER_PARAMS = "proxima.kmeanspp.cluster.params"; //! K-MC2 static const std::string KMC2_CLUSTER_COUNT = "proxima.kmc2.cluster.count"; static const std::string KMC2_CLUSTER_SHARD_FACTOR = "proxima.kmc2.cluster.shard_factor"; static const std::string KMC2_CLUSTER_MARKOV_CHAIN_LENGTH = "proxima.kmc2.cluster.markov_chain_length"; static const std::string KMC2_CLUSTER_ASSUMPTION_FREE = "proxima.kmc2.cluster.assumption_free"; static const std::string KMC2_CLUSTER_CLASS = "proxima.kmc2.cluster.class"; static const std::string KMC2_CLUSTER_PARAMS = "proxima.kmc2.cluster.params"; //! Bisecting K-means static const std::string BIKMEANS_CLUSTER_COUNT = "proxima.bikmeans.cluster.count"; static const std::string BIKMEANS_CLUSTER_INIT_COUNT = "proxima.bikmeans.cluster.init_count"; static const std::string BIKMEANS_CLUSTER_PURGE_EMPTY = "proxima.bikmeans.cluster.purge_empty"; static const std::string BIKMEANS_CLUSTER_FIRST_CLASS = "proxima.bikmeans.cluster.first_class"; static const std::string BIKMEANS_CLUSTER_SECOND_CLASS = "proxima.bikmeans.cluster.second_class"; static const std::string BIKMEANS_CLUSTER_FIRST_PARAMS = "proxima.bikmeans.cluster.first_params"; static const std::string BIKMEANS_CLUSTER_SECOND_PARAMS = "proxima.bikmeans.cluster.second_params"; //! K-medoids static const std::string KMEDOIDS_CLUSTER_COUNT = "proxima.kmedoids.cluster.count"; static const std::string KMEDOIDS_CLUSTER_SHARD_FACTOR = "proxima.kmedoids.cluster.shard_factor"; static const std::string KMEDOIDS_CLUSTER_EPSILON = "proxima.kmedoids.cluster.epsilon"; static const std::string KMEDOIDS_CLUSTER_MAX_ITERATIONS = "proxima.kmedoids.cluster.max_iterations"; static const std::string KMEDOIDS_CLUSTER_PURGE_EMPTY = "proxima.kmedoids.cluster.purge_empty"; static const std::string KMEDOIDS_CLUSTER_BENCH_RATIO = "proxima.kmedoids.cluster.bench_ratio"; static const std::string KMEDOIDS_CLUSTER_ONLY_MEANS = "proxima.kmedoids.cluster.only_means"; static const std::string KMEDOIDS_CLUSTER_WITHOUT_MEANS = "proxima.kmedoids.cluster.without_means"; static const std::string KMEDOIDS_CLUSTER_SEEKER_CLASS = "proxima.kmedoids.cluster.seeker_class"; static const std::string KMEDOIDS_CLUSTER_SEEKER_PARAMS = "proxima.kmedoids.cluster.seeker_params"; //! Stratified static const std::string STRATIFIED_CLUSTER_COUNT = "proxima.stratified.cluster.count"; static const std::string STRATIFIED_CLUSTER_FIRST_CLASS = "proxima.stratified.cluster.first_class"; static const std::string STRATIFIED_CLUSTER_SECOND_CLASS = "proxima.stratified.cluster.second_class"; static const std::string STRATIFIED_CLUSTER_FIRST_COUNT = "proxima.stratified.cluster.first_count"; static const std::string STRATIFIED_CLUSTER_SECOND_COUNT = "proxima.stratified.cluster.second_count"; static const std::string STRATIFIED_CLUSTER_FIRST_PARAMS = "proxima.stratified.cluster.first_params"; static const std::string STRATIFIED_CLUSTER_SECOND_PARAMS = "proxima.stratified.cluster.second_params"; static const std::string STRATIFIED_CLUSTER_AUTO_TUNING = "proxima.stratified.cluster.auto_tuning"; static const std::string STRATIFIED_CLUSTER_SECOND_POOL_COUNT = "proxima.stratified.cluster.second_pool_count"; //! Gap Statistics static const std::string GAPSTATS_CLUSTER_ESTIMATER_K_MIN = "proxima.gapstats.cluster_estimater.k_min"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_K_MAX = "proxima.gapstats.cluster_estimater.k_max"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_K_MIN_STEP = "proxima.gapstats.cluster_estimater.k_min_step"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_K_MAX_STEP = "proxima.gapstats.cluster_estimater.k_max_step"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_TRY_COUNT = "proxima.gapstats.cluster_estimater.try_count"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_SHARD_FACTOR = "proxima.gapstats.cluster_estimater.shard_factor"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_ENABLE_MC2 = "proxima.gapstats.cluster_estimater.enable_mc2"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_MARKOV_CHAIN_LENGTH = "proxima.gapstats.cluster_estimater.markov_chain_length"; static const std::string GAPSTATS_CLUSTER_ESTIMATER_CLUSTER_CLASS = "proxima.gapstats.cluster_estimater.cluster_class"; static const std::string CLUSTER_TRAINER_SAMPLE_COUNT = "proxima.cluster.trainer.sample_count"; static const std::string CLUSTER_TRAINER_SAMPLE_RATIO = "proxima.cluster.trainer.sample_ratio"; static const std::string CLUSTER_TRAINER_THREAD_COUNT = "proxima.cluster.trainer.thread_count"; static const std::string CLUSTER_TRAINER_FILE_NAME = "proxima.cluster.trainer.file_name"; static const std::string CLUSTER_TRAINER_CLASS_NAME = "proxima.cluster.trainer.class_name"; static const std::string STRATIFIED_TRAINER_SAMPLE_COUNT = "proxima.stratified.trainer.sample_count"; static const std::string STRATIFIED_TRAINER_SAMPLE_RATIO = "proxima.stratified.trainer.sample_ratio"; static const std::string STRATIFIED_TRAINER_THREAD_COUNT = "proxima.stratified.trainer.thread_count"; static const std::string STRATIFIED_TRAINER_FILE_NAME = "proxima.stratified.trainer.file_name"; static const std::string STRATIFIED_TRAINER_CLASS_NAME = "proxima.stratified.trainer.class_name"; static const std::string STRATIFIED_TRAINER_CLUSTER_COUNT = "proxima.stratified.trainer.cluster_count"; static const std::string STRATIFIED_TRAINER_AUTOAUNE = "proxima.stratified.trainer.autotune"; static const std::string STRATIFIED_TRAINER_PARAMS_IN_LEVEL_PREFIX = "proxima.stratified.trainer.cluster_params_in_level_"; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/kmeans_cluster.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "cluster_params.h" #include "linear_seeker.h" #include "vector_mean.h" namespace zvec { namespace core { /*! Kmeans Cluster */ class KmeansCluster : public IndexCluster { public: //! Constructor KmeansCluster(void) {} //! Constructor KmeansCluster(size_t iters, bool batch) : max_iterations_(iters), batch_(batch) {} //! Constructor KmeansCluster(bool batch) : batch_(batch) {} //! Destructor virtual ~KmeansCluster(void) {} //! Initialize Cluster virtual int init(const IndexMeta &meta, const ailego::Params ¶ms); //! Cleanup Cluster virtual int cleanup(void); //! Reset Cluster virtual int reset(void); //! Update Cluster virtual int update(const ailego::Params ¶ms); //! Suggest dividing to K clusters virtual void suggest(uint32_t k); //! Mount features virtual int mount(IndexFeatures::Pointer feats); //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Classify virtual int classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Label virtual int label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out); protected: //! Test if it is valid bool is_valid(void) const; //! Cluster once int clustering(IndexThreads *threads, IndexCluster::CentroidList ¢s, double *cost); //! Update parameters void update_params(const ailego::Params ¶ms); //! Init seeker int init_seeker(void); //! Build seeker int build_seeker(const IndexCluster::CentroidList ¢s); //! Check Centroids bool check_centroids(const IndexCluster::CentroidList ¢s); //! Initialize Centroids void init_centroids(size_t count, IndexCluster::CentroidList *out); //! Initialize Shard Containers void init_containers(size_t shard_count); //! Initialize Shard Features Containers void init_features_containers(size_t shard_count); //! Split Clusters void split_clusters(IndexThreads *threads, const IndexCluster::CentroidList ¢s); //! Update Centroids void update_centroids(IndexThreads *threads, IndexCluster::CentroidList ¢s); //! Update Clusters void update_clusters(IndexThreads *threads, const IndexCluster::CentroidList ¢s); //! Update Clusters' Features void update_features(IndexThreads *threads, IndexCluster::CentroidList ¢s); //! Update Labels void update_labels(IndexThreads *threads, std::vector *labels); //! Split Clusters in Thread void split_clusters_thread(size_t index_begin, size_t index_end, const IndexThreads *threads); //! Update Centroid in Thread void update_centroid_thread(size_t column, IndexCluster::CentroidList *out); //! Update Cluster in Thread void update_cluster_thread(size_t index_begin, size_t index_end, const IndexThreads *threads); //! Update Cluster's Features in Thread void update_features_thread(size_t column, IndexCluster::CentroidList *out); //! Update Labels in Thread void update_labels_thread(size_t index_begin, size_t index_end, std::vector *labels); protected: //! Members IndexMeta meta_{}; IndexFeatures::Pointer features_{}; LinearSeeker::Pointer seeker_{}; std::vector shard_cluster_scores_{}; std::vector> shard_cluster_features_{}; std::shared_ptr shard_cluster_means_{}; std::shared_ptr batch_means_{}; std::vector batch_scores_{}; double epsilon_{std::numeric_limits::epsilon()}; float shard_factor_{16.0f}; uint32_t max_iterations_{20u}; uint32_t cluster_count_{0u}; uint32_t thread_count_{0u}; bool batch_{false}; bool purge_empty_{false}; }; /*! Centroid Features */ class KmeansCentroidFeatures : public IndexFeatures { public: //! Constructor KmeansCentroidFeatures(const IndexMeta &meta, const IndexCluster::CentroidList ¢s) : centroids_(cents), feature_size_(meta.element_size()), feature_dimension_(meta.dimension()), data_type_(meta.data_type()) {} virtual size_t count(void) const { return centroids_.size(); } virtual size_t dimension(void) const { return feature_dimension_; } virtual const void *element(size_t i) const { return centroids_[i].feature(); } virtual IndexMeta::DataType data_type(void) const { return data_type_; } virtual size_t element_size(void) const { return feature_size_; } private: const IndexCluster::CentroidList ¢roids_; size_t feature_size_; size_t feature_dimension_; IndexMeta::DataType data_type_; }; static inline std::shared_ptr NewVectorMean(const IndexMeta &meta) { switch (meta.data_type()) { case IndexMeta::DataType::DT_FP16: return std::make_shared>( meta.dimension()); case IndexMeta::DataType::DT_FP32: return std::make_shared>(meta.dimension()); case IndexMeta::DataType::DT_FP64: return std::make_shared>(meta.dimension()); case IndexMeta::DataType::DT_INT8: return std::make_shared>(meta.dimension()); case IndexMeta::DataType::DT_INT4: return std::make_shared>(meta.dimension()); case IndexMeta::DataType::DT_INT16: return std::make_shared>(meta.dimension()); default: break; } // As binary default return std::make_shared(meta.dimension()); } static inline std::shared_ptr NewVectorMeanArray( const IndexMeta &meta) { switch (meta.data_type()) { case IndexMeta::DataType::DT_FP16: return std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); case IndexMeta::DataType::DT_FP32: return std::make_shared< GeneralVectorMeanArray>>(meta.dimension()); case IndexMeta::DataType::DT_FP64: return std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); case IndexMeta::DataType::DT_INT8: return std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); case IndexMeta::DataType::DT_INT4: return std::make_shared< GeneralVectorMeanArray>>(meta.dimension()); case IndexMeta::DataType::DT_INT16: return std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); default: break; } // As binary default return std::make_shared>( meta.dimension()); } static inline std::shared_ptr NewVectorMeanArray( const IndexMeta &meta, const IndexCluster::CentroidList ¢s) { switch (meta.data_type()) { case IndexMeta::DataType::DT_FP16: { auto ptr = std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } case IndexMeta::DataType::DT_FP32: { auto ptr = std::make_shared>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } case IndexMeta::DataType::DT_FP64: { auto ptr = std::make_shared>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } case IndexMeta::DataType::DT_INT8: { auto ptr = std::make_shared>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } case IndexMeta::DataType::DT_INT4: { auto ptr = std::make_shared>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } case IndexMeta::DataType::DT_INT16: { auto ptr = std::make_shared< GeneralVectorMeanArray>>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(reinterpret_cast(it.feature()), meta.dimension(), it.follows()); } return ptr; } default: break; } // As binary default auto ptr = std::make_shared>( meta.dimension()); for (const auto &it : cents) { ptr->emplace(it.feature(), meta.dimension(), it.follows()); } return ptr; } static inline double CalculateSSE(const IndexCluster::CentroidList ¢s) { double accum = 0.0; for (const auto &it : cents) { accum += it.score(); } return accum; } static inline void PurgeCentroids(IndexCluster::CentroidList ¢s, bool cutting) { size_t index = 0; size_t tamp = cents.size(); while (index < tamp) { if (cents[index].follows() == 0) { size_t last_index = tamp - 1; if (index != last_index) { std::swap(cents[index], cents[last_index]); } tamp = last_index; continue; } ++index; } if (cutting) { cents.resize(tamp); } } int KmeansCluster::init(const IndexMeta &meta, const ailego::Params ¶ms) { meta_ = meta; this->update_params(params); return this->init_seeker(); } int KmeansCluster::cleanup(void) { features_.reset(); shard_cluster_scores_.clear(); shard_cluster_features_.clear(); shard_cluster_means_.reset(); batch_means_.reset(); batch_scores_.clear(); seeker_->cleanup(); return 0; } int KmeansCluster::reset(void) { features_.reset(); shard_cluster_scores_.clear(); shard_cluster_features_.clear(); shard_cluster_means_->clear(); batch_means_->clear(); batch_scores_.clear(); seeker_->reset(); return 0; } int KmeansCluster::update(const ailego::Params ¶ms) { this->update_params(params); return 0; } void KmeansCluster::suggest(uint32_t k) { cluster_count_ = k; } int KmeansCluster::mount(IndexFeatures::Pointer feats) { if (!feats) { return IndexError_InvalidArgument; } if (!feats->is_matched(meta_)) { return IndexError_Mismatch; } // Check dimension auto data_type = meta_.data_type(); switch (data_type) { case IndexMeta::DataType::DT_INT4: if (feats->dimension() % 2 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of int4 " "must be an integer multiple of 2).", feats->dimension()); return IndexError_Mismatch; } break; case IndexMeta::DataType::DT_BINARY32: if (feats->dimension() % 32 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of binary32 " "must be an integer multiple of 32).", feats->dimension()); return IndexError_Mismatch; } break; case IndexMeta::DataType::DT_BINARY64: if (feats->dimension() % 64 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of binary64 " "must be an integer multiple of 64).", feats->dimension()); return IndexError_Mismatch; } break; default: break; } features_ = std::move(feats); return 0; } int KmeansCluster::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } if (cents.empty()) { if (cluster_count_ == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } this->init_centroids(cluster_count_, ¢s); } if (batch_) { batch_means_ = NewVectorMeanArray(meta_, cents); batch_scores_.clear(); for (const auto &it : cents) { batch_scores_.push_back(it.score()); } } double cost = 0.0; // we need to do clustering and update the centroids' follows, even if // cents.size() == 1. Otherwise, the centroid with empty follows will be // removed if purge_empty enabled for (uint32_t i = 0; (i < max_iterations_) && (cents.size() > 0); ++i) { double new_cost, new_epsilon; int result = this->clustering(threads.get(), cents, &new_cost); if (result != 0) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return result; } new_epsilon = new_cost - cost; LOG_DEBUG("(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> %f = %f", i, cents.size(), features_->count(), (size_t)stamp.milli_seconds(), cost, new_cost, new_epsilon); stamp.reset(); new_epsilon = std::abs(new_epsilon); if (new_epsilon < epsilon_) { break; } cost = new_cost; } // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } int KmeansCluster::classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } int result = this->build_seeker(cents); if (result != 0) { LOG_ERROR("Failed to build the seeker."); return result; } this->update_clusters(threads.get(), cents); this->update_features(threads.get(), cents); return 0; } int KmeansCluster::label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } int result = this->build_seeker(cents); if (result != 0) { LOG_ERROR("Failed to build the seeker."); return result; } this->update_labels(threads.get(), out); return 0; } bool KmeansCluster::is_valid(void) const { if (!seeker_ || !features_ || !features_->count()) { return false; } return true; } int KmeansCluster::clustering(IndexThreads *threads, IndexCluster::CentroidList ¢s, double *cost) { int result = this->build_seeker(cents); if (result != 0) { LOG_ERROR("Failed to build the seeker."); return result; } this->split_clusters(threads, cents); this->update_centroids(threads, cents); *cost = CalculateSSE(cents); return 0; } void KmeansCluster::update_params(const ailego::Params ¶ms) { params.get(GENERAL_THREAD_COUNT, &thread_count_); params.get(GENERAL_CLUSTER_COUNT, &cluster_count_); params.get(KMEANS_CLUSTER_COUNT, &cluster_count_); params.get(KMEANS_CLUSTER_SHARD_FACTOR, &shard_factor_); params.get(KMEANS_CLUSTER_EPSILON, &epsilon_); params.get(KMEANS_CLUSTER_MAX_ITERATIONS, &max_iterations_); params.get(KMEANS_CLUSTER_BATCH, &batch_); params.get(KMEANS_CLUSTER_PURGE_EMPTY, &purge_empty_); } int KmeansCluster::init_seeker(void) { seeker_.reset(new (std::nothrow) LinearSeeker); if (!seeker_) { LOG_ERROR("Failed to create linear seeker."); return IndexError_NoMemory; } int result = seeker_->init(meta_); if (result != 0) { LOG_ERROR("Failed to initialize linear seeker."); return result; } return 0; } int KmeansCluster::build_seeker(const IndexCluster::CentroidList ¢s) { int result = seeker_->mount(std::make_shared(meta_, cents)); if (result != 0) { LOG_ERROR("Failed to mount features for linear seeker."); return result; } return 0; } bool KmeansCluster::check_centroids(const IndexCluster::CentroidList ¢s) { for (const auto &it : cents) { if (it.size() != meta_.element_size()) { return false; } } return true; } void KmeansCluster::init_centroids(size_t count, IndexCluster::CentroidList *out) { size_t feature_size = features_->element_size(); size_t features_count = features_->count(); size_t sample_count = std::min(count, features_count); ailego::Reservoir sampler(sample_count); for (size_t i = 0; i < features_count; ++i) { sampler.fill(i); } // Save centroids out->reserve(sampler.pool().size()); for (auto i : sampler.pool()) { out->emplace_back(features_->element(i), feature_size); } } void KmeansCluster::init_containers(size_t shard_count) { if (!shard_cluster_means_) { shard_cluster_means_ = NewVectorMeanArray(meta_); } shard_cluster_means_->clear(); shard_cluster_means_->resize(shard_count); shard_cluster_scores_.clear(); shard_cluster_scores_.resize(shard_count); } void KmeansCluster::init_features_containers(size_t shard_count) { shard_cluster_features_.resize(shard_count); for (auto &features : shard_cluster_features_) { features.clear(); } } void KmeansCluster::split_clusters(IndexThreads *threads, const IndexCluster::CentroidList ¢s) { // Initilize containers this->init_containers(threads->count() * cents.size()); auto task_group = threads->make_group(); // Initilize base information size_t features_count = features_->count(); size_t shard_count = std::max( static_cast(std::ceil(threads->count() * shard_factor_)), 1u); size_t fregment_count = (features_count + shard_count - 1) / shard_count; for (size_t i = 0, index = 0; (i != shard_count) && (index < features_count); ++i) { size_t next_index = index + fregment_count; if (next_index > features_count) { next_index = features_count; } // Process in work thread task_group->submit( ailego::Closure::New(this, &KmeansCluster::split_clusters_thread, index, next_index, threads)); // Next index index = next_index; } task_group->wait_finish(); } void KmeansCluster::update_centroids(IndexThreads *threads, IndexCluster::CentroidList ¢s) { auto task_group = threads->make_group(); for (size_t i = 0; i < cents.size(); ++i) { task_group->submit(ailego::Closure::New( this, &KmeansCluster::update_centroid_thread, i, ¢s)); } task_group->wait_finish(); } void KmeansCluster::update_clusters(IndexThreads *threads, const IndexCluster::CentroidList ¢s) { // Initilize containers this->init_features_containers(threads->count() * cents.size()); auto task_group = threads->make_group(); size_t features_count = features_->count(); size_t shard_count = std::max( static_cast(std::ceil(threads->count() * shard_factor_)), 1u); size_t fregment_count = (features_count + shard_count - 1) / shard_count; for (size_t i = 0, index = 0; (i != shard_count) && (index < features_count); ++i) { size_t next_index = index + fregment_count; if (next_index > features_count) { next_index = features_count; } // Process in work thread task_group->submit( ailego::Closure::New(this, &KmeansCluster::update_cluster_thread, index, next_index, threads)); // Next index index = next_index; } task_group->wait_finish(); } void KmeansCluster::update_features(IndexThreads *threads, IndexCluster::CentroidList ¢s) { auto task_group = threads->make_group(); for (size_t i = 0; i < cents.size(); ++i) { // Process in work thread task_group->submit(ailego::Closure::New( this, &KmeansCluster::update_features_thread, i, ¢s)); } task_group->wait_finish(); } void KmeansCluster::update_labels(IndexThreads *threads, std::vector *labels) { size_t features_count = features_->count(); size_t shard_count = std::max( static_cast(std::ceil(threads->count() * shard_factor_)), 1u); size_t fregment_count = (features_count + shard_count - 1) / shard_count; auto task_group = threads->make_group(); // Prepare buffer labels->resize(features_count); for (size_t i = 0, index = 0; (i != shard_count) && (index < features_count); ++i) { size_t next_index = index + fregment_count; if (next_index > features_count) { next_index = features_count; } // Process in work thread task_group->submit(ailego::Closure::New( this, &KmeansCluster::update_labels_thread, index, next_index, labels)); // Next index index = next_index; } task_group->wait_finish(); } void KmeansCluster::split_clusters_thread(size_t index_begin, size_t index_end, const IndexThreads *threads) { size_t feature_size = features_->element_size(); size_t thread_offset = threads->indexof_this() * seeker_->original()->count(); for (size_t i = index_begin; i != index_end; ++i) { const void *feat = features_->element(i); LinearSeeker::Document result(0, std::numeric_limits::max()); // ignore error seeker_->seek(feat, meta_.element_size(), &result); size_t sel_column = thread_offset + result.index; shard_cluster_scores_[sel_column] += result.score; shard_cluster_means_->at(sel_column).plus(feat, feature_size); } } void KmeansCluster::update_centroid_thread(size_t column, IndexCluster::CentroidList *out) { size_t cluster_count = out->size(); double cluster_score = 0.0; // Create Accumulator std::shared_ptr accum = NewVectorMean(meta_); if (batch_) { cluster_score += batch_scores_[column]; accum->merge(batch_means_->at(column)); } // Compute the score of centroid for (size_t i = column; i < shard_cluster_scores_.size(); i += cluster_count) { cluster_score += shard_cluster_scores_[i]; accum->merge(shard_cluster_means_->at(i)); } // Update centroid IndexCluster::Centroid *centroid = &(out->at(column)); centroid->set_score(cluster_score); centroid->set_follows(accum->count()); accum->mean(centroid->mutable_buffer()); } void KmeansCluster::update_cluster_thread(size_t index_begin, size_t index_end, const IndexThreads *threads) { size_t thread_offset = threads->indexof_this() * seeker_->original()->count(); for (size_t i = index_begin; i != index_end; ++i) { const void *feat = features_->element(i); LinearSeeker::Document result(0, std::numeric_limits::max()); // ignore error seeker_->seek(feat, meta_.element_size(), &result); size_t sel_column = thread_offset + result.index; shard_cluster_features_[sel_column].emplace_back(feat); } } void KmeansCluster::update_features_thread(size_t column, IndexCluster::CentroidList *out) { size_t cluster_count = out->size(); size_t cluster_follows = 0u; // Compute the follows of cluster for (size_t i = column; i < shard_cluster_features_.size(); i += cluster_count) { cluster_follows += shard_cluster_features_[i].size(); } // Merge all features in cluster std::vector &cluster_features = *(out->at(column).mutable_similars()); cluster_features.resize(cluster_follows); for (size_t i = column, j = 0; i < shard_cluster_features_.size(); i += cluster_count) { const std::vector &it = shard_cluster_features_[i]; std::memcpy(&cluster_features[j], it.data(), it.size() * sizeof(void *)); j += it.size(); } } void KmeansCluster::update_labels_thread(size_t index_begin, size_t index_end, std::vector *labels) { for (size_t i = index_begin; i != index_end; ++i) { const void *feat = features_->element(i); LinearSeeker::Document result(0, std::numeric_limits::max()); // ignore error seeker_->seek(feat, meta_.element_size(), &result); (*labels)[i] = static_cast(result.index); } } INDEX_FACTORY_REGISTER_CLUSTER_ALIAS(KmeansCluster, KmeansCluster, false); INDEX_FACTORY_REGISTER_CLUSTER_ALIAS(BatchKmeansCluster, KmeansCluster, true); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/linear_seeker.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "linear_seeker.h" namespace zvec { namespace core { int LinearSeeker::seek(const void *query, size_t len, Document *out) { if (ailego_unlikely(!query || !out || meta_.element_size() != len)) { return IndexError_InvalidArgument; } float sel_score = std::numeric_limits::max(); uint32_t sel_column = 0; uint32_t total = static_cast(features_->count()); for (uint32_t i = 0; i < total; ++i) { float score = 0.0f; distance_func_(features_->element(i), query, meta_.dimension(), &score); if (score < sel_score) { sel_score = score; sel_column = i; } } out->index = sel_column; out->score = sel_score; return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/linear_seeker.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "seeker.h" namespace zvec { namespace core { /*! Linear Seeker */ class LinearSeeker : public Seeker { public: typedef std::shared_ptr Pointer; //! Constructor LinearSeeker(void) : meta_(), metric_(), features_() {} //! Destructor ~LinearSeeker(void) {} //! Initialize Seeker int init(const IndexMeta &meta) override { meta_ = meta; metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("Create Metric %s failed.", meta_.metric_name().c_str()); return IndexError_Unsupported; } int ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed wit ret %d.", ret); return ret; } distance_func_ = metric_->distance_matrix(1, 1); if (!distance_func_) { LOG_ERROR("DistanceMatrix function is nullptr."); return IndexError_Unsupported; } return 0; } //! Cleanup Seeker int cleanup(void) override { features_.reset(); return 0; } //! Reset Seeker int reset(void) override { features_.reset(); return 0; } //! Mount features int mount(IndexFeatures::Pointer feats) override { if (!feats) { return IndexError_InvalidArgument; } if (!feats->is_matched(meta_)) { return IndexError_Mismatch; } features_ = std::move(feats); return 0; } //! Seek (TOP 1 Document) int seek(const void *query, size_t len, Document *out) override; //! Retrieve the original features IndexFeatures::Pointer original(void) const override { return features_; } private: IndexMeta meta_{}; IndexMetric::Pointer metric_{}; IndexFeatures::Pointer features_{}; IndexMetric::MatrixDistance distance_func_{nullptr}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/opt_kmeans_cluster.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "cluster_params.h" namespace zvec { namespace core { /*! Optimize K-Means cluster algorithm */ class OptKmeansAlgorithm : public IndexCluster { public: //! Constructor OptKmeansAlgorithm(void) {} //! Destructor virtual ~OptKmeansAlgorithm(void) {} //! Initialize Cluster int init(const IndexMeta &meta, const ailego::Params ¶ms); //! Mount features virtual int mount(IndexFeatures::Pointer feats); //! Suggest dividing to K clusters virtual void suggest(uint32_t k); //! Classify virtual int classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Label virtual int label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out); //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) = 0; //! Cleanup Cluster virtual int cleanup(void); //! Reset Cluster virtual int reset(void); //! Update Cluster virtual int update(const ailego::Params ¶ms); protected: //! Update parameters void update_params(const ailego::Params ¶ms); //! Init Kmeans Algorithm int init_algorithm(); //! Init Distance function int init_distance_func(); //! Check Centroids bool check_centroids(const IndexCluster::CentroidList ¢s); //! Test if it is valid bool is_valid(void) const; //! Update Clusters void update_clusters(IndexThreads *threads, const IndexCluster::CentroidList ¢s); //! Update Cluster in Thread void update_cluster_thread(size_t index_begin, size_t index_end, const IndexThreads *threads, const IndexCluster::CentroidList ¢s); //! Initialize Shard Features Containers void init_features_containers(size_t shard_count); //! Update Clusters' Features void update_features(IndexThreads *threads, IndexCluster::CentroidList ¢s); //! Update Cluster's Features in Thread void update_features_thread(size_t column, IndexCluster::CentroidList *out); //! Update Labels void update_labels(IndexThreads *threads, std::vector *labels, const IndexCluster::CentroidList ¢s); //! Update Labels in Thread void update_labels_thread(size_t index_begin, size_t index_end, std::vector *labels, const IndexCluster::CentroidList ¢s); //! Initialize Centroids void init_centroids(size_t count, IndexCluster::CentroidList *out); protected: uint32_t cluster_count_{0u}; uint32_t thread_count_{0u}; uint32_t max_iterations_{20u}; double epsilon_{std::numeric_limits::epsilon()}; float shard_factor_{16.0f}; bool purge_empty_{false}; bool assumption_free_{false}; uint32_t markov_chain_length_{32}; IndexMeta meta_{}; IndexFeatures::Pointer features_{}; std::vector> shard_cluster_features_{}; IndexMetric::MatrixDistance distance_func_{nullptr}; }; bool OptKmeansAlgorithm::is_valid(void) const { if (!features_ || !features_->count()) { return false; } return true; } bool OptKmeansAlgorithm::check_centroids( const IndexCluster::CentroidList ¢s) { for (const auto &it : cents) { if (it.size() != meta_.element_size()) { return false; } } return true; } void OptKmeansAlgorithm::update_params(const ailego::Params ¶ms) { params.get(GENERAL_THREAD_COUNT, &thread_count_); params.get(GENERAL_CLUSTER_COUNT, &cluster_count_); params.get(OPTKMEANS_CLUSTER_COUNT, &cluster_count_); params.get(OPTKMEANS_CLUSTER_SHARD_FACTOR, &shard_factor_); params.get(OPTKMEANS_CLUSTER_EPSILON, &epsilon_); params.get(OPTKMEANS_CLUSTER_MAX_ITERATIONS, &max_iterations_); params.get(OPTKMEANS_CLUSTER_PURGE_EMPTY, &purge_empty_); params.get(OPTKMEANS_CLUSTER_MARKOV_CHAIN_LENGTH, &markov_chain_length_); params.get(OPTKMEANS_CLUSTER_ASSUMPTION_FREE, &assumption_free_); } int OptKmeansAlgorithm::init_distance_func() { IndexMetric::Pointer metric_{}; metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("Create Metric %s failed.", meta_.metric_name().c_str()); return IndexError_Unsupported; } int ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed wit ret %d.", ret); return ret; } distance_func_ = metric_->distance_matrix(1, 1); if (!distance_func_) { LOG_ERROR("DistanceMatrix function is nullptr."); return IndexError_Unsupported; } return 0; } void OptKmeansAlgorithm::update_clusters( IndexThreads *threads, const IndexCluster::CentroidList ¢s) { // Initilize containers this->init_features_containers(threads->count() * cents.size()); auto task_group = threads->make_group(); size_t features_count = features_->count(); size_t shard_count = std::max( static_cast(std::ceil(threads->count() * shard_factor_)), 1u); size_t fregment_count = (features_count + shard_count - 1) / shard_count; for (size_t i = 0, index = 0; (i != shard_count) && (index < features_count); ++i) { size_t next_index = index + fregment_count; if (next_index > features_count) { next_index = features_count; } // Process in work thread· task_group->submit( ailego::Closure::New(this, &OptKmeansAlgorithm::update_cluster_thread, index, next_index, threads, cents)); // Next index index = next_index; } task_group->wait_finish(); } void OptKmeansAlgorithm::update_cluster_thread( size_t index_begin, size_t index_end, const IndexThreads *threads, const IndexCluster::CentroidList ¢s) { size_t thread_offset = threads->indexof_this() * cents.size(); for (size_t i = index_begin; i != index_end; ++i) { const void *feat = features_->element(i); uint32_t sel_index = 0; float sel_score = std::numeric_limits::max(); // todo: get min distance uint32_t total = static_cast(cents.size()); for (uint32_t j = 0; j < total; ++j) { float score = 0.0f; distance_func_(cents[j].feature(), feat, meta_.dimension(), &score); if (score < sel_score) { sel_score = score; sel_index = j; } } size_t sel_column = thread_offset + sel_index; shard_cluster_features_[sel_column].emplace_back(feat); } } void OptKmeansAlgorithm::init_features_containers(size_t shard_count) { shard_cluster_features_.resize(shard_count); for (auto &features : shard_cluster_features_) { features.clear(); } } void OptKmeansAlgorithm::update_features(IndexThreads *threads, IndexCluster::CentroidList ¢s) { auto task_group = threads->make_group(); for (size_t i = 0; i < cents.size(); ++i) { // Process in work thread task_group->submit(ailego::Closure::New( this, &OptKmeansAlgorithm::update_features_thread, i, ¢s)); } task_group->wait_finish(); } void OptKmeansAlgorithm::update_labels( IndexThreads *threads, std::vector *labels, const IndexCluster::CentroidList ¢s) { size_t features_count = features_->count(); size_t shard_count = std::max( static_cast(std::ceil(threads->count() * shard_factor_)), 1u); size_t fregment_count = (features_count + shard_count - 1) / shard_count; auto task_group = threads->make_group(); // Prepare buffer labels->resize(features_count); for (size_t i = 0, index = 0; (i != shard_count) && (index < features_count); ++i) { size_t next_index = index + fregment_count; if (next_index > features_count) { next_index = features_count; } // Process in work thread task_group->submit( ailego::Closure::New(this, &OptKmeansAlgorithm::update_labels_thread, index, next_index, labels, cents)); // Next index index = next_index; } task_group->wait_finish(); } void OptKmeansAlgorithm::update_labels_thread( size_t index_begin, size_t index_end, std::vector *labels, const IndexCluster::CentroidList ¢s) { for (size_t i = index_begin; i != index_end; ++i) { const void *feat = features_->element(i); uint32_t sel_index = 0; float sel_score = std::numeric_limits::max(); // todo: get min distance uint32_t total = static_cast(cents.size()); for (uint32_t j = 0; j < total; ++j) { float score = 0.0f; distance_func_(cents[j].feature(), feat, meta_.dimension(), &score); if (score < sel_score) { sel_score = score; sel_index = j; } } (*labels)[i] = static_cast(sel_index); } } void OptKmeansAlgorithm::init_centroids(size_t count, IndexCluster::CentroidList *out) { // Just resize, because the get random centroid step is done by cluster_once out->resize(count); } void OptKmeansAlgorithm::update_features_thread( size_t column, IndexCluster::CentroidList *out) { size_t cluster_count = out->size(); size_t cluster_follows = 0u; // Compute the follows of cluster for (size_t i = column; i < shard_cluster_features_.size(); i += cluster_count) { cluster_follows += shard_cluster_features_[i].size(); } // Merge all features in cluster std::vector &cluster_features = *(out->at(column).mutable_similars()); cluster_features.resize(cluster_follows); for (size_t i = column, j = 0; i < shard_cluster_features_.size(); i += cluster_count) { const std::vector &it = shard_cluster_features_[i]; std::memcpy(&cluster_features[j], it.data(), it.size() * sizeof(void *)); j += it.size(); } } static inline void PurgeCentroids(IndexCluster::CentroidList ¢s, bool cutting) { size_t index = 0; size_t tamp = cents.size(); while (index < tamp) { if (cents[index].follows() == 0) { size_t last_index = tamp - 1; if (index != last_index) { std::swap(cents[index], cents[last_index]); } tamp = last_index; continue; } ++index; } if (cutting) { cents.resize(tamp); } } int OptKmeansAlgorithm::init(const IndexMeta &meta, const ailego::Params ¶ms) { meta_ = meta; this->update_params(params); return init_distance_func(); } int OptKmeansAlgorithm::mount(IndexFeatures::Pointer feats) { if (!feats) { return IndexError_InvalidArgument; } if (!feats->is_matched(meta_)) { return IndexError_Mismatch; } // Check dimension auto type_ = meta_.data_type(); switch (type_) { case IndexMeta::DataType::DT_INT4: if (feats->dimension() % 8 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of int4 " "must be an integer multiple of 8).", feats->dimension()); return IndexError_Mismatch; } break; case IndexMeta::DataType::DT_INT8: if (feats->dimension() % 4 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of int8 " "must be an integer multiple of 4).", feats->dimension()); return IndexError_Mismatch; } break; case IndexMeta::DataType::DT_BINARY32: case IndexMeta::DataType::DT_BINARY64: if (feats->dimension() % 32 != 0) { LOG_ERROR( "Unsupported feature dimension %zu (dimension of binary " "must be an integer multiple of 32).", feats->dimension()); return IndexError_Mismatch; } break; default: break; } features_ = std::move(feats); return 0; } void OptKmeansAlgorithm::suggest(uint32_t k) { cluster_count_ = k; } int OptKmeansAlgorithm::classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } this->update_clusters(threads.get(), cents); this->update_features(threads.get(), cents); return 0; } int OptKmeansAlgorithm::label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } this->update_labels(threads.get(), out, cents); return 0; } int OptKmeansAlgorithm::update(const ailego::Params ¶ms) { this->update_params(params); // algorithm_->reset(cluster_count_); return 0; } int OptKmeansAlgorithm::reset(void) { features_.reset(); shard_cluster_features_.clear(); return 0; } int OptKmeansAlgorithm::cleanup(void) { features_.reset(); shard_cluster_features_.clear(); return 0; } /*! Numerical K-Means cluster algorithm */ template class NumericalKmeansAlgorithm : public OptKmeansAlgorithm { public: //! Type of value using ValueType = typename std::remove_cv::type; // Check supporting type static_assert(ailego::IsArithmetic::value, "ValueType must be arithmetic"); //! Constructor NumericalKmeansAlgorithm(void) {} //! Destructor virtual ~NumericalKmeansAlgorithm(void) {} //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); protected: void update_centroids( IndexCluster::CentroidList ¢s, const ailego::NumericalKmeans &algorithm); }; template void NumericalKmeansAlgorithm::update_centroids( IndexCluster::CentroidList ¢s, const ailego::NumericalKmeans &algorithm) { this->init_centroids(algorithm.centroids().count(), ¢s); for (size_t i = 0; i < cents.size(); ++i) { IndexCluster::Centroid *centroid = &(cents.at(i)); centroid->set_score(algorithm.context().clusters()[i].cost()); centroid->set_follows(algorithm.context().clusters()[i].count()); centroid->set_feature(algorithm.centroids()[i], meta_.dimension() * sizeof(T)); } } template int NumericalKmeansAlgorithm::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } // get cluster algorithm size_t centroid_count = cents.empty() ? std::min(cluster_count_, static_cast(features_->count())) : cents.size(); if (centroid_count == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } ailego::NumericalKmeans algorithm(centroid_count, meta_.dimension()); // mount features into algorithm auto features_count = features_->count(); auto dim = meta_.dimension(); algorithm.feature_matrix_reserve(features_count); for (size_t i = 0; i < features_count; ++i) { auto vec = reinterpret_cast(features_->element(i)); algorithm.append(vec, dim); } if (!cents.empty()) { auto centroids = algorithm.mutable_centroids(); centroids->reserve(cents.size()); for (const auto &it : cents) { centroids->append(reinterpret_cast(it.feature()), meta_.dimension()); } } else { ailego::Kmc2CentroidsGenerator< ailego::NumericalKmeans, IndexThreads> g; g.set_chain_length(markov_chain_length_); g.set_assumption_free(assumption_free_); algorithm.init_centroids(*threads, g); } double cost = 0.0; for (uint32_t i = 0; i < max_iterations_; ++i) { double old_cost, new_epsilon; old_cost = cost; bool result = algorithm.cluster_once(*threads, &cost); if (result != true) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return -1; } new_epsilon = std::abs(cost - old_cost); LOG_DEBUG("(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> %f = %f", i, algorithm.centroids().count(), features_->count(), (size_t)stamp.milli_seconds(), old_cost, cost, new_epsilon); stamp.reset(); if (new_epsilon < epsilon_) { break; } } // update_centroids(cents); update_centroids(cents, algorithm); // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } /*! Nibble K-Means cluster algorithm */ template class NibbleKmeansAlgorithm : public OptKmeansAlgorithm { public: //! Type of value using ValueType = typename std::remove_cv::type; // Check supporting type static_assert(ailego::IsArithmetic::value, "ValueType must be arithmetic"); //! Constructor NibbleKmeansAlgorithm(void) {} //! Destructor virtual ~NibbleKmeansAlgorithm(void) {} //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); protected: //! update centroids void update_centroids(IndexCluster::CentroidList ¢s, const ailego::NibbleKmeans &algorithm); }; template void NibbleKmeansAlgorithm::update_centroids( IndexCluster::CentroidList ¢s, const ailego::NibbleKmeans &algorithm) { this->init_centroids(algorithm.centroids().count(), ¢s); for (size_t i = 0; i < cents.size(); ++i) { IndexCluster::Centroid *centroid = &(cents.at(i)); centroid->set_score(algorithm.context().clusters()[i].cost()); centroid->set_follows(algorithm.context().clusters()[i].count()); centroid->set_feature(algorithm.centroids()[i], (meta_.dimension() >> 1)); } } template int NibbleKmeansAlgorithm::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } // get cluster algorithm size_t centroid_count = cents.empty() ? std::min(cluster_count_, static_cast(features_->count())) : cents.size(); if (centroid_count == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } ailego::NibbleKmeans algorithm(centroid_count, meta_.dimension()); // mount features into algorithm auto features_count = features_->count(); auto dim = meta_.dimension(); for (size_t i = 0; i < features_count; ++i) { auto vec = reinterpret_cast::type *>( features_->element(i)); algorithm.append(vec, dim); } if (!cents.empty()) { auto centroids = algorithm.mutable_centroids(); centroids->reserve(cents.size()); for (const auto &it : cents) { centroids->append( reinterpret_cast::type *>( it.feature()), size_t(meta_.dimension())); } } else { ailego::Kmc2CentroidsGenerator< ailego::NibbleKmeans, IndexThreads> g; g.set_chain_length(markov_chain_length_); g.set_assumption_free(assumption_free_); algorithm.init_centroids(*threads, g); } double cost = 0.0; for (uint32_t i = 0; i < max_iterations_; ++i) { double old_cost, new_epsilon; old_cost = cost; bool result = algorithm.cluster_once(*threads, &cost); if (result != true) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return -1; } new_epsilon = std::abs(cost - old_cost); LOG_DEBUG( "(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> " "%f = %f", i, algorithm.centroids().count(), features_->count(), (size_t)stamp.milli_seconds(), old_cost, cost, new_epsilon); stamp.reset(); if (new_epsilon < epsilon_) { break; } } // update centroids update_centroids(cents, algorithm); // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } /*! Binary K-Means cluster algorithm */ template class BinaryKmeansAlgorithm : public OptKmeansAlgorithm { public: //! Type of value using ValueType = typename std::remove_cv::type; // Check supporting type static_assert(ailego::IsArithmetic::value, "ValueType must be arithmetic"); //! Constructor BinaryKmeansAlgorithm(void) {} //! Destructor virtual ~BinaryKmeansAlgorithm(void) {} //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); protected: //! update centroids void update_centroids(IndexCluster::CentroidList ¢s, const ailego::BinaryKmeans &algorithm); }; template void BinaryKmeansAlgorithm::update_centroids( IndexCluster::CentroidList ¢s, const ailego::BinaryKmeans &algorithm) { this->init_centroids(algorithm.centroids().count(), ¢s); for (size_t i = 0; i < cents.size(); ++i) { IndexCluster::Centroid *centroid = &(cents.at(i)); centroid->set_score(algorithm.context().clusters()[i].cost()); centroid->set_follows(algorithm.context().clusters()[i].count()); centroid->set_feature(algorithm.centroids()[i], (meta_.dimension() >> 3)); } } template int BinaryKmeansAlgorithm::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } // get cluster algorithm size_t centroid_count = cents.empty() ? std::min(cluster_count_, static_cast(features_->count())) : cents.size(); if (centroid_count == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } ailego::BinaryKmeans algorithm(centroid_count, meta_.dimension()); // mount features into algorithm auto features_count = features_->count(); auto dim = meta_.dimension(); for (size_t i = 0; i < features_count; ++i) { auto vec = reinterpret_cast(features_->element(i)); algorithm.append(vec, dim); } if (!cents.empty()) { auto centroids = algorithm.mutable_centroids(); centroids->reserve(cents.size()); for (const auto &it : cents) { centroids->append(reinterpret_cast(it.feature()), meta_.dimension()); } } else { ailego::Kmc2CentroidsGenerator< ailego::BinaryKmeans, IndexThreads> g; g.set_chain_length(markov_chain_length_); g.set_assumption_free(assumption_free_); algorithm.init_centroids(*threads, g); } double cost = 0.0; for (uint32_t i = 0; i < max_iterations_; ++i) { double old_cost, new_epsilon; old_cost = cost; bool result = algorithm.cluster_once(*threads, &cost); if (result != true) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return -1; } new_epsilon = std::abs(cost - old_cost); LOG_DEBUG( "(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> " "%f = %f", i, algorithm.centroids().count(), features_->count(), (size_t)stamp.milli_seconds(), old_cost, cost, new_epsilon); stamp.reset(); if (new_epsilon < epsilon_) { break; } } // update centroids update_centroids(cents, algorithm); // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } /*! Numerical K-Means cluster algorithm */ template class NumericalInnerProductKmeansAlgorithm : public OptKmeansAlgorithm { public: //! Type of value using ValueType = typename std::remove_cv::type; // Check supporting type static_assert(ailego::IsArithmetic::value, "ValueType must be arithmetic"); //! Constructor NumericalInnerProductKmeansAlgorithm(void) {} //! Destructor virtual ~NumericalInnerProductKmeansAlgorithm(void) {} //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); protected: void update_centroids( IndexCluster::CentroidList ¢s, const ailego::NumericalInnerProductKmeans &algorithm); }; template void NumericalInnerProductKmeansAlgorithm::update_centroids( IndexCluster::CentroidList ¢s, const ailego::NumericalInnerProductKmeans &algorithm) { this->init_centroids(algorithm.centroids().count(), ¢s); for (size_t i = 0; i < cents.size(); ++i) { IndexCluster::Centroid *centroid = &(cents.at(i)); centroid->set_score(algorithm.context().clusters()[i].cost()); centroid->set_follows(algorithm.context().clusters()[i].count()); centroid->set_feature(algorithm.centroids()[i], meta_.dimension() * sizeof(T)); } } template int NumericalInnerProductKmeansAlgorithm::cluster( IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } // get cluster algorithm size_t centroid_count = cents.empty() ? std::min(cluster_count_, static_cast(features_->count())) : cents.size(); if (centroid_count == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } ailego::NumericalInnerProductKmeans algorithm( centroid_count, meta_.dimension(), true); // mount features into algorithm auto features_count = features_->count(); auto dim = meta_.dimension(); algorithm.feature_matrix_reserve(features_count); for (size_t i = 0; i < features_count; ++i) { auto vec = reinterpret_cast(features_->element(i)); algorithm.append(vec, dim); } if (!cents.empty()) { auto centroids = algorithm.mutable_centroids(); centroids->reserve(cents.size()); for (const auto &it : cents) { centroids->append(reinterpret_cast(it.feature()), meta_.dimension()); } } else { ailego::Kmc2CentroidsGenerator< ailego::NumericalInnerProductKmeans, IndexThreads> g; g.set_chain_length(markov_chain_length_); g.set_assumption_free(assumption_free_); algorithm.init_centroids(*threads, g); } double cost = 0.0; for (uint32_t i = 0; i < max_iterations_; ++i) { double old_cost, new_epsilon; old_cost = cost; bool result = algorithm.cluster_once(*threads, &cost); if (result != true) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return -1; } new_epsilon = std::abs(cost - old_cost); LOG_DEBUG("(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> %f = %f", i, algorithm.centroids().count(), features_->count(), (size_t)stamp.milli_seconds(), old_cost, cost, new_epsilon); stamp.reset(); if (new_epsilon < epsilon_) { break; } } // update_centroids(cents); update_centroids(cents, algorithm); // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } /*! Nibble Inner Product K-Means cluster algorithm */ template class NibbleInnerProductKmeansAlgorithm : public OptKmeansAlgorithm { public: //! Type of value using ValueType = typename std::remove_cv::type; // Check supporting type static_assert(ailego::IsArithmetic::value, "ValueType must be arithmetic"); //! Constructor NibbleInnerProductKmeansAlgorithm(void) {} //! Destructor virtual ~NibbleInnerProductKmeansAlgorithm(void) {} //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); protected: //! update centroids void update_centroids( IndexCluster::CentroidList ¢s, const ailego::NibbleInnerProductKmeans &algorithm); }; template void NibbleInnerProductKmeansAlgorithm::update_centroids( IndexCluster::CentroidList ¢s, const ailego::NibbleInnerProductKmeans &algorithm) { this->init_centroids(algorithm.centroids().count(), ¢s); for (size_t i = 0; i < cents.size(); ++i) { IndexCluster::Centroid *centroid = &(cents.at(i)); centroid->set_score(algorithm.context().clusters()[i].cost()); centroid->set_follows(algorithm.context().clusters()[i].count()); centroid->set_feature(algorithm.centroids()[i], (meta_.dimension() >> 1)); } } template int NibbleInnerProductKmeansAlgorithm::cluster( IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { ailego::ElapsedTime stamp; if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } // get cluster algorithm size_t centroid_count = cents.empty() ? std::min(cluster_count_, static_cast(features_->count())) : cents.size(); if (centroid_count == 0) { LOG_ERROR("The count of cluster is unknown."); return IndexError_NoReady; } ailego::NibbleInnerProductKmeans algorithm( centroid_count, meta_.dimension()); // mount features into algorithm auto features_count = features_->count(); auto dim = meta_.dimension(); for (size_t i = 0; i < features_count; ++i) { auto vec = reinterpret_cast::type *>( features_->element(i)); algorithm.append(vec, dim); } if (!cents.empty()) { auto centroids = algorithm.mutable_centroids(); centroids->reserve(cents.size()); for (const auto &it : cents) { centroids->append( reinterpret_cast::type *>( it.feature()), size_t(meta_.dimension())); } } else { ailego::Kmc2CentroidsGenerator< ailego::NibbleInnerProductKmeans, IndexThreads> g; g.set_chain_length(markov_chain_length_); g.set_assumption_free(assumption_free_); algorithm.init_centroids(*threads, g); } double cost = 0.0; for (uint32_t i = 0; i < max_iterations_; ++i) { double old_cost, new_epsilon; old_cost = cost; bool result = algorithm.cluster_once(*threads, &cost); if (result != true) { LOG_ERROR("(%u) Failed to cluster.", i + 1); return -1; } new_epsilon = std::abs(cost - old_cost); LOG_DEBUG( "(%u) Updated %zu Clusters, %zu Features: %zu ms, %f -> " "%f = %f", i, algorithm.centroids().count(), features_->count(), (size_t)stamp.milli_seconds(), old_cost, cost, new_epsilon); stamp.reset(); if (new_epsilon < epsilon_) { break; } } // update centroids update_centroids(cents, algorithm); // Purge the empty centroids PurgeCentroids(cents, purge_empty_); return 0; } /*! Kmeans Cluster */ class OptKmeansCluster : public IndexCluster { public: //! Constructor OptKmeansCluster(void) {} //! Destructor virtual ~OptKmeansCluster(void) {} //! Initialize Cluster virtual int init(const IndexMeta &meta, const ailego::Params ¶ms); //! Cleanup Cluster virtual int cleanup(void); //! Reset Cluster virtual int reset(void); //! Update Cluster virtual int update(const ailego::Params ¶ms); //! Suggest dividing to K clusters virtual void suggest(uint32_t k); //! Mount features virtual int mount(IndexFeatures::Pointer feats); //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Classify virtual int classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Label virtual int label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out); protected: //! Members IndexCluster::Pointer algorithm_{}; }; //! Cluster int OptKmeansCluster::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { return algorithm_->cluster(std::move(threads), cents); } //! Classify int OptKmeansCluster::classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { return algorithm_->classify(std::move(threads), cents); } //! Label int OptKmeansCluster::label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out) { return algorithm_->label(std::move(threads), cents, out); } //! Update Cluster int OptKmeansCluster::update(const ailego::Params ¶ms) { return algorithm_->update(params); } //! Reset Cluster int OptKmeansCluster::reset(void) { return algorithm_->reset(); } //! Cleanup Cluster int OptKmeansCluster::cleanup(void) { return algorithm_->cleanup(); } //! Suggest dividing to K clusters void OptKmeansCluster::suggest(uint32_t k) { algorithm_->suggest(k); } int OptKmeansCluster::mount(IndexFeatures::Pointer feats) { return algorithm_->mount(feats); } int OptKmeansCluster::init(const IndexMeta &meta, const ailego::Params ¶ms) { auto type_ = meta.data_type(); if (meta.metric_name() == "InnerProduct" || meta.metric_name() == "Cosine") { switch (type_) { case IndexMeta::DataType::DT_FP16: { algorithm_.reset( new (std::nothrow) NumericalInnerProductKmeansAlgorithm); break; } case IndexMeta::DataType::DT_FP32: { algorithm_.reset(new (std::nothrow) NumericalInnerProductKmeansAlgorithm); break; } case IndexMeta::DataType::DT_FP64: { algorithm_.reset(new (std::nothrow) NumericalInnerProductKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT8: { algorithm_.reset(new (std::nothrow) NumericalInnerProductKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT16: { algorithm_.reset(new (std::nothrow) NumericalInnerProductKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT4: { algorithm_.reset(new (std::nothrow) NibbleInnerProductKmeansAlgorithm); break; } default: { LOG_ERROR("Unsupported feature types %d.", type_); return IndexError_Mismatch; } } } else { switch (type_) { case IndexMeta::DataType::DT_FP16: { algorithm_.reset(new (std::nothrow) NumericalKmeansAlgorithm); break; } case IndexMeta::DataType::DT_FP32: { algorithm_.reset(new (std::nothrow) NumericalKmeansAlgorithm); break; } case IndexMeta::DataType::DT_FP64: { algorithm_.reset(new (std::nothrow) NumericalKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT8: { algorithm_.reset(new (std::nothrow) NumericalKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT16: { algorithm_.reset(new (std::nothrow) NumericalKmeansAlgorithm); break; } case IndexMeta::DataType::DT_INT4: { algorithm_.reset(new (std::nothrow) NibbleKmeansAlgorithm); break; } // TODO case IndexMeta::DataType::DT_BINARY32: { algorithm_.reset(new (std::nothrow) BinaryKmeansAlgorithm); break; } #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: { algorithm_.reset(new (std::nothrow) BinaryKmeansAlgorithm); break; } #endif // AILEGO_M64 default: { LOG_ERROR("Unsupported feature types %d.", type_); return IndexError_Mismatch; } } } algorithm_->init(meta, params); return 0; } INDEX_FACTORY_REGISTER_CLUSTER(OptKmeansCluster); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/seeker.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { class Seeker { public: struct Document { uint32_t index; float score; //! Constructor Document(void) : index(0), score(0.0f) {} //! Constructor Document(uint32_t i, float v) : index(i), score(v) {} //! Constructor Document(const Document &rhs) : index(rhs.index), score(rhs.score) {} //! Assignment Document &operator=(const Document &rhs) { index = rhs.index; score = rhs.score; return *this; } //! Less than bool operator<(const Document &rhs) const { return (this->score < rhs.score); } //! Greater than bool operator>(const Document &rhs) const { return (this->score > rhs.score); } }; public: //! Destructor virtual ~Seeker(void) {} virtual int init(const IndexMeta &meta) = 0; virtual int cleanup(void) = 0; virtual int reset(void) = 0; virtual int mount(IndexFeatures::Pointer feats) = 0; virtual int seek(const void *query, size_t len, Document *out) = 0; virtual IndexFeatures::Pointer original(void) const = 0; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/stratified_cluster.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "cluster_params.h" namespace zvec { namespace core { /*! Stratified Cluster */ class StratifiedCluster : public IndexCluster { public: //! Constructor StratifiedCluster(void) {} //! Destructor virtual ~StratifiedCluster(void) {} //! Initialize Cluster virtual int init(const IndexMeta &meta, const ailego::Params ¶ms) { meta_ = meta; this->update_params(params); return 0; } //! Cleanup Cluster virtual int cleanup(void) { features_.reset(); return 0; } //! Reset Cluster virtual int reset(void) { features_.reset(); return 0; } //! Update Cluster virtual int update(const ailego::Params ¶ms) { this->update_params(params); return 0; } //! Suggest dividing to K clusters virtual void suggest(uint32_t k) { cluster_count_ = k; } //! Mount features virtual int mount(IndexFeatures::Pointer feats) { if (!feats) { return IndexError_InvalidArgument; } if (!feats->is_matched(meta_)) { return IndexError_Mismatch; } features_ = std::move(feats); return 0; } //! Cluster virtual int cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Classify virtual int classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s); //! Label virtual int label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out); protected: //! Test if it is valid bool is_valid(void) const { if (!features_ || !features_->count()) { return false; } return true; } //! Update parameters void update_params(const ailego::Params ¶ms); //! Check Centroids bool check_centroids(const IndexCluster::CentroidList ¢s); //! Initialize Sub Clusters int init_sub_clusters(IndexCluster::Pointer *first, IndexCluster::Pointer *second); //! Initialize First Cluster int init_first_cluster(IndexCluster::Pointer *first); //! Initialize Second Cluster int init_second_cluster(IndexCluster::Pointer *second, IndexFeatures::Pointer features); private: //! Members IndexMeta meta_{}; IndexFeatures::Pointer features_{}; uint32_t cluster_count_{0u}; uint32_t thread_count_{0u}; uint32_t first_cluster_count_{0u}; uint32_t second_cluster_count_{0u}; bool auto_tuning_{false}; std::string first_cluster_class_{"OptKmeansCluster"}; std::string second_cluster_class_{"OptKmeansCluster"}; ailego::Params first_cluster_params_{}; ailego::Params second_cluster_params_{}; // TODO: Maybe optimize later uint32_t second_threads_count_{10u}; // todo }; int StratifiedCluster::cluster(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } IndexCluster::Pointer first_cluster; int result = init_first_cluster(&first_cluster); if (result != 0) { LOG_ERROR("Failed to initialize the first cluster."); return result; } if (first_cluster_count_) { first_cluster->suggest(first_cluster_count_); } // The first clustering LOG_DEBUG("Clustering with first cluster: %s.", first_cluster_class_.c_str()); result = first_cluster->cluster(threads, cents); if (result != 0) { LOG_ERROR("Failed to cluster in first cluster: %s.", first_cluster_class_.c_str()); return result; } result = first_cluster->classify(threads, cents); if (result != 0) { LOG_ERROR("Failed to classify in first cluster: %s.", first_cluster_class_.c_str()); return result; } // Cleanup for saving memory first_cluster.reset(); // Calculate the total cluster count uint32_t total_cluster_count = cents.size() * second_cluster_count_; if (cluster_count_) { total_cluster_count = cluster_count_; } // Use thread_threads cluster instead uint32_t tail_threads = threads->count() % second_threads_count_; std::vector> threads_cluster; // TODO: reusing threads pool? // Incase the threads count less than second threads count if (threads->count() / second_threads_count_ == 0) { for (size_t threads_idx = 0; threads_idx < tail_threads; threads_idx++) { std::shared_ptr curr_threads = std::make_shared(1, false); threads_cluster.push_back(curr_threads); } } else { for (size_t threads_idx = 0; threads_idx < second_threads_count_; threads_idx++) { uint32_t curr_threads_count = threads->count() / second_threads_count_; if (threads_idx >= second_threads_count_ - tail_threads) { curr_threads_count++; } std::shared_ptr curr_threads = std::make_shared(curr_threads_count, false); threads_cluster.push_back(curr_threads); } } auto task_group = threads->make_group(); // The second clustering for (size_t i = 0; i < cents.size(); ++i) { if (cents[i].similars().empty()) { continue; } IndexThreads::Pointer &curr_threads = threads_cluster[i % (threads_cluster.size())]; task_group->submit(ailego::Closure::New( [this, &curr_threads, &total_cluster_count, ¢s](size_t index) { auto &it = cents[index]; IndexCluster::Pointer second_cluster; std::shared_ptr features = std::make_shared( meta_, it.similars().data(), it.similars().size()); int ret = this->init_second_cluster(&second_cluster, features); if (ret != 0) { LOG_ERROR("Failed to initialize the second cluster."); return; } if (auto_tuning_) { if (total_cluster_count) { double factor = static_cast(it.similars().size()) / static_cast(this->features_->count()); second_cluster->suggest( std::max(static_cast( std::floor(total_cluster_count * factor)), 1u)); } } else if (second_cluster_count_) { second_cluster->suggest(second_cluster_count_); } LOG_DEBUG("Clustering with second cluster: %s.", second_cluster_class_.c_str()); ret = second_cluster->cluster(curr_threads, *(it.mutable_subitems())); if (ret != 0) { LOG_ERROR("Failed to cluster in second cluster: %s.", second_cluster_class_.c_str()); } }, i)); } task_group->wait_finish(); return 0; } int StratifiedCluster::classify(IndexThreads::Pointer threads, IndexCluster::CentroidList ¢s) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } IndexCluster::Pointer first_cluster, second_cluster; int result = init_sub_clusters(&first_cluster, &second_cluster); if (result != 0) { LOG_ERROR("Failed to initialize the subclusters."); return result; } // The first classifying result = first_cluster->classify(threads, cents); if (result != 0) { LOG_ERROR("Failed to classify in first cluster: %s.", first_cluster_class_.c_str()); return result; } // Cleanup for saving memory first_cluster.reset(); std::shared_ptr shell = std::make_shared(meta_); // The second classifying for (IndexCluster::Centroid &it : cents) { const auto &feats = it.similars(); if (feats.empty()) { continue; } shell->mount(feats.data(), feats.size()); result = second_cluster->mount(shell); if (result != 0) { LOG_ERROR("Failed to mount features for second cluster: %s.", second_cluster_class_.c_str()); return result; } result = second_cluster->classify(threads, *it.mutable_subitems()); if (result != 0) { LOG_ERROR("Failed to classify in second cluster: %s.", second_cluster_class_.c_str()); return result; } } return 0; } int StratifiedCluster::label(IndexThreads::Pointer threads, const IndexCluster::CentroidList ¢s, std::vector *out) { if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } if (cents.empty()) { LOG_ERROR("The input centroid's list is empty."); return IndexError_InvalidArgument; } if (!this->check_centroids(cents)) { LOG_ERROR("The input centroid's list includes some invalid centroids."); return IndexError_InvalidArgument; } if (!this->is_valid()) { LOG_ERROR("The cluster is not ready."); return IndexError_NoReady; } IndexCluster::Pointer first_cluster; int result = init_first_cluster(&first_cluster); if (result != 0) { LOG_ERROR("Failed to initialize the subclusters."); return result; } result = first_cluster->label(threads, cents, out); if (result != 0) { LOG_ERROR("Failed to label in first cluster: %s.", first_cluster_class_.c_str()); return result; } return 0; } void StratifiedCluster::update_params(const ailego::Params ¶ms) { params.get(GENERAL_THREAD_COUNT, &thread_count_); params.get(GENERAL_CLUSTER_COUNT, &cluster_count_); params.get(STRATIFIED_CLUSTER_COUNT, &cluster_count_); params.get(STRATIFIED_CLUSTER_FIRST_COUNT, &first_cluster_count_); params.get(STRATIFIED_CLUSTER_SECOND_COUNT, &second_cluster_count_); params.get(STRATIFIED_CLUSTER_FIRST_CLASS, &first_cluster_class_); params.get(STRATIFIED_CLUSTER_SECOND_CLASS, &second_cluster_class_); params.get(STRATIFIED_CLUSTER_FIRST_PARAMS, &first_cluster_params_); params.get(STRATIFIED_CLUSTER_SECOND_PARAMS, &second_cluster_params_); params.get(STRATIFIED_CLUSTER_AUTO_TUNING, &auto_tuning_); params.get(STRATIFIED_CLUSTER_SECOND_POOL_COUNT, &second_threads_count_); } bool StratifiedCluster::check_centroids( const IndexCluster::CentroidList ¢s) { for (const auto &it : cents) { if (it.size() != meta_.element_size()) { return false; } } return true; } int StratifiedCluster::init_sub_clusters(IndexCluster::Pointer *first, IndexCluster::Pointer *second) { IndexCluster::Pointer first_cluster = IndexFactory::CreateCluster(first_cluster_class_); if (!first_cluster) { LOG_ERROR("Failed to create first cluster: %s.", first_cluster_class_.c_str()); return IndexError_NoExist; } IndexCluster::Pointer second_cluster = IndexFactory::CreateCluster(second_cluster_class_); if (!second_cluster) { LOG_ERROR("Failed to create second cluster: %s.", first_cluster_class_.c_str()); return IndexError_NoExist; } int result = first_cluster->init(meta_, first_cluster_params_); if (result != 0) { LOG_ERROR("Failed to initialize first cluster: %s.", first_cluster_class_.c_str()); return result; } result = second_cluster->init(meta_, second_cluster_params_); if (result != 0) { LOG_ERROR("Failed to initialize second cluster: %s.", second_cluster_class_.c_str()); return result; } result = first_cluster->mount(features_); if (result != 0) { LOG_ERROR("Failed to mount features for first cluster: %s.", first_cluster_class_.c_str()); return result; } *first = std::move(first_cluster); *second = std::move(second_cluster); return 0; } int StratifiedCluster::init_first_cluster(IndexCluster::Pointer *first) { IndexCluster::Pointer first_cluster = IndexFactory::CreateCluster(first_cluster_class_); if (!first_cluster) { LOG_ERROR("Failed to create first cluster: %s.", first_cluster_class_.c_str()); return IndexError_NoExist; } int result = first_cluster->init(meta_, first_cluster_params_); if (result != 0) { LOG_ERROR("Failed to initialize first cluster: %s.", first_cluster_class_.c_str()); return result; } result = first_cluster->mount(features_); if (result != 0) { LOG_ERROR("Failed to mount features for first cluster: %s.", first_cluster_class_.c_str()); return result; } *first = std::move(first_cluster); return 0; } int StratifiedCluster::init_second_cluster(IndexCluster::Pointer *second, IndexFeatures::Pointer features) { IndexCluster::Pointer second_cluster = IndexFactory::CreateCluster(second_cluster_class_); if (!second_cluster) { LOG_ERROR("Failed to create second cluster: %s.", second_cluster_class_.c_str()); return IndexError_NoExist; } int result = second_cluster->init(meta_, second_cluster_params_); if (result != 0) { LOG_ERROR("Failed to initialize second cluster: %s.", second_cluster_class_.c_str()); return result; } result = second_cluster->mount(features); if (result != 0) { LOG_ERROR("Failed to mount features for second cluster: %s.", second_cluster_class_.c_str()); return result; } *second = std::move(second_cluster); return 0; } INDEX_FACTORY_REGISTER_CLUSTER(StratifiedCluster); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/stratified_cluster_trainer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "stratified_cluster_trainer.h" #include #include #include #include #include #include "cluster_params.h" namespace zvec { namespace core { const std::string StratifiedClusterTrainer::SEP_TOKEN = "*"; const std::string StratifiedClusterTrainer::DEFAULT_CLUSTER_CLASS = "OptKmeansCluster"; int StratifiedClusterTrainer::init_params(const ailego::Params ¶ms) { params.get(STRATIFIED_TRAINER_SAMPLE_COUNT, &sample_count_); params.get(STRATIFIED_TRAINER_SAMPLE_RATIO, &sample_ratio_); params.get(STRATIFIED_TRAINER_THREAD_COUNT, &thread_count_); cluster_auto_tuning_ = params.get_as_bool(STRATIFIED_TRAINER_AUTOAUNE); std::string centroids_num = params.get_as_string(STRATIFIED_TRAINER_CLUSTER_COUNT); if (!centroids_num.empty()) { ailego::StringHelper::Split(centroids_num, SEP_TOKEN, ¢roid_num_vec_); for (size_t i = 0; i < centroid_num_vec_.size(); ++i) { if (centroid_num_vec_[i] == 0) { LOG_ERROR("Invalid centroid num %s", centroids_num.c_str()); return IndexError_InvalidArgument; } } } else { LOG_ERROR("No centroids_num configed."); return IndexError_InvalidArgument; } size_t level_cnt = centroid_num_vec_.size(); for (size_t i = 1; i <= level_cnt; ++i) { std::string level_params_key = STRATIFIED_TRAINER_PARAMS_IN_LEVEL_PREFIX + std::to_string(i); ailego::Params level_params; params.get(level_params_key, &level_params); cluster_params_.push_back(level_params); } std::string cluster_class(DEFAULT_CLUSTER_CLASS); params.get(STRATIFIED_TRAINER_CLASS_NAME, &cluster_class); ailego::StringHelper::Split(cluster_class, SEP_TOKEN, &cluster_class_); if (cluster_class_.size() == 1) { // repeat the cluster class to level_cnt for (size_t i = 1; i < level_cnt; ++i) { cluster_class_.push_back(cluster_class_[0]); } } else if (cluster_class_.size() != level_cnt) { LOG_ERROR("Cluster class should be equal to level count"); return IndexError_InvalidArgument; } return 0; } int StratifiedClusterTrainer::init(const IndexMeta &index_meta, const ailego::Params ¶ms) { int err = init_params(params); if (err != 0) { LOG_ERROR("init params failed, errno:%d,%s", err, IndexError::What(err)); return err; } meta_ = index_meta; ailego::Params cluster_params; if (centroid_num_vec_.size() == 0) { LOG_ERROR("invalid centroid num"); return IndexError_InvalidArgument; } else if (centroid_num_vec_.size() == 1) { // one level clustering class_name_ = cluster_class_[0]; cluster_params = cluster_params_[0]; suggest_centriod_cnt_ = centroid_num_vec_[0]; } else if (centroid_num_vec_.size() == 2) { // cluster level > 1 class_name_ = "StratifiedCluster"; int level_cnt = centroid_num_vec_.size(); cluster_params.set(STRATIFIED_CLUSTER_FIRST_CLASS, cluster_class_[level_cnt - 2]); cluster_params.set(STRATIFIED_CLUSTER_SECOND_CLASS, cluster_class_[level_cnt - 1]); cluster_params.set(STRATIFIED_CLUSTER_FIRST_COUNT, centroid_num_vec_[level_cnt - 2]); cluster_params.set(STRATIFIED_CLUSTER_SECOND_COUNT, centroid_num_vec_[level_cnt - 1]); cluster_params.set(STRATIFIED_CLUSTER_FIRST_PARAMS, cluster_params_[level_cnt - 2]); cluster_params.set(STRATIFIED_CLUSTER_SECOND_PARAMS, cluster_params_[level_cnt - 1]); cluster_params.set(STRATIFIED_CLUSTER_AUTO_TUNING, cluster_auto_tuning_); suggest_centriod_cnt_ = centroid_num_vec_[level_cnt - 1] * centroid_num_vec_[level_cnt - 2]; } else { LOG_ERROR("Unsupported more than 2 level clustering."); return IndexError_Unsupported; } cluster_ = IndexFactory::CreateCluster(class_name_); if (!cluster_) { LOG_ERROR("Failed to create cluster[%s]", class_name_.c_str()); return IndexError_InvalidArgument; } int result = cluster_->init(meta_, cluster_params); if (result != 0) { LOG_ERROR("Failed to initialize of cluster[%s], error: %d, %s", class_name_.c_str(), result, IndexError::What(result)); return result; } if (suggest_centriod_cnt_ > 0) { cluster_->suggest(suggest_centriod_cnt_); } return 0; } int StratifiedClusterTrainer::cleanup(void) { cluster_ = nullptr; centroids_.clear(); return 0; } int StratifiedClusterTrainer::train(IndexThreads::Pointer threads, IndexHolder::Pointer holder) { ailego::ElapsedTime timer; if (!holder) { return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { return IndexError_Mismatch; } if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } size_t train_sample_count = std::max( sample_count_, static_cast(sample_ratio_ * holder->count())); IndexFeatures::Pointer features; if (train_sample_count > 0) { LOG_INFO( "Train sampling, SampleCount=%u, SampleRatio=%f, HolderCount=%lu, " "TrainCount=%lu", sample_count_, sample_ratio_, holder->count(), train_sample_count); auto sampler = std::make_shared>( meta_, train_sample_count); size_t pre_reserve = train_sample_count < holder->count() ? train_sample_count : holder->count(); sampler->reserve(pre_reserve); for (auto iter = holder->create_iterator(); iter && iter->is_valid(); iter->next()) { sampler->emplace(iter->data()); } features = sampler; stats_.set_trained_count(train_sample_count); } else { LOG_INFO( "Do no sampling, SampleCount=%u, SampleRatio=%f, " "HolderCount=%lu, TrainCount=%lu", sample_count_, sample_ratio_, holder->count(), holder->count()); auto no_sampler = std::make_shared(meta_); for (auto iter = holder->create_iterator(); iter && iter->is_valid(); iter->next()) { no_sampler->emplace(iter->data()); } features = no_sampler; stats_.set_trained_count(holder->count()); } stats_.set_discarded_count(0); // Holder is not needed, cleanup it. holder.reset(); int result = cluster_->mount(features); if (result != 0) { LOG_ERROR("Failed to mount features of cluster[%s], error: %d, %s", class_name_.c_str(), result, IndexError::What(result)); return result; } centroids_.clear(); result = cluster_->cluster(std::move(threads), centroids_); if (result != 0) { LOG_ERROR("Failed to cluster features of cluster[%s], error: %d, %s", class_name_.c_str(), result, IndexError::What(result)); return result; } // check build result std::vector level_size; std::function cal_centroid_cnt = [&cal_centroid_cnt, &level_size]( const IndexCluster::CentroidList ¢s, size_t level) { if (level > level_size.size()) { level_size.resize(level); } level_size[level - 1] += cents.size(); for (const auto &it : cents) { if (!it.subitems().empty()) { cal_centroid_cnt(it.subitems(), level + 1); } } }; cal_centroid_cnt(centroids_, 1); size_t centroids_num = level_size[level_size.size() - 1]; if (centroids_num > suggest_centriod_cnt_) { LOG_WARN( "Built centroid(%zd level) count[%zd] bigger than expected " "count[%d]", level_size.size(), centroids_num, suggest_centriod_cnt_); } else { LOG_INFO("Built centroid(%zd level) count[%zd], expected count[%d]", level_size.size(), centroids_num, suggest_centriod_cnt_); } stats_.set_trained_costtime(timer.milli_seconds()); return 0; } int StratifiedClusterTrainer::load(IndexStorage::Pointer cntr) { if (!cntr) { LOG_ERROR("IndexStorage is nullptr."); return IndexError_InvalidArgument; } std::shared_ptr bundle = std::make_shared(); if (!bundle) { LOG_ERROR("New MemoryInndexBundle failed."); return IndexError_NoMemory; } auto results = cntr->get_all(); for (auto &it : results) { IndexStorage::Segment::Pointer &seg = it.second; if (!seg) { LOG_ERROR("Get Segment %s failed.", it.first.c_str()); return IndexError_InvalidArgument; } size_t data_size = seg->data_size(); const void *data = nullptr; size_t actual_size = seg->read(0, &data, data_size); if (actual_size != data_size) { LOG_ERROR("Read data failed expect %zu, actual %zu.", data_size, actual_size); return IndexError_ReadData; } bundle->set(it.first, data, data_size); } int result = IndexHelper::DeserializeFromStorage(cntr.get(), &meta_); if (result != 0) { LOG_ERROR("Failed to deserialize meta from container"); return result; } result = IndexCluster::Deserialize(meta_, std::move(bundle), ¢roids_); if (result != 0) { LOG_ERROR("Failed to deserialize index: %d", result); return result; } return 0; } int StratifiedClusterTrainer::dump(const IndexDumper::Pointer &dumper) { IndexBundle::Pointer bundle; int result = IndexCluster::Serialize(meta_, centroids_, &bundle); if (result != 0) { LOG_ERROR("IndexCluster Serialize failed with ret %d.", result); return result; } result = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (result != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return result; } for (const auto &it : bundle->all()) { size_t data_size = it.second.size(); result = dumper->append(it.first, data_size, 0, 0); if (result != 0) { LOG_ERROR("Dumper append meta %s %zu failed.", it.first.c_str(), data_size); return IndexError_PackIndex; } size_t actual_size = dumper->write(it.second.buffer(), data_size); if (actual_size != data_size) { LOG_ERROR("Dumper segment %s expect %zu, actual %zu.", it.first.c_str(), data_size, actual_size); return IndexError_PackIndex; } } return 0; } const IndexMeta &StratifiedClusterTrainer::meta(void) const { return meta_; } const IndexTrainer::Stats &StratifiedClusterTrainer::stats(void) const { return stats_; } IndexBundle::Pointer StratifiedClusterTrainer::indexes(void) const { IndexBundle::Pointer bundle; IndexCluster::Serialize(meta_, centroids_, &bundle); return bundle; } //! Register Cluster Trainer in Factory INDEX_FACTORY_REGISTER_TRAINER(StratifiedClusterTrainer); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/stratified_cluster_trainer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace core { /*! Cluster Trainer */ class StratifiedClusterTrainer : public IndexTrainer { public: typedef std::shared_ptr Pointer; //! Constructor StratifiedClusterTrainer(void) {} //! Destructor ~StratifiedClusterTrainer(void) {} protected: //! Initialize Trainer virtual int init(const IndexMeta &meta, const ailego::Params ¶ms); //! Cleanup Trainer virtual int cleanup(void); //! Train the data virtual int train(IndexThreads::Pointer threads, IndexHolder::Pointer holder); //! Load index from file path or dir virtual int load(IndexStorage::Pointer cntr); //! Dump index into file path or dir virtual int dump(const IndexDumper::Pointer &dumper); //! Retrieve Index Meta virtual const IndexMeta &meta(void) const; //! Retrieve statistics virtual const IndexTrainer::Stats &stats(void) const; //! Retrieve the output indexes virtual IndexBundle::Pointer indexes(void) const; private: int init_params(const ailego::Params ¶ms); private: IndexMeta meta_{}; uint32_t sample_count_{0u}; float sample_ratio_{0.0}; uint32_t thread_count_{0u}; bool cluster_auto_tuning_{false}; IndexCluster::Pointer cluster_{}; IndexCluster::CentroidList centroids_{}; uint32_t suggest_centriod_cnt_{0u}; std::string class_name_; std::vector cluster_class_; std::vector centroid_num_vec_; std::vector cluster_params_; IndexTrainer::Stats stats_{}; private: static const std::string SEP_TOKEN; static const std::string DEFAULT_CLUSTER_CLASS; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/cluster/vector_mean.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include namespace zvec { namespace core { /*! Vector Mean */ struct VectorMean { //! Destructor virtual ~VectorMean(void) {} //! Reset accumulator virtual void reset(void) = 0; //! Plus a vector virtual bool plus(const void *vec, size_t len) = 0; //! Retrieve the mean of vectors virtual bool mean(void *out, size_t len) const = 0; //! Retrieve the mean of vectors virtual void mean(std::string *out) const = 0; //! Merge another vector mean virtual bool merge(const VectorMean &rhs) = 0; //! Retrieve the count of vectors virtual size_t count(void) const = 0; //! Retrieve the dimension of vectors virtual size_t dimension(void) const = 0; }; /*! Vector Mean Array */ struct VectorMeanArray { //! Destructor virtual ~VectorMeanArray(void) {} //! Operator [] VectorMean &operator[](size_t i) { return this->at(i); } //! Operator [] const VectorMean &operator[](size_t i) const { return this->at(i); } //! Resize accumulators virtual void resize(size_t cnt) = 0; //! Clear accumulators virtual void clear(void) = 0; //! Retrieve an accumulator virtual VectorMean &at(size_t i) = 0; //! Retrieve an accumulator virtual const VectorMean &at(size_t i) const = 0; //! Retrieve the count of accumulators virtual size_t count(void) const = 0; //! Retrieve the dimension of accumulators virtual size_t dimension(void) const = 0; }; /*! General Vector Mean Array */ template ::type> class GeneralVectorMeanArray : public VectorMeanArray { public: //! Constructor GeneralVectorMeanArray(size_t dim) : dimension_(dim), array_() {} //! Constructor GeneralVectorMeanArray(const GeneralVectorMeanArray &rhs) : dimension_(rhs.dimension_), array_(rhs.array_) {} //! Constructor GeneralVectorMeanArray(GeneralVectorMeanArray &&rhs) : dimension_(rhs.dimension_), array_(std::move(rhs.array_)) {} //! Emplace an accumulator template bool emplace(TArgs &&...args) { T accum(std::forward(args)...); if (accum.dimension() != dimension_) { return false; } array_.push_back(std::move(accum)); return true; } //! Resize accumulators virtual void resize(size_t cnt) { if (array_.size() < cnt) { for (size_t i = array_.size(); i < cnt; ++i) { array_.emplace_back(dimension_); } } else { array_.resize(cnt); } } //! Clear accumulators virtual void clear(void) { array_.clear(); } //! Retrieve an accumulator virtual VectorMean &at(size_t i) { return array_[i]; } //! Retrieve an accumulator virtual const VectorMean &at(size_t i) const { return array_[i]; } //! Retrieve the count of accumulators virtual size_t count(void) const { return array_.size(); } //! Retrieve the dimension of accumulators virtual size_t dimension(void) const { return dimension_; } private: //! Disable them GeneralVectorMeanArray(void) = delete; //! Members size_t dimension_; std::vector array_; }; /*! Numerical Vector Mean */ template ::value>::type> class NumericalVectorMean : public VectorMean { public: //! Constructor NumericalVectorMean(void) : count_(0), accums_() {} //! Constructor NumericalVectorMean(const NumericalVectorMean &rhs) : count_(rhs.count_), accums_(rhs.accums_) {} //! Constructor NumericalVectorMean(NumericalVectorMean &&rhs) : count_(rhs.count_), accums_(std::move(rhs.accums_)) {} //! Constructor NumericalVectorMean(size_t dim) : count_(0), accums_(dim) {} //! Constructor NumericalVectorMean(const T *means, size_t dim, size_t cnt) : count_(cnt), accums_(dim) { for (size_t i = 0; i < dim; ++i) { accums_[i] = static_cast(means[i]) * count_; } } //! Reset accumulator void reset(size_t dim) { count_ = 0u; accums_.clear(); accums_.resize(dim, 0.0); } //! Reset accumulator virtual void reset(void) { this->reset(accums_.size()); } //! Plus a vector virtual bool plus(const void *vec, size_t len) { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { accums_[i] += *(static_cast(vec) + i); } ++count_; return true; } //! Retrieve the mean of vectors virtual bool mean(void *out, size_t len) const { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { *(static_cast(out) + i) = FloatCast(accums_[i] / count_); } return true; } //! Retrieve the mean of vectors virtual void mean(std::string *out) const { ailego::NumericalVector &vec = *static_cast *>(out); size_t dim = accums_.size(); vec.resize(dim); for (size_t i = 0; i < dim; ++i) { vec[i] = FloatCast(accums_[i] / count_); } } //! Merge another vector mean virtual bool merge(const VectorMean &rhs) { const NumericalVectorMean &src = dynamic_cast &>(rhs); size_t dim = accums_.size(); if (dim != src.accums_.size()) { return false; } count_ += src.count_; for (size_t i = 0; i < dim; ++i) { accums_[i] += src.accums_[i]; } return true; } //! Retrieve the count of vectors virtual size_t count(void) const { return count_; } //! Retrieve dimension of accumulator virtual size_t dimension(void) const { return accums_.size(); } protected: //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(val); } //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(std::round(val)); } private: //! Members size_t count_; std::vector accums_; }; /*! Numerical Vector Harmonic Mean */ template ::value>::type> class NumericalVectorHarmonicMean : public VectorMean { public: //! Constructor NumericalVectorHarmonicMean(void) : count_(0), accums_() {} //! Constructor NumericalVectorHarmonicMean(const NumericalVectorHarmonicMean &rhs) : count_(rhs.count_), accums_(rhs.accums_) {} //! Constructor NumericalVectorHarmonicMean(NumericalVectorHarmonicMean &&rhs) : count_(rhs.count_), accums_(std::move(rhs.accums_)) {} //! Constructor NumericalVectorHarmonicMean(size_t dim) : count_(0), accums_(dim) {} //! Constructor NumericalVectorHarmonicMean(const T *means, size_t dim, size_t cnt) : count_(cnt), accums_(dim) { for (size_t i = 0; i < dim; ++i) { accums_[i] = static_cast(count_) / static_cast(means[i]); } } //! Reset accumulator void reset(size_t dim) { count_ = 0u; accums_.clear(); accums_.resize(dim, 0.0); } //! Reset accumulator virtual void reset(void) { this->reset(accums_.size()); } //! Plus a vector (harmonic) virtual bool plus(const void *vec, size_t len) { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { accums_[i] += 1.0 / *(static_cast(vec) + i); } ++count_; return true; } //! Retrieve the mean of vectors (harmonic) virtual bool mean(void *out, size_t len) const { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { *(static_cast(out) + i) = FloatCast(count_ / accums_[i]); } return true; } //! Retrieve the mean of vectors virtual void mean(std::string *out) const { ailego::NumericalVector &vec = *static_cast *>(out); size_t dim = accums_.size(); vec.resize(dim); for (size_t i = 0; i < dim; ++i) { vec[i] = FloatCast(count_ / accums_[i]); } } //! Merge another vector mean virtual bool merge(const VectorMean &rhs) { const NumericalVectorHarmonicMean &src = dynamic_cast &>(rhs); size_t dim = accums_.size(); if (dim != src.accums_.size()) { return false; } count_ += src.count_; for (size_t i = 0; i < dim; ++i) { accums_[i] += src.accums_[i]; } return true; } //! Retrieve the count of vectors virtual size_t count(void) const { return count_; } //! Retrieve dimension of accumulator virtual size_t dimension(void) const { return accums_.size(); } protected: //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(val); } //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(std::round(val)); } private: //! Members size_t count_; std::vector accums_; }; /*! Numerical Vector Geometric Mean */ template ::value>::type> class NumericalVectorGeometricMean : public VectorMean { public: //! Constructor NumericalVectorGeometricMean(void) : count_(0), accums_() {} //! Constructor NumericalVectorGeometricMean(const NumericalVectorGeometricMean &rhs) : count_(rhs.count_), accums_(rhs.accums_) {} //! Constructor NumericalVectorGeometricMean(NumericalVectorGeometricMean &&rhs) : count_(rhs.count_), accums_(std::move(rhs.accums_)) {} //! Constructor NumericalVectorGeometricMean(size_t dim) : count_(0), accums_(dim, 1.0) {} //! Constructor NumericalVectorGeometricMean(const T *means, size_t dim, size_t cnt) : count_(cnt), accums_(dim, 1.0) { for (size_t i = 0; i < dim; ++i) { accums_[i] = std::pow(static_cast(means[i]), count_); } } //! Reset accumulator void reset(size_t dim) { count_ = 0u; accums_.clear(); accums_.resize(dim, 1.0); } //! Reset accumulator virtual void reset(void) { this->reset(accums_.size()); } //! Plus a vector (geometric) virtual bool plus(const void *vec, size_t len) { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { accums_[i] *= *(static_cast(vec) + i); } ++count_; return true; } //! Retrieve the mean of vectors (geometric) virtual bool mean(void *out, size_t len) const { size_t dim = accums_.size(); if (dim * sizeof(T) != len) { return false; } for (size_t i = 0; i < dim; ++i) { *(static_cast(out) + i) = FloatCast(std::pow(accums_[i], 1.0 / count_)); } return true; } //! Retrieve the mean of vectors virtual void mean(std::string *out) const { ailego::NumericalVector &vec = *static_cast *>(out); size_t dim = accums_.size(); vec.resize(dim); for (size_t i = 0; i < dim; ++i) { vec[i] = FloatCast(std::pow(accums_[i], 1.0 / count_)); } } //! Merge another vector mean virtual bool merge(const VectorMean &rhs) { const NumericalVectorGeometricMean &src = dynamic_cast &>(rhs); size_t dim = accums_.size(); if (dim != src.accums_.size()) { return false; } count_ += src.count_; for (size_t i = 0; i < dim; ++i) { accums_[i] *= src.accums_[i]; } return true; } //! Retrieve the count of vectors virtual size_t count(void) const { return count_; } //! Retrieve dimension of accumulator virtual size_t dimension(void) const { return accums_.size(); } protected: //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(val); } //! Convert float type to another type template static auto FloatCast(const double &val) -> typename std::enable_if::value, U>::type { return static_cast(std::round(val)); } private: //! Members size_t count_; std::vector accums_; }; /*! Binary Vector Mean */ class BinaryVectorMean : public VectorMean { public: //! Constructor BinaryVectorMean(void) : count_(0), accums_() {} //! Constructor BinaryVectorMean(const BinaryVectorMean &rhs) : count_(rhs.count_), accums_(rhs.accums_) {} //! Constructor BinaryVectorMean(BinaryVectorMean &&rhs) : count_(rhs.count_), accums_(std::move(rhs.accums_)) {} //! Constructor BinaryVectorMean(size_t dim) : count_(0), accums_(((dim + 7) >> 3) << 3) {} //! Constructor BinaryVectorMean(const void *means, size_t dim, size_t cnt) : count_(cnt), accums_(((dim + 7) >> 3) << 3) { const uint8_t *bits = reinterpret_cast(means); for (size_t i = 0; i < dim; ++i) { accums_[i] = (count_ >> 1); if (bits[i >> 3] & static_cast(1 << (i & 0x7))) { accums_[i] += 1; } } } //! Reset accumulator void reset(size_t dim) { count_ = 0u; accums_.clear(); accums_.resize(dim); } //! Reset accumulator virtual void reset(void) { this->reset(accums_.size()); } //! Plus a vector virtual bool plus(const void *vec, size_t len) { size_t dim = accums_.size(); if (dim != (len << 3)) { return false; } const uint8_t *bits = reinterpret_cast(vec); for (size_t i = 0; i < dim; ++i) { if (bits[i >> 3] & static_cast(1 << (i & 0x7))) { accums_[i] += 1; } } ++count_; return true; } //! Retrieve the mean of vectors virtual bool mean(void *out, size_t len) const { size_t dim = accums_.size(); if (dim != (len << 3)) { return false; } memset(out, 0, len); uint8_t *bits = reinterpret_cast(out); size_t half_count = count_ >> 1; for (size_t i = 0; i < dim; ++i) { if (accums_[i] > half_count) { bits[i >> 3] |= static_cast(1 << (i & 0x7)); } } return true; } //! Retrieve the mean of vectors virtual void mean(std::string *out) const { size_t dim = accums_.size(); out->clear(); out->resize((dim + 7) / 8); uint8_t *bits = reinterpret_cast(const_cast(out->data())); size_t half_count = count_ >> 1; for (size_t i = 0; i < dim; ++i) { if (accums_[i] > half_count) { bits[i >> 3] |= static_cast(1 << (i & 0x7)); } } } //! Merge another vector mean virtual bool merge(const VectorMean &rhs) { const BinaryVectorMean &src = dynamic_cast(rhs); size_t dim = accums_.size(); if (dim != src.accums_.size()) { return false; } count_ += src.count_; for (size_t i = 0; i < dim; ++i) { accums_[i] += src.accums_[i]; } return true; } //! Retrieve the count of vectors virtual size_t count(void) const { return count_; } //! Retrieve dimension of accumulator virtual size_t dimension(void) const { return accums_.size(); } private: //! Members size_t count_; std::vector accums_; }; /*! Numerical Vector Mean */ template ::value>::type> class NibbleVectorMean : public VectorMean { public: //! Constructor NibbleVectorMean(void) : count_(0), accums_() {} //! Constructor NibbleVectorMean(const NibbleVectorMean &rhs) : count_(rhs.count_), accums_(rhs.accums_) {} //! Constructor NibbleVectorMean(NibbleVectorMean &&rhs) : count_(rhs.count_), accums_(std::move(rhs.accums_)) {} //! Constructor NibbleVectorMean(size_t dim) : count_(0), accums_(dim) {} //! Constructor NibbleVectorMean(const void *means, size_t dim, size_t cnt) : count_(cnt), accums_(dim) { const uint8_t *arr = reinterpret_cast(means); for (size_t i = 0; i != dim; i += 2) { uint8_t val = arr[i >> 1]; int lo = ((int8_t)(val << 4) >> 4); int hi = ((int8_t)(val) >> 4); accums_[i] = static_cast(lo) * count_; accums_[i + 1] = static_cast(hi) * count_; } } //! Reset accumulator void reset(size_t dim) { count_ = 0u; accums_.clear(); accums_.resize(dim, 0.0); } //! Reset accumulator virtual void reset(void) { this->reset(accums_.size()); } //! Plus a vector virtual bool plus(const void *vec, size_t len) { size_t dim = accums_.size(); if (dim != (len << 1)) { return false; } const uint8_t *arr = reinterpret_cast(vec); for (size_t i = 0; i != dim; i += 2) { uint8_t val = arr[i >> 1]; accums_[i] += ((int8_t)(val << 4) >> 4); accums_[i + 1] += ((int8_t)(val) >> 4); } ++count_; return true; } //! Retrieve the mean of vectors virtual bool mean(void *out, size_t len) const { size_t dim = accums_.size(); if (dim != (len << 1)) { return false; } memset(out, 0, len); uint8_t *arr = reinterpret_cast(out); for (size_t i = 0; i != dim; i += 2) { int lo = static_cast(std::round(accums_[i] / count_)); int hi = static_cast(std::round(accums_[i + 1] / count_)); arr[i >> 1] = (uint8_t)((hi << 4) & 0xf0) | (uint8_t)(lo & 0xf); } return true; } //! Retrieve the mean of vectors virtual void mean(std::string *out) const { size_t dim = accums_.size(); out->clear(); out->resize(dim >> 1); uint8_t *arr = reinterpret_cast(const_cast(out->data())); for (size_t i = 0; i != dim; i += 2) { int lo = static_cast(std::round(accums_[i] / count_)); int hi = static_cast(std::round(accums_[i + 1] / count_)); arr[i >> 1] = (uint8_t)((hi << 4) & 0xf0) | (uint8_t)(lo & 0xf); } } //! Merge another vector mean virtual bool merge(const VectorMean &rhs) { const NibbleVectorMean &src = dynamic_cast(rhs); size_t dim = accums_.size(); if (dim != src.accums_.size()) { return false; } count_ += src.count_; for (size_t i = 0; i < dim; ++i) { accums_[i] += src.accums_[i]; } return true; } //! Retrieve the count of vectors virtual size_t count(void) const { return count_; } //! Retrieve dimension of accumulator virtual size_t dimension(void) const { return accums_.size(); } private: //! Members size_t count_; std::vector accums_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) #message(STATUS "PROJECT_ROOT_DIR = ${PROJECT_ROOT_DIR}") cc_library( NAME core_knn_flat STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS core_framework INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm ${PROJECT_ROOT_DIR}/src/core/framework VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/flat/flat_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_builder.h" #include "flat_utility.h" namespace zvec { namespace core { template int FlatBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { meta_ = meta; // Set the major order bool column_major_order = false; if (params.get(PARAM_FLAT_COLUMN_MAJOR_ORDER, &column_major_order)) { meta_.set_major_order(column_major_order ? IndexMeta::MO_COLUMN : IndexMeta::MO_ROW); } // Verify column major order if (meta_.major_order() != IndexMeta::MO_ROW) { IndexMeta::DataType dt = meta_.data_type(); bool support_column_major = false; if ((dt != IndexMeta::DataType::DT_FP32 && dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_INT8 && dt != IndexMeta::DT_INT4 && dt != IndexMeta::DT_BINARY32 && dt != IndexMeta::DT_BINARY64) || (meta_.unit_size() != IndexMeta::UnitSizeof(dt))) { if (meta_.major_order() == IndexMeta::MO_COLUMN) { LOG_ERROR("Unsupported type %d with unit size %u.", dt, meta_.unit_size()); return IndexError_Unsupported; } else { support_column_major = false; } } if (meta_.element_size() % IndexMeta::AlignSizeof(dt) != 0) { if (meta_.major_order() == IndexMeta::MO_COLUMN) { LOG_ERROR("Unsupported type %d with dimension %u.", dt, meta_.dimension()); return IndexError_Unsupported; } else { support_column_major = false; } } if (meta_.major_order() == IndexMeta::MO_UNDEFINED && support_column_major) { meta_.set_major_order(IndexMeta::MO_COLUMN); } } if (!VerifyMetric(meta_)) { LOG_ERROR("Invalid index measure %s.", meta_.metric_name().c_str()); return IndexError_InvalidArgument; } std::string tag = std::to_string(BATCH_SIZE); ailego::Params searcher_params; searcher_params.set(PARAM_FLAT_BATCH_SIZE, BATCH_SIZE); meta_.set_searcher("FlatSearcher" + tag, 0, searcher_params); meta_.set_builder("FlatBuilder" + tag, 0, params); return 0; } template int FlatBuilder::build(IndexThreads::Pointer, IndexHolder::Pointer holder) { ailego::ElapsedTime stamp; if (!holder->is_matched(meta_)) { LOG_ERROR("The holder is unmatched with initialized meta."); return IndexError_Mismatch; } holder_ = std::move(holder); stats_.set_built_count(holder_->count()); stats_.set_built_costtime(stamp.milli_seconds()); return 0; } template int FlatBuilder::dump(const IndexDumper::Pointer &dumper) { ailego::ElapsedTime stamp; if (!holder_) { return IndexError_NoReady; } std::vector keys; if (meta_.major_order() == IndexMeta::MO_COLUMN) { int error_code = this->write_column_index(dumper.get(), &keys); if (error_code != 0) { return error_code; } } else { int error_code = this->write_row_index(dumper.get(), &keys); if (error_code != 0) { return error_code; } } int error_code = this->write_keys(keys, dumper.get()); if (error_code != 0) { return error_code; } error_code = this->write_mapping(keys, dumper.get()); if (error_code != 0) { return error_code; } error_code = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (error_code != 0) { return error_code; } stats_.set_dumped_count(keys.size()); stats_.set_dumped_costtime(stamp.milli_seconds()); return 0; } template int FlatBuilder::write_keys(const std::vector &keys, IndexDumper *dumper) { size_t keys_size = keys.size() * sizeof(uint64_t); size_t keys_padding_size = ailego_align(keys_size, 32) - keys_size; if (dumper->write(keys.data(), keys_size) != keys_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } // Write the padding if need if (keys_padding_size) { std::string padding(keys_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(FLAT_SEGMENT_KEYS_SEG_ID, keys_size, keys_padding_size, 0); } template int FlatBuilder::write_mapping(const std::vector &keys, IndexDumper *dumper) { std::vector mapping(keys.size()); std::iota(mapping.begin(), mapping.end(), 0); std::sort( mapping.begin(), mapping.end(), [&keys](uint32_t lhs, uint32_t rhs) { return (keys[lhs] < keys[rhs]); }); size_t mapping_size = mapping.size() * sizeof(uint32_t); size_t mapping_padding_size = ailego_align(mapping_size, 32) - mapping_size; if (dumper->write(mapping.data(), mapping_size) != mapping_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } // Write the padding if need if (mapping_padding_size) { std::string padding(mapping_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(FLAT_SEGMENT_MAPPING_SEG_ID, mapping_size, mapping_padding_size, 0); } template template int FlatBuilder::write_column_index(IndexDumper *dumper, std::vector *keys) { auto iter = holder_->create_iterator(); if (!iter) { LOG_ERROR("Failed to create iterator of holder"); return IndexError_Runtime; } // Write features size_t element_size = holder_->element_size(); size_t block_size = element_size * BATCH_SIZE; std::string block1, block2; block1.reserve(block_size); block2.reserve(block_size); for (; iter->is_valid(); iter->next()) { block1.append(reinterpret_cast(iter->data()), element_size); keys->emplace_back(iter->key()); if (block1.size() == block_size) { ailego::MatrixHelper::Transpose( block1.data(), element_size / sizeof(T), (void *)block2.data()); if (dumper->write(block2.data(), block_size) != block_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } block1.clear(); } } if (!block1.empty()) { if (dumper->write(block1.data(), block1.size()) != block1.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } // Write the padding if need size_t features_size = keys->size() * element_size; size_t features_padding_size = ailego_align(features_size, 32) - features_size; if (features_padding_size) { std::string padding(features_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(FLAT_SEGMENT_FEATURES_SEG_ID, features_size, features_padding_size, 0); } template int FlatBuilder::write_row_index(IndexDumper *dumper, std::vector *keys) { auto iter = holder_->create_iterator(); if (!iter) { LOG_ERROR("Failed to create iterator of holder"); return IndexError_Runtime; } // Write features size_t element_size = holder_->element_size(); for (; iter->is_valid(); iter->next()) { if (dumper->write(iter->data(), element_size) != element_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } keys->emplace_back(iter->key()); } // Write the padding if need size_t features_size = keys->size() * element_size; size_t features_padding_size = ailego_align(features_size, 32) - features_size; if (features_padding_size) { std::string padding(features_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(FLAT_SEGMENT_FEATURES_SEG_ID, features_size, features_padding_size, 0); } INDEX_FACTORY_REGISTER_BUILDER_ALIAS(LinearBuilder, FlatBuilder<32>); INDEX_FACTORY_REGISTER_BUILDER_ALIAS(FlatBuilder, FlatBuilder<32>); INDEX_FACTORY_REGISTER_BUILDER_ALIAS(FlatBuilder16, FlatBuilder<16>); INDEX_FACTORY_REGISTER_BUILDER_ALIAS(FlatBuilder32, FlatBuilder<32>); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "flat_utility.h" namespace zvec { namespace core { /*! Flat Builder */ template class FlatBuilder : public IndexBuilder { public: //! Destructor virtual ~FlatBuilder(void) {} //! Initialize the builder int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder int cleanup(void) override { holder_ = nullptr; return 0; } //! Train the data int train(IndexThreads::Pointer, IndexHolder::Pointer) override { stats_.set_trained_count(0u); stats_.set_trained_costtime(0u); return 0; } //! Train the data int train(const IndexTrainer::Pointer &) override { stats_.set_trained_count(0u); stats_.set_trained_costtime(0u); return 0; } //! Build the index int build(IndexThreads::Pointer, IndexHolder::Pointer holder) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const IndexBuilder::Stats &stats(void) const override { return stats_; } protected: //! Dump index keys int write_keys(const std::vector &keys, IndexDumper *dumper); //! Dump index keys mapping int write_mapping(const std::vector &keys, IndexDumper *dumper); //! Dump index using column-major-order format template int write_column_index(IndexDumper *dumper, std::vector *keys); //! Dump index using column-major-order format int write_column_index(IndexDumper *dumper, std::vector *keys) { switch (IndexMeta::AlignSizeof(meta_.data_type())) { case 2: return this->write_column_index(dumper, keys); case 4: return this->write_column_index(dumper, keys); case 8: return this->write_column_index(dumper, keys); default: ailego_check_with(0, "BAD CASE"); } return IndexError_Runtime; } //! Dump index using row-major-order format int write_row_index(IndexDumper *dumper, std::vector *keys); private: IndexMeta meta_{}; IndexBuilder::Stats stats_{}; IndexHolder::Pointer holder_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_distance_matrix.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "flat_utility.h" namespace zvec { namespace core { /*! Brute Force Distance Tuple */ template class FlatDistanceTuple; /*! Brute Force Distance Tuple */ template <> class FlatDistanceTuple<1> { public: //! Retrieve non-zero if all distances are valid. bool is_valid(void) const { return !!distance_; } //! Retrieve non-zero if a distance is valid. bool is_valid(size_t m) const { return m == 1 && !!distance_; } //! Initialize the distance tuple void initialize(const IndexMetric &measure) { distance_ = measure.distance_matrix(1, 1); } //! Initialize the distance tuple void initialize(const IndexMetric &measure, size_t m) { distance_ = measure.distance_matrix(m, 1); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if::type { distance_(m, q, dim, out); } private: IndexMetric::MatrixDistance distance_{}; }; /*! Brute Force Distance Tuple */ template class FlatDistanceTuple< K, typename std::enable_if::value>::type> { public: //! Retrieve non-zero if all distances are valid. bool is_valid(void) const { return (distance_tuple_.is_valid() && !!distance_); } //! Retrieve non-zero if a distance is valid. bool is_valid(size_t m) const { return (m == K ? (!!distance_) : (m < K ? distance_tuple_.is_valid(m) : false)); } //! Initialize the distance tuple void initialize(const IndexMetric &measure) { distance_tuple_.initialize(measure); distance_ = measure.distance_matrix(K, 1); } //! Initialize the distance tuple void initialize(const IndexMetric &measure, size_t m) { distance_tuple_.initialize(measure, m); distance_ = measure.distance_matrix(m, K); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if::type { distance_(m, q, dim, out); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if<(K > M) && IsEqualPowerofTwo::value>::type { distance_tuple_.template distance(m, q, dim, out); } private: FlatDistanceTuple<(K >> 1)> distance_tuple_{}; IndexMetric::MatrixDistance distance_{}; }; /*! Brute Force Distance Matrix */ template class FlatDistanceMatrix; /*! Brute Force Distance Matrix */ template <> class FlatDistanceMatrix<1> { public: //! Retrieve non-zero if all distances are valid. bool is_valid(void) const { return (!!distance_); } //! Initialize the distance tuple void initialize(const IndexMetric &measure) { distance_ = measure.distance_matrix(1, 1); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if::type { distance_(m, q, dim, out); } private: IndexMetric::MatrixDistance distance_{}; }; /*! Brute Force Distance Matrix */ template class FlatDistanceMatrix< K, typename std::enable_if::value>::type> { public: //! Retrieve non-zero if all distances are valid. bool is_valid(void) const { return (tuple_h_.is_valid() && tuple_v_.is_valid()); } //! Retrieve non-zero if a distance is valid. bool is_valid(size_t m, size_t n) const { return (m == K ? tuple_h_.is_valid(n) : (m < K && n == 1 ? tuple_v_.is_valid(m) : false)); } //! Initialize the distance tuple void initialize(const IndexMetric &measure) { tuple_h_.initialize(measure, K); tuple_v_.initialize(measure); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if<(K == M) && (K >= N)>::type { tuple_h_.template distance(m, q, dim, out); } //! Compute the distance between matrix and query template auto distance(const void *m, const void *q, size_t dim, float *out) const -> typename std::enable_if<(K > M) && (N == 1u)>::type { tuple_v_.template distance(m, q, dim, out); } private: FlatDistanceTuple tuple_h_{}; FlatDistanceTuple<(K >> 1)> tuple_v_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_index_format.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { using node_id_t = uint32_t; using key_t = uint64_t; using level_t = int32_t; using dist_t = float; using TopkHeap = ailego::KeyValueHeap; using CandidateHeap = ailego::KeyValueHeap>; constexpr node_id_t kInvalidNodeId = static_cast(-1); constexpr key_t kInvalidKey = static_cast(-1); /*! Index Format of Linear Index Header */ struct LinearIndexHeader { LinearIndexHeader() : header_size(0), total_vector_count(0), linear_body_size(0), linear_list_count(0), block_vector_count(0), block_size(0), block_count(0), index_meta_size(0) { memset(reserved_, 0, sizeof(reserved_)); memset(index_meta, 0, sizeof(index_meta)); } uint32_t header_size{0}; uint32_t total_vector_count{0}; uint64_t linear_body_size{0}; uint32_t linear_list_count{0}; uint32_t block_vector_count{0}; uint32_t block_size{0}; uint32_t block_count{0}; uint32_t index_meta_size{0}; char reserved_[28] = {0}; char index_meta[0]; }; /*! Index Format of Linear Index Meta for each Linear list */ struct LinearListMeta { LinearListMeta() : offset(0), block_count(0), vector_count(0), id_offset(0) { memset(reserved_, 0, sizeof(reserved_)); } uint64_t offset{0}; uint32_t block_count{0}; uint32_t vector_count{0}; uint32_t id_offset{0}; char reserved_[16] = {0}; }; /*! Index Format of Location in Linear Index for each vector */ struct LinearVecLocation { LinearVecLocation(size_t off, bool col) : offset(off), column_major(col), reserved(0u) {} uint64_t offset : 48; // feature offset in posting block segment uint64_t column_major : 1; // coloum major if true uint64_t reserved : 15; }; /*! Index Format of Integer Quantizer params for each linear list */ struct LinearIntegerQuantizerParams { float scale{1.0}; float bias{0.0}; }; /*! Location of Vectors Block in Storage Segment */ struct BlockLocation { uint32_t segment_id{0}; uint32_t block_index{0}; }; /*! The Header of a Block in Storage Segment */ struct BlockHeader { BlockHeader() : vector_count(0u), column_major(0u), reserved(0u) {} BlockLocation next; uint16_t vector_count{0}; uint16_t column_major : 1; uint16_t reserved : 15; }; struct DeletionMap { void set(uint32_t index) { bitset.set(index); } void reset(uint32_t index) { bitset.reset(index); } bool test(uint32_t index) const { return bitset.test(index); } bool is_dirty() const { return bitset.test_any(); } ailego::FixedBitset<32> bitset{}; }; static_assert(sizeof(DeletionMap) == 4, "DeletionMap must be 4 bytes"); /*! Meta Information of Streamer Entity */ struct StreamerLinearMeta { StreamerLinearMeta() : create_time(0), update_time(0), revision_id(0), segment_count(0), segment_size(0) { memset(reserved_, 0, sizeof(reserved_)); } uint64_t create_time{0}; uint64_t update_time{0}; uint64_t revision_id{0}; uint32_t segment_count{0}; uint32_t segment_size{0}; uint8_t reserved_[32] = {0}; LinearIndexHeader header; }; /*! Location of Vector in Storage Segment */ struct VectorLocation { //! Constructor VectorLocation(void) : segment_id(0u), column_major(0u), reserved(0u), offset(0u) {} //! Constructor VectorLocation(uint32_t id, bool col, uint32_t off) : segment_id(id), column_major(col), reserved(0u), offset(off) {} uint32_t segment_id{0}; uint16_t column_major : 1; uint16_t reserved : 15; uint32_t offset{0}; public: bool operator==(const VectorLocation &other) const { return segment_id == other.segment_id && column_major == other.column_major && offset == other.offset; } }; // static_assert(sizeof(VectorLocation) == sizeof(uint64_t), // "VectorLocation must be size of 8 bytes"); struct KeyInfo { KeyInfo(void) : centroid_idx(0u) {} KeyInfo(uint32_t idx, const VectorLocation &loc) : centroid_idx(idx), location(loc) {} KeyInfo(VectorLocation loc) : location(loc) {} uint32_t centroid_idx{0}; VectorLocation location; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_searcher.h" #include #include #include "flat_distance_matrix.h" #include "flat_searcher_context.h" #include "flat_searcher_provider.h" #include "flat_utility.h" namespace zvec { namespace core { template IndexProvider::Pointer FlatSearcher::create_provider(void) const { std::lock_guard lock(mapping_mutex_); if (mapping_.empty()) { auto mapping_segment = container_->get(FLAT_SEGMENT_MAPPING_SEG_ID); if (!mapping_segment) { LOG_ERROR("Failed to fetch segment %s", FLAT_SEGMENT_MAPPING_SEG_ID.c_str()); return nullptr; } if (mapping_segment->data_size() % sizeof(uint32_t) != 0) { LOG_ERROR("Invalid data size %zu of mapping segment", mapping_segment->data_size()); return nullptr; } size_t mapping_count = mapping_segment->data_size() / sizeof(uint32_t); if (mapping_count * meta_.element_size() != features_segment_->data_size()) { LOG_ERROR("Invalid data size %zd of mapping segment", features_segment_->data_size()); return nullptr; } const uint32_t *mapping = nullptr; if (mapping_segment->read(0, reinterpret_cast(&mapping), mapping_segment->data_size()) != mapping_segment->data_size()) { LOG_ERROR("Failed to read data (%zu bytes) from mapping segment", mapping_segment->data_size()); return nullptr; } mapping_.clear(); mapping_.reserve(mapping_count); std::copy(mapping, mapping + mapping_count, std::back_inserter(mapping_)); } return IndexProvider::Pointer(new (std::nothrow) FlatSearcherProvider(this)); } template int FlatSearcher::load(IndexStorage::Pointer cntr, IndexMetric::Pointer measure) { ailego::ElapsedTime stamp; if (!cntr) { return IndexError_InvalidArgument; } int error_code = IndexHelper::DeserializeFromStorage(cntr.get(), &meta_); if (error_code != 0) { LOG_ERROR( "Failed to deserialize index meta from container %s, error=%d, %s", cntr->name().c_str(), error_code, IndexError::What(error_code)); return error_code; } if (!measure) { error_code = InitializeMetric(meta_, &measure_); if (error_code != 0) { LOG_ERROR("Failed to initialize index measure %s, error=%d, %s", meta_.metric_name().c_str(), error_code, IndexError::What(error_code)); return error_code; } if (measure_->query_metric()) { measure_ = measure_->query_metric(); } } else { if (!measure->is_matched(meta_)) { LOG_ERROR( "The index measure is unmatched with index meta from container."); return IndexError_Mismatch; } measure_ = std::move(measure); } column_major_order_ = (meta_.major_order() == IndexMeta::MO_COLUMN); distance_matrix_.initialize(*measure_); if (column_major_order_) { if (!distance_matrix_.is_valid()) { LOG_ERROR("Lack of distance functions to support column index."); return IndexError_Unsupported; } } else { if (!distance_matrix_.is_valid(1, 1)) { LOG_ERROR("Lack of distance functions to support row index."); return IndexError_Unsupported; } } auto keys_segment = cntr->get(FLAT_SEGMENT_KEYS_SEG_ID); if (!keys_segment) { LOG_ERROR("Failed to fetch segment %s", FLAT_SEGMENT_KEYS_SEG_ID.c_str()); return IndexError_NoExist; } features_segment_ = cntr->get(FLAT_SEGMENT_FEATURES_SEG_ID); if (!features_segment_) { LOG_ERROR("Failed to fetch segment %s", FLAT_SEGMENT_KEYS_SEG_ID.c_str()); return IndexError_NoExist; } if (keys_segment->data_size() % sizeof(uint64_t) != 0) { LOG_ERROR("Invalid data size %zu of keys segment", keys_segment->data_size()); return IndexError_InvalidLength; } size_t keys_count = keys_segment->data_size() / sizeof(uint64_t); if (keys_count * meta_.element_size() != features_segment_->data_size()) { LOG_ERROR("Invalid data size %zd of features segment", features_segment_->data_size()); return IndexError_Mismatch; } if (keys_segment->read(0, reinterpret_cast(&keys_), keys_segment->data_size()) != keys_segment->data_size()) { LOG_ERROR("Failed to read data (%zu bytes) from keys segment", keys_segment->data_size()); return IndexError_ReadData; } for (size_t i = 0; i < keys_count; i++) { key_id_mapping_[keys_[i]] = i; } container_ = cntr; magic_ = IndexContext::GenerateMagic(); stats_.set_loaded_count(keys_count); stats_.set_loaded_costtime(stamp.milli_seconds()); return 0; } template int FlatSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { ailego_assert(query && !!context); ailego_assert(measure_->is_matched(meta_, qmeta)); FlatSearcherContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force searcher context"); return IndexError_InvalidArgument; } if (bf_context->magic() != magic_) { bf_context->reset(this); } if (bf_context->group_by_search()) { return bf_context->group_by_search_impl(query, qmeta, 1); } else { return (column_major_order_ ? bf_context->search_column(query, qmeta) : bf_context->search_row(query, qmeta)); } } template int FlatSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { ailego_assert(query && count && !!context); ailego_assert(measure_->is_matched(meta_, qmeta)); FlatSearcherContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force searcher context"); return IndexError_InvalidArgument; } if (bf_context->magic() != magic_) { bf_context->reset(this); } if (bf_context->group_by_search()) { return bf_context->group_by_search_impl(query, qmeta, count); } else { return (column_major_order_ ? bf_context->search_column(query, qmeta, count) : bf_context->search_row(query, qmeta, count)); } } template int FlatSearcher::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { ailego_assert(query && count && !!context); ailego_assert(measure_->is_matched(meta_, qmeta)); if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } FlatSearcherContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force searcher context"); return IndexError_InvalidArgument; } if (bf_context->magic() != magic_) { bf_context->reset(this); } return bf_context->search_bf_by_p_keys_impl(query, p_keys, qmeta, count); } template IndexSearcher::Context::Pointer FlatSearcher::create_context( void) const { return IndexSearcher::Context::Pointer( new FlatSearcherContext(this)); } INDEX_FACTORY_REGISTER_SEARCHER_ALIAS(LinearSearcher, FlatSearcher<32>); INDEX_FACTORY_REGISTER_SEARCHER_ALIAS(FlatSearcher, FlatSearcher<32>); INDEX_FACTORY_REGISTER_SEARCHER_ALIAS(FlatSearcher16, FlatSearcher<16>); INDEX_FACTORY_REGISTER_SEARCHER_ALIAS(FlatSearcher32, FlatSearcher<32>); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "flat_distance_matrix.h" #include "flat_index_format.h" namespace zvec { namespace core { /*! Flat Searcher */ template class FlatSearcher : public IndexSearcher { public: //! Destructor virtual ~FlatSearcher(void) = default; //! Initialize Searcher int init(const ailego::Params &index_params) override { params_ = index_params; read_block_size_ = FLAT_DEFAULT_READ_BLOCK_SIZE; index_params.get(PARAM_FLAT_READ_BLOCK_SIZE, &read_block_size_); return 0; } //! Cleanup Searcher int cleanup(void) override { return this->unload(); } //! Load index from container int load(IndexStorage::Pointer cntr, IndexMetric::Pointer measure) override; //! Unload index int unload(void) override { container_ = nullptr; measure_ = nullptr; features_segment_ = nullptr; keys_ = nullptr; key_id_mapping_.clear(); return 0; } //! Similarity brute force search int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return this->search_impl(query, qmeta, context); } //! Similarity brute force search int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override { return this->search_impl(query, qmeta, count, context); } //! Similarity search int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys int search_bf_by_p_keys_impl(const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Retrieve statistics const IndexSearcher::Stats &stats(void) const override { return stats_; } //! Retrieve meta of index const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index const ailego::Params ¶ms(void) const override { return params_; } //! Create a searcher context IndexSearcher::Context::Pointer create_context(void) const override; //! Create a searcher provider IndexProvider::Pointer create_provider(void) const override; //! Retrieve magic number uint32_t magic(void) const { return magic_; } //! Retrieve block size of data read uint32_t read_block_size(void) const { return read_block_size_; } //! Retrieve primary key via index id uint64_t key(size_t i) const { return keys_[i]; } // Retrieve index id via primary key node_id_t get_id(key_t key) const { auto it = key_id_mapping_.find(key); if (it != key_id_mapping_.end()) { return it->second; } else { return kInvalidNodeId; } } //! Retrieve primary key via index id uint32_t local_index(size_t i) const { return mapping_[i]; } //! Retrieve primary key via index id inline bool column_major_order(void) const { return column_major_order_; } //! Retrieve the distance matrix const FlatDistanceMatrix &distance_matrix(void) const { return distance_matrix_; } //! Clone a features segment IndexStorage::Segment::Pointer clone_features_segment(void) const { return features_segment_->clone(); } const void *get_vector(key_t key) const override { auto provider = this->create_provider(); return provider->get_vector(key); } private: //! Members const uint64_t *keys_{nullptr}; std::unordered_map key_id_mapping_; uint32_t magic_{IndexContext::GenerateMagic()}; uint32_t read_block_size_{FLAT_DEFAULT_READ_BLOCK_SIZE}; bool column_major_order_{false}; IndexMeta meta_{}; IndexStorage::Pointer container_{}; IndexMetric::Pointer measure_{}; ailego::Params params_{}; IndexStorage::Segment::Pointer features_segment_{}; mutable std::vector mapping_{}; mutable std::mutex mapping_mutex_{}; FlatDistanceMatrix distance_matrix_{}; IndexSearcher::Stats stats_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_searcher_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "flat_index_format.h" #include "flat_searcher.h" #include "flat_utility.h" namespace zvec { namespace core { /*! Brute Force Searcher Context */ template class FlatSearcherContext : public IndexSearcher::Context { public: //! Constructor FlatSearcherContext(const FlatSearcher *owner) { this->reset(owner); } //! Destructor virtual ~FlatSearcherContext(void) {} //! Set topk of search result void set_topk(uint32_t topk) override { topk_ = topk; } //! Retrieve search result const IndexDocumentList &result(void) const override { return result_heaps_.at(0); } //! Retrieve search result with index const IndexDocumentList &result(size_t index) const override { return result_heaps_.at(index); } //! Retrieve result object for output IndexDocumentList *mutable_result(size_t idx) override { return &result_heaps_.at(idx); } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result( size_t idx) const override { return group_results_[idx]; } //! Update the parameters of context int update(const ailego::Params & /*params*/) override { return 0; } //! Retrieve magic number uint32_t magic(void) const override { return magic_; } //! Get group topk inline uint32_t group_topk() const { return group_topk_; } //! Get group num inline uint32_t group_num() const { return group_num_; } inline std::map &group_topk_heaps() { return group_topk_heaps_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector() const override { return fetch_vector_; } inline void resize_group_results(size_t size) { if (group_by_search()) { group_results_.resize(size); } } void topk_to_group_result(uint32_t idx) { ailego_assert_with(idx < group_results_.size(), "invalid idx"); group_results_[idx].clear(); std::vector> group_topk_list; std::vector> best_score_in_groups; for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; heap.sort(); if (heap.size() > 0) { float best_score = heap[0].second; best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; group_topk_list.emplace_back( std::make_pair(group_id, group_topk_heaps_[group_id])); } group_results_[idx].resize(group_topk_list.size()); for (uint32_t i = 0; i < group_topk_list.size(); ++i) { const std::string &group_id = group_topk_list[i].first; group_results_[idx][i].set_group_id(group_id); uint32_t size = std::min( group_topk_, static_cast(group_topk_list[i].second.size())); for (uint32_t j = 0; j < size; ++j) { auto score = group_topk_list[i].second[j].second; if (score > this->threshold()) { break; } node_id_t id = group_topk_list[i].second[j].first; auto provider = owner_->create_provider(); if (fetch_vector_) { group_results_[idx][i].mutable_docs()->emplace_back( id, score, id, provider->get_vector(id)); } else { group_results_[idx][i].mutable_docs()->emplace_back(id, score, id); } } } } //! Get if group by search bool group_by_search() { return group_num_ > 0; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; group_topk_heaps_.clear(); } void reset() override {} //! Reset the context void reset(const FlatSearcher *owner) { magic_ = owner->magic(); feature_size_ = owner->meta().element_size(); uint32_t block_size = feature_size_ * BATCH_SIZE; actual_read_size_ = (owner->read_block_size() + block_size - 1) / block_size * block_size; features_segment_ = owner->clone_features_segment(); owner_ = owner; } //! Similarity search int search_row(const void *query, const IndexQueryMeta &qmeta) { return (this->filter().is_valid() ? this->search_row_filter(query, qmeta) : this->search_row_nofilter(query, qmeta)); } //! Similarity search int search_row(const void *query, const IndexQueryMeta &qmeta, size_t count) { return (this->filter().is_valid() ? this->batch_search_row_filter(query, qmeta, count) : this->batch_search_row_nofilter(query, qmeta, count)); } //! Similarity search int search_column(const void *query, const IndexQueryMeta &qmeta) { return (this->filter().is_valid() ? this->search_column_filter(query, qmeta) : this->search_column_nofilter(query, qmeta)); } //! Similarity search int search_column(const void *query, const IndexQueryMeta &qmeta, size_t count) { return (this->filter().is_valid() ? this->batch_search_column_filter(query, qmeta, count) : this->batch_search_column_nofilter(query, qmeta, count)); } int group_by_search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count); int search_bf_by_p_keys_impl(const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count); protected: //! Enqueue items into the search heaps (without filter) template auto batch_enqueue_nofilter(const void *block, size_t block_index, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::value>::type { size_t query_batch_count = query_count / K; for (size_t i = 0; i != query_batch_count; ++i) { owner_->distance_matrix().template distance( block, &batch_queries_[query_index * qmeta.element_size()], qmeta.dimension(), scores_); for (size_t k = 0; k != K; ++k) { IndexDocumentHeap *heap = &result_heaps_[query_index++]; for (size_t j = 0; j != BATCH_SIZE; ++j) { heap->emplace(0, scores_[k * BATCH_SIZE + j], block_index + j); } } // end of for } // end of for size_t query_left_count = query_count % K; if (query_left_count != 0) { this->batch_enqueue_nofilter<(K >> 1)>(block, block_index, query_index, qmeta, query_left_count); } } //! Enqueue items into the search heaps (without filter) template auto batch_enqueue_nofilter(const void *block, size_t block_index, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::type { ailego_assert(query_count == 1); (void)query_count; owner_->distance_matrix().template distance( block, &batch_queries_[query_index * qmeta.element_size()], qmeta.dimension(), scores_); IndexDocumentHeap *heap = &result_heaps_[query_index]; for (size_t i = 0; i != BATCH_SIZE; ++i) { heap->emplace(0, scores_[i], block_index + i); } } //! Enqueue items into the search heaps (with filter) template auto batch_enqueue_filter(const void *block, size_t block_index, size_t block_mask, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::value>::type { size_t query_batch_count = query_count / K; for (size_t i = 0; i != query_batch_count; ++i) { owner_->distance_matrix().template distance( block, &batch_queries_[query_index * qmeta.element_size()], qmeta.dimension(), scores_); for (size_t k = 0; k != K; ++k) { IndexDocumentHeap *heap = &result_heaps_[query_index++]; for (size_t j = 0; j != BATCH_SIZE; ++j) { if ((block_mask & (1 << j)) != 0) { heap->emplace(0, scores_[k * BATCH_SIZE + j], block_index + j); } } } // end of for } // end of for size_t query_left_count = query_count % K; if (query_left_count != 0) { this->batch_enqueue_filter<(K >> 1)>( block, block_index, block_mask, query_index, qmeta, query_left_count); } } //! Enqueue items into the search heaps (with filter) template auto batch_enqueue_filter(const void *block, size_t block_index, size_t block_mask, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::type { ailego_assert(query_count == 1); (void)query_count; owner_->distance_matrix().template distance( block, &batch_queries_[query_index * qmeta.element_size()], qmeta.dimension(), scores_); IndexDocumentHeap *heap = &result_heaps_[query_index]; for (size_t i = 0; i != BATCH_SIZE; ++i) { if ((block_mask & (1 << i)) != 0) { heap->emplace(0, scores_[i], block_index + i); } } } //! Enqueue items into the search heaps (without filter) template auto single_enqueue_nofilter(const void *feature, size_t feature_index, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::value>::type { size_t query_batch_count = query_count / K; for (size_t i = 0; i != query_batch_count; ++i) { owner_->distance_matrix().template distance( &batch_queries_[query_index * qmeta.element_size()], feature, qmeta.dimension(), scores_); for (size_t k = 0; k != K; ++k) { result_heaps_[query_index++].emplace(0, scores_[k], feature_index); } } size_t query_left_count = query_count % K; if (query_left_count != 0) { this->single_enqueue_nofilter<(K >> 1)>( feature, feature_index, query_index, qmeta, query_left_count); } } //! Enqueue items into the search heaps (without filter) template auto single_enqueue_nofilter(const void *feature, size_t feature_index, size_t query_index, const IndexQueryMeta &qmeta, size_t query_count) -> typename std::enable_if::type { ailego_assert(query_count == 1); (void)query_count; owner_->distance_matrix().template distance<1>( feature, &batch_queries_[query_index * qmeta.element_size()], qmeta.dimension(), scores_); result_heaps_[query_index].emplace(0, scores_[0], feature_index); } protected: //! Similarity search (1 column without filter) int search_column_nofilter(const void *query, const IndexQueryMeta &qmeta); //! Similarity search (1 column with filter) int search_column_filter(const void *query, const IndexQueryMeta &qmeta); //! Similarity search (1 row without filter) int search_row_nofilter(const void *query, const IndexQueryMeta &qmeta); //! Similarity search (1 row with filter) int search_row_filter(const void *query, const IndexQueryMeta &qmeta); //! Similarity search (n columns without filter) int batch_search_column_nofilter(const void *query, const IndexQueryMeta &qmeta, size_t query_count); //! Similarity search (n columns with filter) int batch_search_column_filter(const void *query, const IndexQueryMeta &qmeta, size_t query_count); //! Similarity search (n rows without filter) int batch_search_row_nofilter(const void *query, const IndexQueryMeta &qmeta, size_t query_count); //! Similarity search (n rows with filter) int batch_search_row_filter(const void *query, const IndexQueryMeta &qmeta, size_t query_count); private: const FlatSearcher *owner_{nullptr}; uint32_t magic_{0}; uint32_t topk_{0}; uint32_t feature_size_{0}; uint32_t actual_read_size_{0}; IndexStorage::Segment::Pointer features_segment_{}; std::vector result_heaps_{1}; std::string batch_queries_{}; float scores_[BATCH_SIZE * BATCH_SIZE]; bool fetch_vector_{false}; // group uint32_t group_num_{0}, group_topk_{0}; std::map group_topk_heaps_{}; std::vector group_results_{}; }; template int FlatSearcherContext::search_column_nofilter( const void *query, const IndexQueryMeta &qmeta) { IndexDocumentHeap *heap = &result_heaps_[0]; heap->clear(); heap->limit(topk_); heap->set_threshold(this->threshold()); size_t left_size = features_segment_->data_size(); size_t block_size = feature_size_ * BATCH_SIZE; size_t read_offset = 0; size_t feature_index = 0; auto matrix = this->owner_->distance_matrix(); while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += block_size) { matrix.template distance( (const char *)data + offset, query, qmeta.dimension(), scores_); for (size_t i = 0; i != BATCH_SIZE; ++i) { heap->emplace(0, scores_[i], feature_index++); } } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left block features size_t left_size_aligned = left_size / block_size * block_size; for (size_t offset = 0; offset != left_size_aligned; offset += block_size) { matrix.template distance((const char *)data + offset, query, qmeta.dimension(), scores_); for (size_t i = 0; i != BATCH_SIZE; ++i) { heap->emplace(0, scores_[i], feature_index++); } } // Process left single features for (size_t offset = left_size_aligned; offset < left_size; offset += feature_size_) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(0, score, feature_index++); } for (auto &it : *heap) { it.set_key(owner_->key(it.index())); } heap->sort(); return 0; } template int FlatSearcherContext::search_column_filter( const void *query, const IndexQueryMeta &qmeta) { IndexDocumentHeap *heap = &result_heaps_[0]; heap->clear(); heap->limit(topk_); heap->set_threshold(this->threshold()); size_t left_size = features_segment_->data_size(); size_t block_size = feature_size_ * BATCH_SIZE; size_t read_offset = 0; size_t feature_index = 0; auto matrix = owner_->distance_matrix(); while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += block_size) { matrix.template distance( (const char *)data + offset, query, qmeta.dimension(), scores_); for (size_t i = 0; i != BATCH_SIZE; ++i) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { if (group_by_search()) { } heap->emplace(feature_key, scores_[i], feature_index); } feature_index += 1; } } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left block features size_t left_size_aligned = left_size / block_size * block_size; for (size_t offset = 0; offset != left_size_aligned; offset += block_size) { matrix.template distance((const char *)data + offset, query, qmeta.dimension(), scores_); for (size_t i = 0; i != BATCH_SIZE; ++i) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { heap->emplace(feature_key, scores_[i], feature_index); } feature_index += 1; } } // Process left single features for (size_t offset = left_size_aligned; offset < left_size; offset += feature_size_) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(feature_key, score, feature_index); } feature_index += 1; } heap->sort(); return 0; } template int FlatSearcherContext::search_row_nofilter( const void *query, const IndexQueryMeta &qmeta) { IndexDocumentHeap *heap = &result_heaps_[0]; heap->clear(); heap->limit(topk_); heap->set_threshold(this->threshold()); size_t left_size = features_segment_->data_size(); size_t read_offset = 0; size_t feature_index = 0; auto matrix = owner_->distance_matrix(); while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += feature_size_) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(0, score, feature_index++); } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } for (size_t offset = 0; offset < left_size; offset += feature_size_) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(0, score, feature_index++); } for (auto &it : *heap) { it.set_key(owner_->key(it.index())); } heap->sort(); return 0; } template int FlatSearcherContext::search_row_filter( const void *query, const IndexQueryMeta &qmeta) { IndexDocumentHeap *heap = &result_heaps_[0]; heap->clear(); heap->limit(topk_); heap->set_threshold(this->threshold()); size_t left_size = features_segment_->data_size(); size_t read_offset = 0; size_t feature_index = 0; auto matrix = owner_->distance_matrix(); while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += feature_size_) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(feature_key, score, feature_index); } feature_index += 1; } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } for (size_t offset = 0; offset < left_size; offset += feature_size_) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { float score; matrix.template distance<1>((const char *)data + offset, query, qmeta.dimension(), &score); heap->emplace(feature_key, score, feature_index); } feature_index += 1; } heap->sort(); return 0; } template int FlatSearcherContext::batch_search_column_nofilter( const void *query, const IndexQueryMeta &qmeta, size_t query_count) { // Initialize resources result_heaps_.resize(query_count); for (auto &heap : result_heaps_) { heap.clear(); heap.limit(topk_); heap.set_threshold(this->threshold()); } // Transpose queries batch_queries_.clear(); batch_queries_.reserve(query_count * qmeta.element_size()); TransposeQueries(query, qmeta, query_count, &batch_queries_); size_t left_size = features_segment_->data_size(); size_t block_size = feature_size_ * BATCH_SIZE; size_t read_offset = 0; size_t block_index = 0; // Process feature blocks while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += block_size) { this->batch_enqueue_nofilter( (const char *)data + offset, block_index, 0, qmeta, query_count); block_index += BATCH_SIZE; } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left block features size_t left_size_aligned = left_size / block_size * block_size; for (size_t offset = 0; offset != left_size_aligned; offset += block_size) { this->batch_enqueue_nofilter( (const char *)data + offset, block_index, 0, qmeta, query_count); block_index += BATCH_SIZE; } // Process left single features for (size_t offset = left_size_aligned; offset < left_size; offset += feature_size_) { this->single_enqueue_nofilter( (const char *)data + offset, block_index, 0, qmeta, query_count); block_index += 1; } // Normalize results for (auto &heap : result_heaps_) { for (auto &it : heap) { it.set_key(owner_->key(it.index())); } heap.sort(); } return 0; } template int FlatSearcherContext::batch_search_column_filter( const void *query, const IndexQueryMeta &qmeta, size_t query_count) { // Initialize resources result_heaps_.resize(query_count); for (auto &heap : result_heaps_) { heap.clear(); heap.limit(topk_); heap.set_threshold(this->threshold()); } // Transpose queries batch_queries_.clear(); batch_queries_.reserve(query_count * qmeta.element_size()); TransposeQueries(query, qmeta, query_count, &batch_queries_); size_t left_size = features_segment_->data_size(); size_t block_size = feature_size_ * BATCH_SIZE; size_t read_offset = 0; size_t block_index = 0; // Process feature blocks while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += block_size) { size_t block_mask = 0; for (size_t i = 0; i != BATCH_SIZE; ++i) { if (!this->filter()(this->owner_->key(block_index + i))) { block_mask |= (1 << i); } } if (block_mask != 0) { this->batch_enqueue_filter((const char *)data + offset, block_index, block_mask, 0, qmeta, query_count); } block_index += BATCH_SIZE; } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left block features size_t left_size_aligned = left_size / block_size * block_size; for (size_t offset = 0; offset != left_size_aligned; offset += block_size) { size_t block_mask = 0; for (size_t i = 0; i != BATCH_SIZE; ++i) { if (!this->filter()(this->owner_->key(block_index + i))) { block_mask |= (1 << i); } } if (block_mask != 0) { this->batch_enqueue_filter((const char *)data + offset, block_index, block_mask, 0, qmeta, query_count); } block_index += BATCH_SIZE; } // Process left single features for (size_t offset = left_size_aligned; offset < left_size; offset += feature_size_) { if (!this->filter()(owner_->key(block_index))) { this->single_enqueue_nofilter( (const char *)data + offset, block_index, 0, qmeta, query_count); } block_index += 1; } // Normalize results for (auto &heap : result_heaps_) { for (auto &it : heap) { it.set_key(owner_->key(it.index())); } heap.sort(); } return 0; } template int FlatSearcherContext::batch_search_row_nofilter( const void *query, const IndexQueryMeta &qmeta, size_t query_count) { // Initialize resources result_heaps_.resize(query_count); for (auto &heap : result_heaps_) { heap.clear(); heap.limit(topk_); heap.set_threshold(this->threshold()); } size_t left_size = features_segment_->data_size(); size_t read_offset = 0; size_t feature_index = 0; auto matrix = owner_->distance_matrix(); // Process feature blocks while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += feature_size_) { size_t query_offset = 0; const void *feature = (const char *)data + offset; for (auto &heap : result_heaps_) { float score; matrix.template distance<1>(feature, (const char *)query + query_offset, qmeta.dimension(), &score); heap.emplace(0, score, feature_index); query_offset += qmeta.element_size(); } feature_index += 1; } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left features for (size_t offset = 0; offset < left_size; offset += feature_size_) { size_t query_offset = 0; const void *feature = (const char *)data + offset; for (auto &heap : result_heaps_) { float score; matrix.template distance<1>(feature, (const char *)query + query_offset, qmeta.dimension(), &score); heap.emplace(0, score, feature_index); query_offset += qmeta.element_size(); } feature_index += 1; } // Normalize results for (auto &heap : result_heaps_) { for (auto &it : heap) { it.set_key(owner_->key(it.index())); } heap.sort(); } return 0; } template int FlatSearcherContext::batch_search_row_filter( const void *query, const IndexQueryMeta &qmeta, size_t query_count) { // Initialize resources result_heaps_.resize(query_count); for (auto &heap : result_heaps_) { heap.clear(); heap.limit(topk_); heap.set_threshold(this->threshold()); } size_t left_size = features_segment_->data_size(); size_t read_offset = 0; size_t feature_index = 0; auto matrix = owner_->distance_matrix(); // Process feature blocks while (left_size >= actual_read_size_) { const void *data = nullptr; if (features_segment_->read(read_offset, &data, actual_read_size_) != actual_read_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", actual_read_size_); return IndexError_ReadData; } for (size_t offset = 0; offset < actual_read_size_; offset += feature_size_) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { size_t query_offset = 0; const void *feature = (const char *)data + offset; for (auto &heap : result_heaps_) { float score; matrix.template distance<1>(feature, (const char *)query + query_offset, qmeta.dimension(), &score); heap.emplace(feature_key, score, feature_index); query_offset += qmeta.element_size(); } } feature_index += 1; } read_offset += actual_read_size_; left_size -= actual_read_size_; } const void *data = nullptr; if (features_segment_->read(read_offset, &data, left_size) != left_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", left_size); return IndexError_ReadData; } // Process left features for (size_t offset = 0; offset < left_size; offset += feature_size_) { uint64_t feature_key = owner_->key(feature_index); if (!this->filter()(feature_key)) { size_t query_offset = 0; const void *feature = (const char *)data + offset; for (auto &heap : result_heaps_) { float score; matrix.template distance<1>(feature, (const char *)query + query_offset, qmeta.dimension(), &score); heap.emplace(feature_key, score, feature_index); query_offset += qmeta.element_size(); } } feature_index += 1; } // Normalize results for (auto &heap : result_heaps_) { heap.sort(); } return 0; } template int FlatSearcherContext::group_by_search_impl( const void *query, const IndexQueryMeta &qmeta, uint32_t count) { this->resize_group_results(count); if (!this->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](uint64_t key) { return this->group_by()(key); }; auto provider = owner_->create_provider(); for (size_t q = 0; q < count; ++q) { this->group_topk_heaps().clear(); for (node_id_t id = 0; id < provider->count(); ++id) { if (!this->filter().is_valid() || !this->filter()(owner_->key(id))) { dist_t dist = 0; owner_->distance_matrix().template distance<1>( query, provider->get_vector(owner_->key(id)), provider->dimension(), &dist); std::string group_id = group_by(owner_->key(id)); auto &topk_heap = this->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(this->group_topk()); } topk_heap.emplace(id, dist); } } this->topk_to_group_result(q); query = static_cast(query) + qmeta.element_size(); } return 0; } template int FlatSearcherContext::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count) { auto provider = owner_->create_provider(); if (this->group_by_search()) { this->resize_group_results(count); if (!this->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](uint64_t key) { return this->group_by()(key); }; for (size_t q = 0; q < count; ++q) { this->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!this->filter().is_valid() || !this->filter()(pk)) { dist_t dist = 0; owner_->distance_matrix().template distance<1>( query, provider->get_vector(pk), provider->dimension(), &dist); std::string group_id = group_by(pk); auto &topk_heap = this->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(this->group_topk()); } topk_heap.emplace(owner_->get_id(pk), dist); } } this->topk_to_group_result(q); query = static_cast(query) + qmeta.element_size(); } } else { result_heaps_.resize(count); for (auto &heap : result_heaps_) { heap.clear(); heap.limit(topk_); heap.set_threshold(this->threshold()); } for (size_t q = 0; q < count; ++q) { for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!this->filter().is_valid() || !this->filter()(pk)) { dist_t dist = 0; owner_->distance_matrix().template distance<1>( query, provider->get_vector(pk), provider->dimension(), &dist); result_heaps_[q].emplace(pk, dist, owner_->get_id(pk)); } } query = static_cast(query) + qmeta.element_size(); } for (auto &heap : result_heaps_) { heap.sort(); } } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_searcher_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "flat_distance_matrix.h" #include "flat_searcher.h" // #include "flat_streamer.h" #include "flat_utility.h" namespace zvec { namespace core { /*! Brute Force Searcher Provider */ template class FlatSearcherProvider : public IndexProvider { public: //! Constructor FlatSearcherProvider(const FlatSearcher *owner) { feature_size_ = owner->meta().element_size(); features_segment_ = owner->clone_features_segment(); total_vector_count_ = features_segment_->data_size() / owner->meta().element_size(); owner_ = owner; block_buffer_.resize(BATCH_SIZE * feature_size_); } //! Create a new iterator IndexProvider::Iterator::Pointer create_iterator(void) override { return IndexProvider::Iterator::Pointer( new (std::nothrow) FlatSearcherProvider::Iterator(owner_)); } //! Retrieve count of vectors size_t count(void) const override { return total_vector_count_; } //! Retrieve dimension of vector size_t dimension(void) const override { return owner_->meta().dimension(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return owner_->meta().data_type(); } //! Retrieve vector size in bytes size_t element_size(void) const override { return owner_->meta().element_size(); } //! Retrieve a vector using a primary key const void *get_vector(uint64_t key) const override { return this->get_vector_by_index(owner_->get_id(key)); } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_->name(); } protected: /*! Brute Force Provider Iterator */ class Iterator : public IndexProvider::Iterator { public: //! Constructor Iterator(const FlatSearcher *owner) { block_buffer_.resize(BATCH_SIZE * owner->meta().element_size()); feature_size_ = owner->meta().element_size(); features_segment_ = owner->clone_features_segment(); total_vector_count_ = features_segment_->data_size() / owner->meta().element_size(); owner_ = owner; cursor_index_ = 0; offset_ = 0; this->next_block(); } //! Retrieve pointer of data //! NOTICE: the vec feature will be changed after iterating to next, so //! the caller need to keep a copy of it before iterator to next vector const void *data(void) const override { return data_; } //! Test if the iterator is valid bool is_valid(void) const override { return (!invalid_ && cursor_index_ < total_vector_count_); } //! Retrieve primary key uint64_t key(void) const override { return owner_->key(cursor_index_); } //! Next iterator void next(void) override { ++cursor_index_; if ((cursor_index_ % BATCH_SIZE) != 0) { data_ += feature_size_; } else { this->next_block(); } } protected: //! Read a block of data void next_block(void) { const void *read_data = nullptr; size_t read_size = 0; if (cursor_index_ >= total_vector_count_) { invalid_ = true; return; } if (cursor_index_ + BATCH_SIZE < total_vector_count_) { read_size = BATCH_SIZE * feature_size_; } else { read_size = (total_vector_count_ - cursor_index_) * feature_size_; } if (features_segment_->read(offset_, &read_data, read_size) != read_size) { LOG_ERROR("Failed to read data (%zu bytes) from features segment", read_size); invalid_ = true; return; } offset_ += read_size; // The order of data may be a column format, convert it to the row format. if (owner_->column_major_order() && read_size == BATCH_SIZE * feature_size_) { uint32_t align_size = IndexMeta::AlignSizeof(owner_->meta().data_type()); ReverseTranspose(align_size, read_data, feature_size_ / align_size, &block_buffer_[0]); data_ = block_buffer_.data(); } else { data_ = reinterpret_cast(read_data); } } private: const FlatSearcher *owner_{nullptr}; IndexStorage::Segment::Pointer features_segment_{}; uint32_t total_vector_count_{0}; uint32_t feature_size_{0}; std::vector block_buffer_{}; const uint8_t *data_{nullptr}; uint64_t offset_{0}; uint32_t cursor_index_{0}; bool invalid_{false}; }; //! Retrieve a vector via local index const void *get_vector_by_index(uint32_t index) const { const void *read_data = nullptr; if (index == kInvalidNodeId) { LOG_ERROR("Failed to get vector by Invalid Id."); return nullptr; } if (owner_->column_major_order() && index < (total_vector_count_ - (total_vector_count_ % BATCH_SIZE))) { uint32_t block_size = feature_size_ * BATCH_SIZE; uint64_t offset = (index - (index % BATCH_SIZE)) * feature_size_; if (features_segment_->read(offset, &read_data, block_size) != block_size) { LOG_ERROR("Failed to read data (%u bytes) from features segment", block_size); return nullptr; } uint32_t align_size = IndexMeta::AlignSizeof(owner_->meta().data_type()); ReverseTranspose( align_size, read_data, feature_size_ / align_size, &block_buffer_[0]); read_data = block_buffer_.data() + ((index % BATCH_SIZE) * feature_size_); } else { if (features_segment_->read(index * feature_size_, &read_data, feature_size_) != feature_size_) { LOG_ERROR("Failed to read data (%u bytes) from features segment", feature_size_); return nullptr; } } return read_data; } private: //! Members const FlatSearcher *owner_{nullptr}; IndexStorage::Segment::Pointer features_segment_{}; uint32_t feature_size_{0}; uint32_t total_vector_count_{0}; mutable std::vector block_buffer_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_streamer.h" #include #include "flat_streamer_context.h" #include "flat_streamer_dumper.h" #include "flat_streamer_provider.h" namespace zvec { namespace core { #define WRITE_LOCK_GUARD(MUTEX, LOCK_NAME) \ ailego::WriteLock write_lock(MUTEX); \ std::unique_lock LOCK_NAME(write_lock); #define READ_LOCK_GUARD_DEFER(MUTEX, LOCK_NAME) \ ailego::ReadLock read_lock(MUTEX); \ std::unique_lock LOCK_NAME(read_lock, std::defer_lock); template FlatStreamer::FlatStreamer() : entity_(stats_) {} template FlatStreamer::~FlatStreamer() { if (state_ == STATE_INITED) { this->cleanup(); } } template int FlatStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; meta_.set_streamer("FlatStreamer", 0U, params); int error_code = InitializeMetric(meta_, &metric_); if (error_code != 0) { LOG_ERROR("Failed to initialize index metric %s, error=%d, %s", meta_.metric_name().c_str(), error_code, IndexError::What(error_code)); return error_code; } if (metric_->query_metric()) { metric_ = metric_->query_metric(); } // 参数设置 if (params.get(PARAM_FLAT_COLUMN_MAJOR_ORDER, &column_major_order_)) { meta_.set_major_order(column_major_order_ ? IndexMeta::MO_COLUMN : IndexMeta::MO_ROW); } // Verify column major order if (meta_.major_order() != IndexMeta::MO_ROW) { IndexMeta::DataType ft = meta_.data_type(); bool support_column_major = true; if ((ft != IndexMeta::DT_FP32 && ft != IndexMeta::DT_FP16 && ft != IndexMeta::DT_INT8 && ft != IndexMeta::DT_INT4 && ft != IndexMeta::DT_BINARY32 && ft != IndexMeta::DT_BINARY64) || (meta_.unit_size() != IndexMeta::UnitSizeof(ft))) { if (meta_.major_order() == IndexMeta::MO_COLUMN) { LOG_ERROR("Unsupported type %d with unit size %u.", ft, meta_.unit_size()); return IndexError_Unsupported; } else { support_column_major = false; } } if (meta_.element_size() % IndexMeta::AlignSizeof(ft) != 0) { if (meta_.major_order() == IndexMeta::MO_COLUMN) { LOG_ERROR("Unsupported type %d with dimension %u.", ft, meta_.dimension()); return IndexError_Unsupported; } else { support_column_major = false; } } if (meta_.major_order() == IndexMeta::MO_UNDEFINED && support_column_major) { meta_.set_major_order(IndexMeta::MO_ROW); } } if (!VerifyMetric(meta_)) { LOG_ERROR("Invalid index metric %s.", meta_.metric_name().c_str()); return IndexError_InvalidArgument; } read_block_size_ = FLAT_DEFAULT_READ_BLOCK_SIZE; params.get(PARAM_FLAT_READ_BLOCK_SIZE, &read_block_size_); params.get(PARAM_FLAT_USE_ID_MAP, &use_key_info_map_); // entity init uint32_t block_vector_count = kDefaultBlockVecCount; uint32_t segment_size = kDefaultSegmentSize; bool filter_same_key = true; entity_.set_block_vector_count(block_vector_count); entity_.set_segment_size(segment_size); entity_.enable_filter_same_key(filter_same_key); entity_.set_linear_list_count(1); entity_.set_use_key_info_map(use_key_info_map_); *entity_.mutable_meta() = meta_; state_ = STATE_INITED; return 0; } template int FlatStreamer::cleanup() { if (state_ == STATE_OPENED) { this->close(); } LOG_DEBUG("FlatStreamer cleanup"); state_ = STATE_INIT; return 0; } template int FlatStreamer::open(IndexStorage::Pointer stg) { if (!stg) { LOG_ERROR("Failed to open for invalid storage"); return IndexError_InvalidArgument; } if (ailego_unlikely(state_ != STATE_INITED)) { LOG_ERROR("Open storage failed, init streamer first!"); return IndexError_NoReady; } LOG_DEBUG("FlatStreamer open with %s", stg->name().c_str()); int ret = entity_.open(std::move(stg), meta_); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to open storage"); return ret; } magic_ = IndexContext::GenerateMagic(); state_ = STATE_OPENED; return 0; } template int FlatStreamer::close(void) { LOG_DEBUG("FlatStreamer close"); entity_.flush_linear_meta(); stats_.clear(); int ret = entity_.close(); if (ailego_unlikely(ret != 0)) { return ret; } state_ = STATE_INITED; return 0; } template int FlatStreamer::flush(uint64_t checkpoint) { LOG_INFO("FlatStreamer flush with checkpoint %zu", (size_t)checkpoint); return entity_.flush(checkpoint); } template int FlatStreamer::dump(const IndexDumper::Pointer &dumper) { std::string searcher_name = "FlatSearcher"; if (BATCH_SIZE == 16) { searcher_name = "FlatSearcher16"; } meta_.set_searcher(searcher_name, 0U, ailego::Params()); WRITE_LOCK_GUARD(dump_mutex_, dump_lock); std::shared_ptr> bf_dumper = std::make_shared>(this); int ret = bf_dumper->dump(dumper); *(stats_.mutable_dumped_size()) += bf_dumper->dump_size(); return ret; } template IndexStreamer::Context::UPointer FlatStreamer::create_context( void) const { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to create Context, open storage first!"); return Context::UPointer(); } return IndexStreamer::Context::Pointer( new FlatStreamerContext(this)); } template IndexProvider::Pointer FlatStreamer::create_provider(void) const { return IndexProvider::Pointer(new (std::nothrow) FlatStreamerProvider(this)); } template int FlatStreamer::add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, Context::UPointer &context) { if (!query || qmeta.dimension() != meta_.dimension() || qmeta.data_type() != meta_.data_type() || qmeta.element_size() != meta_.element_size()) { LOG_ERROR( "Failed to add for invalid arguments, query=%p, qmeta(type=%u " "dim=%u size=%u) vs meta(type=%u dim=%u size=%u)", query, qmeta.data_type(), qmeta.dimension(), qmeta.element_size(), meta_.data_type(), meta_.dimension(), meta_.element_size()); (*stats_.mutable_discarded_count())++; return IndexError_InvalidArgument; } auto *ctx = dynamic_cast *>(context.get()); if (!ctx) { LOG_ERROR("Failed to cast FlatStreamerContext"); (*stats_.mutable_discarded_count())++; return IndexError_Cast; } READ_LOCK_GUARD_DEFER(dump_mutex_, dump_lock); if (!dump_lock.try_lock()) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } // IndexQueryMeta iv_qmeta; // int ret = entity_.convert(query, qmeta, &query, &iv_qmeta); // if (ret != 0) { // LOG_ERROR("Failed to convert record for %s", // IndexError::What(ret)); // (*stats_.mutable_discarded_count())++; // return ret; // } int ret = entity_.add(pkey, query, qmeta.element_size()); if (ret != 0) { LOG_ERROR("Failed to add record for %s", IndexError::What(ret)); (*stats_.mutable_discarded_count())++; return ret; } return 0; } template int FlatStreamer::add_with_id_impl(uint32_t id, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) { if (!query || qmeta.dimension() != meta_.dimension() || qmeta.data_type() != meta_.data_type() || qmeta.element_size() != meta_.element_size()) { LOG_ERROR( "Failed to add for invalid arguments, query=%p, qmeta(type=%u " "dim=%u size=%u) vs meta(type=%u dim=%u size=%u)", query, qmeta.data_type(), qmeta.dimension(), qmeta.element_size(), meta_.data_type(), meta_.dimension(), meta_.element_size()); (*stats_.mutable_discarded_count())++; return IndexError_InvalidArgument; } auto *ctx = dynamic_cast *>(context.get()); if (!ctx) { LOG_ERROR("Failed to cast FlatStreamerContext"); (*stats_.mutable_discarded_count())++; return IndexError_Cast; } READ_LOCK_GUARD_DEFER(dump_mutex_, dump_lock); if (!dump_lock.try_lock()) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } int ret = entity_.add_vector_with_id(id, query, qmeta.element_size()); if (ret != 0) { LOG_ERROR("Failed to add record for %s", IndexError::What(ret)); (*stats_.mutable_discarded_count())++; return ret; } return 0; } template int FlatStreamer::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { ailego_assert(query && count && !!context); ailego_assert(metric_->is_matched(meta_, qmeta)); FlatStreamerContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force streamer context"); return IndexError_InvalidArgument; } if (bf_context->magic() != magic_) { bf_context->reset(this); } if (bf_context->group_by_search()) { return group_by_search_impl(query, qmeta, count, context); } bf_context->reset_results(count); auto &filter = bf_context->filter(); for (size_t q = 0; q < count; ++q) { auto *heap = bf_context->result_heap(); auto *context_stats = bf_context->mutable_stats(q); uint32_t scan_count = 0; int ret = entity_.search(query, filter, &scan_count, heap, context_stats); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to search for %s", IndexError::What(ret)); return ret; } heap->sort(); bf_context->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } return 0; } template int FlatStreamer::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { ailego_assert(query && count && !!context); ailego_assert(metric_->is_matched(meta_, qmeta)); FlatStreamerContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force streamer context"); return IndexError_InvalidArgument; } if (bf_context->magic() != magic_) { bf_context->reset(this); } if (bf_context->group_by_search()) { return group_by_search_p_keys_impl(query, p_keys, qmeta, count, context); } bf_context->reset_results(count); auto &filter = bf_context->filter(); for (size_t q = 0; q < count; ++q) { auto *heap = bf_context->result_heap(); for (node_id_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t key = p_keys[q][idx]; if (!filter.is_valid() || !filter(key)) { dist_t dist = 0; IndexStorage::MemoryBlock block; if (entity_.get_vector_by_key(key, block) != 0) continue; entity_.row_major_distance(query, block.data(), 1, &dist); heap->emplace(key, dist); } } heap->sort(); bf_context->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } return 0; } template int FlatStreamer::group_by_search_impl( const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { FlatStreamerContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force streamer context"); return IndexError_InvalidArgument; } bf_context->resize_group_results(count); if (!bf_context->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](uint64_t key) { return bf_context->group_by()(key); }; auto iterator = entity_.creater_iterator(); for (size_t q = 0; q < count; ++q) { bf_context->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.vector_count(); ++id) { uint64_t key = entity_.key(id); if (!bf_context->filter().is_valid() || !bf_context->filter()(key)) { dist_t dist = 0; IndexStorage::MemoryBlock block; if (entity_.get_vector_by_key(key, block) != 0) continue; entity_.row_major_distance(query, block.data(), 1, &dist); std::string group_id = group_by(key); auto &topk_heap = bf_context->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(bf_context->group_topk()); } topk_heap.emplace(key, dist); } } bf_context->topk_to_group_result(q); query = static_cast(query) + qmeta.element_size(); } return 0; } template int FlatStreamer::group_by_search_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { FlatStreamerContext *bf_context = dynamic_cast *>(context.get()); if (!bf_context) { LOG_ERROR("Invalid brute-force streamer context"); return IndexError_InvalidArgument; } bf_context->resize_group_results(count); if (!bf_context->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](uint64_t key) { return bf_context->group_by()(key); }; auto iterator = entity_.creater_iterator(); for (size_t q = 0; q < count; ++q) { bf_context->group_topk_heaps().clear(); for (node_id_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t key = p_keys[q][idx]; if (!bf_context->filter().is_valid() || !bf_context->filter()(key)) { dist_t dist = 0; IndexStorage::MemoryBlock block; if (entity_.get_vector_by_key(key, block) != 0) continue; entity_.row_major_distance(query, block.data(), 1, &dist); std::string group_id = group_by(key); auto &topk_heap = bf_context->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(bf_context->group_topk()); } topk_heap.emplace(key, dist); } } bf_context->topk_to_group_result(q); query = static_cast(query) + qmeta.element_size(); } return 0; } INDEX_FACTORY_REGISTER_STREAMER_ALIAS(LinearStreamer, FlatStreamer<32>); INDEX_FACTORY_REGISTER_STREAMER_ALIAS(FlatStreamer, FlatStreamer<32>); INDEX_FACTORY_REGISTER_STREAMER_ALIAS(FlatStreamer16, FlatStreamer<16>); INDEX_FACTORY_REGISTER_STREAMER_ALIAS(FlatStreamer32, FlatStreamer<32>); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "flat_streamer_entity.h" #include "flat_utility.h" namespace zvec { namespace core { /*! Flat Streamer */ template class FlatStreamer : public IndexStreamer { public: using ContextPointer = IndexStreamer::Context::UPointer; FlatStreamer(void); virtual ~FlatStreamer(void); FlatStreamer(const FlatStreamer &streamer) = delete; FlatStreamer &operator=(const FlatStreamer &streamer) = delete; public: //! Initialize Streamer int init(const IndexMeta &, const ailego::Params &) override; //! Cleanup Streamer int cleanup(void) override; //! Create a context IndexStreamer::Context::UPointer create_context(void) const override; //! Create a new iterator IndexProvider::Pointer create_provider(void) const override; //! Add a vector into index int add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, Context::UPointer &context) override; int add_with_id_impl(uint32_t id, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Similarity search int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::UPointer &context) const override { return search_bf_impl(query, qmeta, 1, context); } //! Similarity search int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::UPointer &context) const override { return search_bf_impl(query, qmeta, count, context); } //! Similarity brute force search int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::UPointer &context) const override { return search_bf_impl(query, qmeta, 1, context); } //! Similarity brute force search int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::UPointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, Context::UPointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys int search_bf_by_p_keys_impl(const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::UPointer &context) const override; int group_by_search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::UPointer &context) const; int group_by_search_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const; //! Open index from file path int open(IndexStorage::Pointer stg) override; //! Close file int close(void) override; //! flush file int flush(uint64_t checkpoint) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index const IndexMeta &meta(void) const override { return meta_; } const FlatStreamerEntity &entity(void) const { return entity_; } virtual const void *get_vector(uint64_t key) const override { return this->get_vector_by_key(key); } virtual int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return this->get_vector_by_key(key, block); } const void *get_vector_by_key(uint64_t key) const { return entity_.get_vector_by_key(key); } int get_vector_by_key(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return entity_.get_vector_by_key(key, block); } const void *get_vector_by_id(uint32_t id) const override { return get_vector_by_key(id); } int get_vector_by_id(const uint32_t id, IndexStorage::MemoryBlock &block) const override { return get_vector_by_key(id, block); } uint32_t magic(void) const { return magic_; } //! Retrieve block size of data read uint32_t read_block_size(void) const { return read_block_size_; } private: //! Constants static constexpr uint32_t kDefaultBlockVecCount = 32u; static constexpr uint32_t kDefaultSegmentSize = 4 * 1024 * 1024u; static constexpr float kDefaultDocsSoftLimitRatio = 0.9f; enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; //! Members uint32_t magic_{0}; uint32_t docs_hard_limit_{std::numeric_limits::max()}; uint32_t docs_soft_limit_{0}; IndexMeta meta_{}; std::vector> data_; IndexStreamer::Stats stats_{}; IndexMetric::Pointer metric_{}; State state_{STATE_INIT}; mutable std::mutex mapping_mutex_{}; ailego::SharedMutex dump_mutex_{}; FlatStreamerEntity entity_; bool column_major_order_{false}; bool use_key_info_map_{true}; uint32_t read_block_size_{0}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "flat_streamer.h" namespace zvec { namespace core { /*! Brute Force Streamer Context */ template class FlatStreamerContext : public IndexStreamer::Context { public: //! Constructor FlatStreamerContext(const FlatStreamer *owner) { this->reset(owner); } //! Destructor virtual ~FlatStreamerContext(void) = default; //! Set topk of search result void set_topk(uint32_t topk) override { topk_ = topk; result_heap_.limit(topk); } //! Retrieve search result const IndexDocumentList &result(void) const override { return results_[0]; } //! Retrieve search result with index const IndexDocumentList &result(size_t idx) const override { return results_[idx]; } //! Retrieve result object for output IndexDocumentList *mutable_result(size_t idx) override { ailego_assert_with(idx < results_.size(), "invalid idx"); return &results_[idx]; } inline IndexDocumentHeap *result_heap() { return &result_heap_; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result( size_t idx) const override { return group_results_[idx]; } //! Update the parameters of context int update(const ailego::Params & /*params*/) override { return 0; } //! Retrieve magic number uint32_t magic(void) const override { return magic_; } //! Get group topk inline uint32_t group_topk() const { return group_topk_; } //! Get group num inline uint32_t group_num() const { return group_num_; } inline std::map &group_topk_heaps() { return group_topk_heaps_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector() const override { return fetch_vector_; } inline void resize_group_results(size_t size) { if (group_by_search()) { group_results_.resize(size); } } void topk_to_result(uint32_t idx) { if (ailego_unlikely(result_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(result_heap_.size())); result_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = result_heap_[i].score(); if (score > this->threshold()) { break; } key_t key = result_heap_[i].key(); if (fetch_vector_) { IndexStorage::MemoryBlock block; owner_->entity().get_vector_by_key(key, block); results_[idx].emplace_back(key, score, key, block); } else { results_[idx].emplace_back(key, score, key); } } } void topk_to_group_result(uint32_t idx) { ailego_assert_with(idx < group_results_.size(), "invalid idx"); group_results_[idx].clear(); std::vector> group_topk_list; std::vector> best_score_in_groups; for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; heap.sort(); if (heap.size() > 0) { float best_score = heap[0].second; best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; group_topk_list.emplace_back( std::make_pair(group_id, group_topk_heaps_[group_id])); } group_results_[idx].resize(group_topk_list.size()); for (uint32_t i = 0; i < group_topk_list.size(); ++i) { const std::string &group_id = group_topk_list[i].first; group_results_[idx][i].set_group_id(group_id); uint32_t size = std::min( group_topk_, static_cast(group_topk_list[i].second.size())); for (uint32_t j = 0; j < size; ++j) { auto score = group_topk_list[i].second[j].second; if (score > this->threshold()) { break; } node_id_t id = group_topk_list[i].second[j].first; auto provider = owner_->create_provider(); if (fetch_vector_) { IndexStorage::MemoryBlock block; provider->get_vector(id, block); group_results_[idx][i].mutable_docs()->emplace_back(id, score, id, block); } else { group_results_[idx][i].mutable_docs()->emplace_back(id, score, id); } } } } //! Get if group by search bool group_by_search() { return group_num_ > 0; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; group_topk_heaps_.clear(); } void reset() override { for (auto &it : results_) { it.clear(); } for (auto &it : group_results_) { it.clear(); } } //! Reset the context void reset(const FlatStreamer *owner) { this->reset(); magic_ = owner->magic(); feature_size_ = owner->meta().element_size(); uint32_t block_size = feature_size_ * BATCH_SIZE; actual_read_size_ = (owner->read_block_size() + block_size - 1) / block_size * block_size; owner_ = owner; } //! Reset all the query results void reset_results(size_t qnum) { results_.resize(qnum); stats_vec_.resize(qnum); for (size_t i = 0; i < qnum; ++i) { results_[i].clear(); stats_vec_[i].clear(); } result_heap_.clear(); result_heap_.limit(topk_); result_heap_.set_threshold(this->threshold()); } Stats *mutable_stats(size_t idx = 0) { ailego_assert_with(stats_vec_.size() > idx, "invalid index"); return &stats_vec_[idx]; } private: const FlatStreamer *owner_{nullptr}; std::vector stats_vec_{}; uint32_t magic_{0}; uint32_t topk_{0}; uint32_t feature_size_{0}; uint32_t actual_read_size_{0}; IndexDocumentHeap result_heap_; std::vector results_{}; std::string batch_queries_{}; float scores_[BATCH_SIZE * BATCH_SIZE]; bool fetch_vector_{false}; // group uint32_t group_num_{0}; uint32_t group_topk_{0}; std::map group_topk_heaps_{}; std::vector group_results_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer_dumper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "flat_streamer.h" #include "flat_utility.h" namespace zvec { namespace core { template class FlatStreamerDumper { public: typedef std::unique_ptr Pointer; FlatStreamerDumper(const FlatStreamer *owner) { owner_ = owner; dump_size_ = 0; } int dump(const IndexDumper::Pointer &dumper) { ailego::ElapsedTime stamp; std::vector keys; if (owner_->meta().major_order() == IndexMeta::MO_COLUMN) { int error_code = this->write_column_index(dumper.get(), &keys); if (error_code != 0) { return error_code; } } else { int error_code = this->write_row_index(dumper.get(), &keys); if (error_code != 0) { return error_code; } } int error_code = this->write_keys(keys, dumper.get()); if (error_code != 0) { return error_code; } error_code = this->write_mapping(keys, dumper.get()); if (error_code != 0) { return error_code; } error_code = IndexHelper::SerializeToDumper(owner_->meta(), dumper.get()); if (error_code != 0) { return error_code; } LOG_DEBUG("dumped_count: %zu, costtime: %zu", keys.size(), (size_t)stamp.milli_seconds()); return 0; } size_t dump_size() { return dump_size_; } private: int write_column_index(IndexDumper *dumper, std::vector *keys) { switch (IndexMeta::AlignSizeof(owner_->meta().data_type())) { case 2: return this->write_column_index(dumper, keys); case 4: return this->write_column_index(dumper, keys); case 8: return this->write_column_index(dumper, keys); default: ailego_check_with(0, "BAD CASE"); } return IndexError_Runtime; } template int write_column_index(IndexDumper *dumper, std::vector *keys) { auto iter = owner_->entity().creater_iterator(); if (!iter) { LOG_ERROR("Failed to create iterator"); return IndexError_Runtime; } // Write features size_t element_size = owner_->meta().element_size(); size_t block_size = element_size * BATCH_SIZE; std::string block1, block2; block1.reserve(block_size); block2.reserve(block_size); for (; iter->is_valid(); iter->next()) { block1.append(reinterpret_cast(iter->data()), element_size); keys->emplace_back(iter->key()); if (block1.size() == block_size) { ailego::MatrixHelper::Transpose( block1.data(), element_size / sizeof(T), (void *)block2.data()); if (dumper->write(block2.data(), block_size) != block_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } block1.clear(); dump_size_ += block_size; } } if (!block1.empty()) { if (dumper->write(block1.data(), block1.size()) != block1.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += block1.size(); } // Write the padding if need size_t features_size = keys->size() * element_size; size_t features_padding_size = ailego_align(features_size, 32) - features_size; if (features_padding_size) { std::string padding(features_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += padding.size(); } return dumper->append(FLAT_SEGMENT_FEATURES_SEG_ID, features_size, features_padding_size, 0); } int write_row_index(IndexDumper *dumper, std::vector *keys) { auto iter = owner_->entity().creater_iterator(); if (!iter) { LOG_ERROR("Failed to create iterator"); return IndexError_Runtime; } // Write features size_t element_size = owner_->meta().element_size(); for (; iter->is_valid(); iter->next()) { if (dumper->write(iter->data(), element_size) != element_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += element_size; keys->emplace_back(iter->key()); } // Write the padding if need size_t features_size = keys->size() * element_size; size_t features_padding_size = ailego_align(features_size, 32) - features_size; if (features_padding_size) { std::string padding(features_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += padding.size(); } return dumper->append(FLAT_SEGMENT_FEATURES_SEG_ID, features_size, features_padding_size, 0); } int write_keys(const std::vector &keys, IndexDumper *dumper) { size_t keys_size = keys.size() * sizeof(uint64_t); size_t keys_padding_size = ailego_align(keys_size, 32) - keys_size; if (dumper->write(keys.data(), keys_size) != keys_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += keys_size; // Write the padding if need if (keys_padding_size) { std::string padding(keys_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += padding.size(); } return dumper->append(FLAT_SEGMENT_KEYS_SEG_ID, keys_size, keys_padding_size, 0); } int write_mapping(const std::vector &keys, IndexDumper *dumper) { std::vector mapping(keys.size()); std::iota(mapping.begin(), mapping.end(), 0); std::sort(mapping.begin(), mapping.end(), [&keys](uint32_t lhs, uint32_t rhs) { return (keys[lhs] < keys[rhs]); }); size_t mapping_size = mapping.size() * sizeof(uint32_t); size_t mapping_padding_size = ailego_align(mapping_size, 32) - mapping_size; if (dumper->write(mapping.data(), mapping_size) != mapping_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += mapping_size; // Write the padding if need if (mapping_padding_size) { std::string padding(mapping_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } dump_size_ += padding.size(); } return dumper->append(FLAT_SEGMENT_MAPPING_SEG_ID, mapping_size, mapping_padding_size, 0); } private: const FlatStreamer *owner_{nullptr}; size_t dump_size_{0}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_streamer_entity.h" #include #include #include "flat_utility.h" namespace zvec { namespace core { FlatStreamerEntity::FlatStreamerEntity(IndexStreamer::Stats &stats) : stats_(stats) {} int FlatStreamerEntity::open(IndexStorage::Pointer storage, const IndexMeta & /*mt*/) { if (storage_) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } // segments_[0] store the meta information of the linear list ailego_assert_with(segments_.size() == 0, "Invalid Size"); key_info_map_lock_ = std::make_shared(); key_info_map_.clear(); id_key_vector_.clear(); withid_key_info_map_.clear(); withid_key_map_.clear(); vec_unit_size_ = IndexMeta::AlignSizeof(index_meta_.data_type()); vec_cols_ = index_meta_.element_size() / vec_unit_size_; meta_.header.block_size = ailego_align(sizeof(BlockHeader) + sizeof(DeletionMap) + (index_meta_.element_size() + sizeof(uint64_t)) * meta_.header.block_vector_count, 32); if (storage->get(FLAT_LINEAR_LIST_HEAD_SEG_ID) || storage->get(FLAT_LINEAR_META_SEG_ID)) { int ret = this->load_storage(storage); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to load storage index"); return ret; } } else { int ret = this->init_storage(storage); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to init storage"); return ret; } } storage_ = storage; //! Create the distance calculator auto metric = IndexFactory::CreateMetric(index_meta_.metric_name()); if (!metric) { LOG_ERROR("Failed to create metric %s", index_meta_.metric_name().c_str()); return IndexError_NoExist; } int ret = metric->init(index_meta_, index_meta_.metric_params()); if (ret != 0) { LOG_ERROR("Failed to initialize metric %s", index_meta_.metric_name().c_str()); return ret; } row_distance_ = metric->distance(); column_distance_ = metric->distance_matrix(meta_.header.block_vector_count, 1); LOG_DEBUG("Open storage %s done, metric=%s", storage_->name().c_str(), index_meta_.metric_name().c_str()); return 0; } int FlatStreamerEntity::close(void) { segments_.clear(); storage_.reset(); key_info_map_lock_.reset(); key_info_map_.clear(); withid_key_info_map_.clear(); withid_key_map_.clear(); id_key_vector_.clear(); meta_.create_time = 0; meta_.update_time = 0; meta_.segment_count = 0; meta_.header.total_vector_count = 0; meta_.header.block_count = 0; meta_.header.block_size = 0; meta_.header.linear_body_size = 0; return 0; } int FlatStreamerEntity::flush_linear_meta(void) { if (!storage_) { return 0; } meta_.update_time = ailego::Realtime::Seconds(); meta_.revision_id = stats_.revision_id(); stats_.set_update_time(meta_.update_time); auto segment = storage_->get(FLAT_LINEAR_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get segment %s", FLAT_LINEAR_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("Failed to write segment %s", FLAT_LINEAR_META_SEG_ID.c_str()); return IndexError_WriteData; } return 0; } int FlatStreamerEntity::flush(uint64_t checkpoint) { int ret = this->flush_linear_meta(); if (ailego_unlikely(ret != 0)) { return ret; } if (checkpoint != 0) { storage_->refresh(checkpoint); } ret = storage_->flush(); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to refresh storage for %s", IndexError::What(ret)); return ret; } if (checkpoint != 0) { stats_.set_check_point(checkpoint); } return 0; } int FlatStreamerEntity::add(uint64_t key, const void *vec, size_t size) { std::lock_guard lock(mutex_); if (filter_same_key_) { key_info_map_lock_->lock_shared(); if (key_info_map_.find(key) != key_info_map_.end()) { key_info_map_lock_->unlock_shared(); LOG_WARN("Try to add duplicate key, drop it"); return IndexError_Duplicate; } key_info_map_lock_->unlock_shared(); } if (size != static_cast(index_meta_.element_size())) { LOG_ERROR("Failed to add, mismatch size %zu vs elemsize %u", size, index_meta_.element_size()); return IndexError_Mismatch; } IndexStorage::MemoryBlock head_block; this->get_head_block(head_block); const BlockLocation *bl = reinterpret_cast(head_block.data()); if (ailego_unlikely(bl == nullptr)) { LOG_ERROR("Failed to get block loc"); return IndexError_ReadData; } BlockLocation block = *bl; if (!this->is_valid_block(block)) { int ret = this->alloc_block(block, &block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->update_head_block(block); if (ailego_unlikely(ret != 0)) { return ret; } } int ret = this->add_to_block(block, key, vec, size); if (ret == IndexError_IndexFull) { ret = this->alloc_block(block, &block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->update_head_block(block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->add_to_block(block, key, vec, size); if (ailego_unlikely(ret != 0)) { return ret; } } if (ailego_unlikely(ret != 0)) { return ret; } (*stats_.mutable_added_count())++; stats_.set_revision_id(meta_.revision_id + 1); return 0; } int FlatStreamerEntity::search(const void *query, const IndexFilter &filter, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { IndexStorage::MemoryBlock head_block; this->get_head_block(head_block); const BlockLocation *bl = reinterpret_cast(head_block.data()); if (ailego_unlikely(bl == nullptr)) { LOG_ERROR("Failed to get block loc"); return IndexError_ReadData; } BlockLocation block = *bl; while (this->is_valid_block(block)) { IndexStorage::MemoryBlock block_header_block; this->get_block_header(block, block_header_block); const BlockHeader *hd = reinterpret_cast(block_header_block.data()); if (ailego_unlikely(hd == nullptr)) { LOG_ERROR("Failed to get block header"); return IndexError_ReadData; } if (hd->vector_count > 0) { *scan_count += hd->vector_count; IndexStorage::MemoryBlock deletion_map_block; this->get_block_deletion_map(block, deletion_map_block); const DeletionMap *deletion_map = reinterpret_cast(deletion_map_block.data()); if (filter.is_valid() || deletion_map->is_dirty()) { this->search_block(query, block, hd, 1.0, filter, deletion_map, heap, context_stats); } else { *(context_stats->mutable_dist_calced_count()) += hd->vector_count; this->search_block(query, block, hd, 1.0, heap); } } block = hd->next; } return 0; } //! Search in a block void FlatStreamerEntity::search_block(const void *query, const BlockLocation &bl, const BlockHeader *hd, float norm_val, IndexDocumentHeap *heap) const { std::vector distances(block_vector_count()); IndexStorage::MemoryBlock vecs_block; this->get_block_vectors(bl, vecs_block); const char *vecs = reinterpret_cast(vecs_block.data()); IndexStorage::MemoryBlock keys_block; this->get_block_keys(bl, keys_block); const uint64_t *keys = reinterpret_cast(keys_block.data()); row_major_distance(query, vecs, hd->vector_count, distances.data()); for (size_t k = 0; k < hd->vector_count; ++k) { if (keys[k] != kInvalidKey) { heap->emplace(keys[k], distances[k] * norm_val); } } } //! Search in a block with filter void FlatStreamerEntity::search_block( const void *query, const BlockLocation &bl, const BlockHeader *hd, float norm_val, const IndexFilter &filter, const DeletionMap *deletion_map, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { std::vector distances(block_vector_count()); IndexStorage::MemoryBlock vecs_block; this->get_block_vectors(bl, vecs_block); const char *vecs = reinterpret_cast(vecs_block.data()); IndexStorage::MemoryBlock keys_block; this->get_block_keys(bl, keys_block); const uint64_t *keys = reinterpret_cast(keys_block.data()); DeletionMap keeps; for (size_t k = 0; k < hd->vector_count; ++k) { const bool condition1 = !deletion_map->test(k); const bool condition2 = filter.is_valid() ? !filter(keys[k]) : true; const bool condition3 = keys[k] != kInvalidKey; if (condition1 && condition2 && condition3) { keeps.set(k); } } if (!keeps.is_dirty()) { (*context_stats->mutable_filtered_count()) += hd->vector_count; return; } for (size_t k = 0; k < hd->vector_count; ++k) { if (keeps.test(k)) { auto cur_vec = vecs + index_meta_.element_size() * k; row_major_distance(query, cur_vec, 1, distances.data() + k); ++(*context_stats->mutable_dist_calced_count()); } } for (size_t k = 0; k < hd->vector_count; ++k) { if (keeps.test(k)) { heap->emplace(keys[k], distances[k] * norm_val); } else { ++(*context_stats->mutable_filtered_count()); } } } int FlatStreamerEntity::search_bf(const void *query, const IndexFilter &filter, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { uint32_t scan_count; return this->search(query, filter, &scan_count, heap, context_stats); } FlatStreamerEntity::Pointer FlatStreamerEntity::clone(void) const { std::vector segments; segments.reserve(segments_.size()); for (size_t i = 0; i < segments_.size(); ++i) { segments.emplace_back(segments_[i]->clone()); if (!segments[i]) { LOG_ERROR("Failed to clone segment, index=%zu", i); return nullptr; } } auto entity = new (std::nothrow) FlatStreamerEntity(stats_); if (!entity) { LOG_ERROR("Failed to New FlatStreamerEntity object"); return nullptr; } entity->index_meta_ = this->index_meta_; entity->storage_ = this->storage_; // entity->reformer_ = this->reformer_; entity->segments_ = segments; entity->meta_ = this->meta_; entity->key_info_map_lock_ = this->key_info_map_lock_; entity->key_info_map_ = this->key_info_map_; entity->id_key_vector_ = this->id_key_vector_; entity->withid_key_info_map_ = this->withid_key_info_map_; entity->withid_key_map_ = this->withid_key_map_; entity->filter_same_key_ = this->filter_same_key_; entity->vec_unit_size_ = this->vec_unit_size_; entity->vec_cols_ = this->vec_cols_; return FlatStreamerEntity::Pointer(entity); } const void *FlatStreamerEntity::get_vector_by_key(uint64_t key) const { VectorLocation loc{}; key_info_map_lock_->lock_shared(); if (use_key_info_map_) { auto iterator = key_info_map_.find(key); if (iterator == key_info_map_.end()) { key_info_map_lock_->unlock_shared(); return nullptr; } loc = iterator->second; } else { if (key < withid_key_info_map_.size()) { loc = withid_key_info_map_[key]; } else { key_info_map_lock_->unlock_shared(); return nullptr; } } key_info_map_lock_->unlock_shared(); auto segment = this->get_segment(loc.segment_id); const void *data = nullptr; if (segment->read(loc.offset, &data, index_meta_.element_size()) != index_meta_.element_size()) { LOG_ERROR("Failed to read segment, size=%u", index_meta_.element_size()); return nullptr; } return data; } int FlatStreamerEntity::get_vector_by_key( const uint64_t key, IndexStorage::MemoryBlock &block) const { VectorLocation loc{}; key_info_map_lock_->lock_shared(); if (use_key_info_map_) { auto iterator = key_info_map_.find(key); if (iterator == key_info_map_.end()) { key_info_map_lock_->unlock_shared(); return -1; } loc = iterator->second; } else { if (key < withid_key_info_map_.size()) { loc = withid_key_info_map_[key]; } else { key_info_map_lock_->unlock_shared(); return -1; } } key_info_map_lock_->unlock_shared(); auto segment = this->get_segment(loc.segment_id); if (segment->read(loc.offset, block, index_meta_.element_size()) != index_meta_.element_size()) { LOG_ERROR("Failed to read segment, size=%u", index_meta_.element_size()); return -1; } return 0; } IndexProvider::Iterator::Pointer FlatStreamerEntity::creater_iterator( void) const { auto entity = this->clone(); if (!entity) { LOG_ERROR("Failed to clone entity"); return nullptr; } return Iterator::Pointer(new (std::nothrow) FlatStreamerEntity::Iterator(std::move(entity))); } void FlatStreamerEntity::Iterator::read_next_block(void) { auto block_size = entity_->linear_block_size(); while (segment_id_ < entity_->segments_.size()) { auto &segment = entity_->segments_[segment_id_]; size_t off = block_index_ * block_size; if (off + block_size > segment->data_size()) { ++segment_id_; block_index_ = 0; continue; } if (segment->read(off, block_, block_size) != block_size) { LOG_ERROR("Failed to read block, off=%zu", off); break; } data_ = block_.data(); auto hd = reinterpret_cast( static_cast(data_) + block_size - sizeof(BlockHeader)); if (hd->vector_count == 0) { ++block_index_; continue; } block_vector_count_ = hd->vector_count; block_vector_index_ = 0; size_t elemsize = entity_->index_meta_.element_size(); keys_ = reinterpret_cast( reinterpret_cast(data_) + elemsize * entity_->block_vector_count()); return; } is_valid_ = false; } int FlatStreamerEntity::init_storage(IndexStorage::Pointer storage) { // Init Linear Meta Segment meta_.create_time = ailego::Realtime::Seconds(); stats_.set_create_time(meta_.create_time); meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(meta_.update_time); meta_.segment_count = 0; meta_.revision_id = 0; std::string str; index_meta_.serialize(&str); const size_t page = ailego::MemoryHelper::PageSize(); meta_.header.header_size = sizeof(LinearIndexHeader) + str.size(); meta_.header.total_vector_count = 0; meta_.header.linear_body_size = 0; meta_.header.block_count = 0; meta_.header.index_meta_size = str.size(); meta_.header.linear_list_count = 1; AdjustSegmentSize(&meta_); LOG_DEBUG( "Create Streamer Index, VecSize=%u, BlockSize=%u SegmentSize=%u " "LinearListCount=%u", index_meta_.element_size(), meta_.header.block_size, meta_.segment_size, meta_.header.linear_list_count); size_t size = ailego_align(sizeof(meta_) + str.size(), page); int ret = storage->append(FLAT_LINEAR_META_SEG_ID, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to append segment %s", FLAT_LINEAR_META_SEG_ID.c_str()); return ret; } auto segment = storage->get(FLAT_LINEAR_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get segment %s", FLAT_LINEAR_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("Failed to write segment data"); return IndexError_WriteData; } if (segment->write(sizeof(meta_), str.data(), str.size()) != str.size()) { LOG_ERROR("Failed to write segment data, size=%zu", str.size()); return IndexError_WriteData; } ret = storage->append("IndexMeta", str.size()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to append segment IndexMeta, code: %d", ret); return ret; } auto index_meta_segment = storage->get("IndexMeta"); if (index_meta_segment->write(0, str.data(), str.size()) != str.size()) { LOG_ERROR("Failed to write segment data, size=%zu", str.size()); return IndexError_WriteData; } *stats_.mutable_index_size() += size; // Init Linear List Head Segment size = ailego_align(sizeof(BlockLocation) * linear_list_count(), page); ret = storage->append(FLAT_LINEAR_LIST_HEAD_SEG_ID, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to append segment %s for %s, size=%zu", FLAT_LINEAR_LIST_HEAD_SEG_ID.c_str(), IndexError::What(ret), size); return ret; } segment = storage->get(FLAT_LINEAR_LIST_HEAD_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get segment %s", FLAT_LINEAR_LIST_HEAD_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->resize(size) != size) { LOG_ERROR("Failed to resize segment, size=%zu", size); return IndexError_WriteData; } segments_.emplace_back(std::move(segment)); *stats_.mutable_index_size() += size; return 0; } int FlatStreamerEntity::load_linear_meta(IndexStorage::Pointer storage) { AdjustSegmentSize(&meta_); // Load Meta Segment auto segment = storage->get(FLAT_LINEAR_META_SEG_ID); if (!segment || segment->data_size() < sizeof(meta_)) { LOG_ERROR("Missing segment %s, or invalid segment size", FLAT_LINEAR_META_SEG_ID.c_str()); return IndexError_InvalidFormat; } IndexStorage::MemoryBlock data_block; if (segment->read(0, data_block, segment->data_size()) != segment->data_size()) { LOG_ERROR("Failed to read storage, size=%zu", segment->data_size()); return IndexError_InvalidFormat; } auto *mt = reinterpret_cast(data_block.data()); if (mt->header.block_vector_count != meta_.header.block_vector_count) { LOG_ERROR("Unmatched BlockVecCount Setting, Index %u vs Setting %u", mt->header.block_vector_count, meta_.header.block_vector_count); return IndexError_Mismatch; } if (mt->header.block_size != meta_.header.block_size) { LOG_ERROR("Unmatched BlockSize Setting, Index %u vs Setting %u", mt->header.block_size, meta_.header.block_size); return IndexError_Mismatch; } if (mt->header.index_meta_size + sizeof(meta_) > segment->data_size()) { LOG_ERROR("Invalid format, IndexMetaSize %u, SegmentSize %zu", mt->header.index_meta_size, segment->data_size()); return IndexError_InvalidFormat; } if (mt->header.linear_list_count != meta_.header.linear_list_count) { LOG_ERROR("Unmatch LinearListCount, Index size %u vs Setting %u", mt->header.linear_list_count, meta_.header.linear_list_count); return IndexError_InvalidFormat; } IndexMeta index_meta; if (!index_meta.deserialize(mt->header.index_meta, mt->header.index_meta_size)) { LOG_ERROR("Failed to deserialize IndexMeta, size=%u", mt->header.index_meta_size); return IndexError_InvalidFormat; } if (index_meta.data_type() != index_meta_.data_type() || index_meta.dimension() != index_meta_.dimension() || index_meta.element_size() != index_meta_.element_size() || index_meta.metric_name() != index_meta_.metric_name()) { LOG_ERROR( "Unmatch IndexMeta, Index(type=%u dim=%u elemsize=%u " "metric=%s) Setting(type=%u dim=%u elemsize=%u metric=%s)", index_meta.data_type(), index_meta.dimension(), index_meta.element_size(), index_meta.metric_name().c_str(), index_meta_.data_type(), index_meta_.dimension(), index_meta_.element_size(), index_meta_.metric_name().c_str()); return IndexError_Mismatch; } // Segment Size can be reconfigurable auto segment_size = meta_.segment_size; std::memcpy(&meta_, mt, sizeof(meta_)); meta_.segment_size = segment_size; return 0; } int FlatStreamerEntity::load_segment_keys_to_map(BlockLocation block) { while (this->is_valid_block(block)) { auto segment = this->get_segment(block.segment_id); IndexStorage::MemoryBlock block_header_block; this->get_block_header(block, block_header_block); const BlockHeader *hd = reinterpret_cast(block_header_block.data()); if (ailego_unlikely(hd == nullptr)) { LOG_ERROR("Failed to get block header"); return IndexError_ReadData; } IndexStorage::MemoryBlock keys_block; this->get_block_keys(block, keys_block); const uint64_t *keys = reinterpret_cast(keys_block.data()); IndexStorage::MemoryBlock deletion_map_block; this->get_block_deletion_map(block, deletion_map_block); const DeletionMap *deletion_map = reinterpret_cast(deletion_map_block.data()); for (uint32_t vector_index = 0; vector_index < hd->vector_count; ++vector_index) { if (deletion_map->test(vector_index)) { continue; } size_t vector_off = this->get_block_vector_offset(block.block_index, vector_index); key_info_map_[keys[vector_index]] = VectorLocation(block.segment_id, false, vector_off); id_key_vector_.push_back(keys[vector_index]); } block = hd->next; } return 0; } int FlatStreamerEntity::load_segment_keys_to_vector() { for (uint32_t i = 0; i < meta_.header.total_vector_count; i++) { size_t block_id = i / block_vector_count(); uint32_t vector_index = i % block_vector_count(); ailego_assert(segments_.size() > 1); size_t segment_block_count = segments_[1]->data_size() / linear_block_size(); size_t segment_id = block_id / segment_block_count + 1; size_t real_block_id = block_id % segment_block_count; size_t vector_off = this->get_block_vector_offset(real_block_id, vector_index); withid_key_info_map_.push_back( VectorLocation(segment_id, false, vector_off)); size_t key_off = get_block_key_offset(real_block_id, vector_index); withid_key_map_.push_back(key_off); } return 0; } int FlatStreamerEntity::load_storage(IndexStorage::Pointer storage) { int ret = this->load_linear_meta(storage); if (ailego_unlikely(ret != 0)) { return ret; } // Load Linear List auto hd_segment = storage->get(FLAT_LINEAR_LIST_HEAD_SEG_ID); if (ailego_unlikely(!hd_segment)) { LOG_ERROR("Failed to get segment %s", FLAT_LINEAR_LIST_HEAD_SEG_ID.c_str()); return IndexError_Runtime; } if (hd_segment->data_size() < linear_list_count() * sizeof(BlockLocation)) { LOG_ERROR("Invalid segment size, LinearListCount=%zu, size=%zu", linear_list_count(), hd_segment->data_size()); return IndexError_InvalidFormat; } segments_.emplace_back(hd_segment); size_t index_size = hd_segment->capacity(); for (size_t i = 1; i <= meta_.segment_count; ++i) { std::string segment_id = ailego::StringHelper::Concat(FLAT_SEGMENT_FEATURES_SEG_ID, i); auto seg = storage->get(segment_id); if (!seg || seg->data_size() < meta_.header.block_size) { LOG_ERROR("Failed to get segment %s, or invalid segment size", segment_id.c_str()); return IndexError_InvalidFormat; } index_size += seg->capacity(); segments_.emplace_back(std::move(seg)); } for (size_t i = 0; i < linear_list_count(); i++) { IndexStorage::MemoryBlock head_block; this->get_head_block(head_block); const BlockLocation *bl = reinterpret_cast(head_block.data()); if (ailego_unlikely(bl == nullptr)) { LOG_ERROR("Failed to get block loc"); return IndexError_ReadData; } BlockLocation block = *bl; if (use_key_info_map_) { ret = this->load_segment_keys_to_map(block); } else { ret = this->load_segment_keys_to_vector(); } if (ailego_unlikely(ret != 0)) { return ret; } } char create_time[32]; char update_time[32]; ailego::Realtime::Gmtime(meta_.create_time, "%Y-%m-%d %H:%M:%S", create_time, sizeof(create_time)); ailego::Realtime::Gmtime(meta_.update_time, "%Y-%m-%d %H:%M:%S", update_time, sizeof(update_time)); LOG_DEBUG( "Load Index, IndexSize=%zu SegmentCount=%u SegmentSize=%u " "RevisionId=%zu BlockCount=%u BlockSize=%u " "BlockVectorCount=%u LinearListCount=%u TotalVecCount=%zu " "CreateTime=%s UpdateTime=%s", index_size, meta_.segment_count, meta_.segment_size, static_cast(meta_.revision_id), meta_.header.block_count, meta_.header.block_size, meta_.header.block_vector_count, meta_.header.linear_list_count, static_cast(meta_.header.total_vector_count), create_time, update_time); stats_.set_index_size(index_size); stats_.set_check_point(storage->check_point()); stats_.set_create_time(meta_.create_time); stats_.set_revision_id(meta_.revision_id); stats_.set_update_time(meta_.update_time); stats_.set_loaded_count(meta_.header.total_vector_count); return 0; } int FlatStreamerEntity::alloc_segment(void) { size_t index = segments_.size(); if (index == kMaxSegmentId) { LOG_ERROR("Failed to alloc new segment, exceed max count %zu", kMaxSegmentId); return IndexError_IndexFull; } std::string segment_id = ailego::StringHelper::Concat(FLAT_SEGMENT_FEATURES_SEG_ID, index); size_t size = ailego_align(meta_.segment_size, ailego::MemoryHelper::PageSize()); auto segment = storage_->get(segment_id); if (segment) { if (segment->padding_size() < linear_block_size()) { LOG_ERROR( "Unexpect segment, index=%zu, data_size=%zu " "padding_size=%zu block_size=%zu", index, segment->data_size(), segment->padding_size(), linear_block_size()); return IndexError_Runtime; } LOG_WARN("Alloc an existing segment=%s capacity=%zu", segment_id.c_str(), segment->capacity()); } else { int ret = storage_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to alloc segment from storage"); return ret; } segment = storage_->get(segment_id); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get segment %s", segment_id.c_str()); return IndexError_Runtime; } } meta_.segment_count += 1; meta_.header.linear_body_size += size; segments_.emplace_back(std::move(segment)); *stats_.mutable_index_size() += size; // Update meta information auto meta_segment = storage_->get(FLAT_LINEAR_META_SEG_ID); if (ailego_unlikely(!meta_segment)) { LOG_ERROR("Failed to get segment %s", FLAT_LINEAR_META_SEG_ID.c_str()); return IndexError_Runtime; } if (meta_segment->write(0, &meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("Failed to write meta segment"); return IndexError_WriteData; } return 0; } int FlatStreamerEntity::alloc_block(const BlockLocation &next, BlockLocation *block) { if (segments_.size() <= 1 || segments_.back()->padding_size() < linear_block_size()) { int ret = this->alloc_segment(); if (ailego_unlikely(ret != 0)) { return ret; } } auto &segment = segments_.back(); size_t block_index = segment->data_size() / linear_block_size(); if (block_index == kMaxBlockId) { LOG_ERROR("Failed to alloc block, exceed max count %zu per segment", kMaxBlockId); return IndexError_IndexFull; } BlockHeader header; header.next = next; header.vector_count = 0; header.column_major = false; size_t hd_off = segment->data_size() + linear_block_size() - sizeof(header); if (segment->write(hd_off, &header, sizeof(header)) != sizeof(header)) { LOG_ERROR("Failed to write block header"); return IndexError_WriteData; } size_t del_off = hd_off - sizeof(DeletionMap); DeletionMap reset_del_map{}; if (segment->write(del_off, &reset_del_map, sizeof(reset_del_map)) != sizeof(reset_del_map)) { LOG_ERROR("Failed to write block deletion map"); return IndexError_WriteData; } ++meta_.header.block_count; block->segment_id = segments_.size() - 1; block->block_index = (segment->data_size() / linear_block_size()) - 1; return 0; } int FlatStreamerEntity::add_to_block(const BlockLocation &block, uint64_t key, const void *data, size_t size) { IndexStorage::MemoryBlock block_header_block; this->get_block_header(block, block_header_block); const BlockHeader *header = reinterpret_cast(block_header_block.data()); if (ailego_unlikely(header == nullptr)) { LOG_ERROR("Failed to get header"); return IndexError_ReadData; } if (header->vector_count == block_vector_count()) { return IndexError_IndexFull; } auto &segment = segments_[block.segment_id]; size_t vector_off = get_block_vector_offset(block.block_index, header->vector_count); if (segment->write(vector_off, data, size) != size) { LOG_ERROR("Failed to write vector, off=%zu size=%zu", vector_off, size); return IndexError_WriteData; } size_t key_off = get_block_key_offset(block.block_index, header->vector_count); if (segment->write(key_off, &key, sizeof(key)) != sizeof(key)) { LOG_ERROR("Failed to write key, off=%zu", key_off); return IndexError_WriteData; } BlockHeader hd = *header; hd.vector_count += 1; size_t hd_off = get_block_header_offset(block.block_index); if (segment->write(hd_off, &hd, sizeof(hd)) != sizeof(hd)) { LOG_ERROR("Failed to write block header, off=%zu", hd_off); return IndexError_WriteData; } VectorLocation loc(block.segment_id, false, vector_off); key_info_map_lock_->lock(); key_info_map_[key] = loc; id_key_vector_.push_back(key); withid_key_info_map_.push_back(loc); withid_key_map_.push_back(key_off); key_info_map_lock_->unlock(); ++meta_.header.total_vector_count; return 0; } int FlatStreamerEntity::add_vector_with_id(const uint32_t id, const void *query, const uint32_t size) { std::lock_guard lock(mutex_); // if (filter_same_key_) { // key_info_map_lock_->lock_shared(); // if (key_info_map_.find(id) != key_info_map_.end()) { // key_info_map_lock_->unlock_shared(); // LOG_WARN("Try to add duplicate key, drop it"); // return IndexError_Duplicate; // } // key_info_map_lock_->unlock_shared(); // } if (size != static_cast(index_meta_.element_size())) { LOG_ERROR("Failed to add, mismatch size %u vs elemsize %u", size, index_meta_.element_size()); return IndexError_Mismatch; } if (id >= vector_count()) { IndexStorage::MemoryBlock head_block; this->get_head_block(head_block); BlockLocation block = *reinterpret_cast(head_block.data()); if (!this->is_valid_block(block)) { int ret = this->alloc_block(block, &block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->update_head_block(block); if (ailego_unlikely(ret != 0)) { return ret; } } for (size_t start_id = vector_count(); start_id < id; ++start_id) { std::vector vec(size); int ret = this->add_to_block(block, kInvalidKey, vec.data(), size); if (ret == IndexError_IndexFull) { ret = this->alloc_block(block, &block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->update_head_block(block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->add_to_block(block, kInvalidKey, vec.data(), size); if (ailego_unlikely(ret != 0)) { return ret; } } } int ret = this->add_to_block(block, id, query, size); if (ret == IndexError_IndexFull) { ret = this->alloc_block(block, &block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->update_head_block(block); if (ailego_unlikely(ret != 0)) { return ret; } ret = this->add_to_block(block, id, query, size); if (ailego_unlikely(ret != 0)) { return ret; } } } else { VectorLocation vector_loc = withid_key_info_map_[id]; auto segment = this->get_segment(vector_loc.segment_id); size_t vector_off = vector_loc.offset; if (segment->write(vector_off, query, size) != size) { LOG_ERROR("Failed to write vector, off=%zu size=%u", vector_off, size); return IndexError_WriteData; } size_t key_off = withid_key_map_[id]; uint64_t key = id; if (segment->write(key_off, &key, sizeof(key)) != sizeof(key)) { LOG_ERROR("Failed to write key, off=%zu", key_off); return IndexError_WriteData; } key_info_map_lock_->lock(); key_info_map_[key] = vector_loc; key_info_map_lock_->unlock(); } (*stats_.mutable_added_count())++; stats_.set_revision_id(meta_.revision_id + 1); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "flat_index_format.h" #include "flat_utility.h" namespace zvec { namespace core { /*! Flat Streamer Entity */ class FlatStreamerEntity { public: typedef std::shared_ptr Pointer; //! Constructor explicit FlatStreamerEntity(IndexStreamer::Stats &stats); //! Destructor virtual ~FlatStreamerEntity(void) = default; //! Open the entity with storage int open(IndexStorage::Pointer storage, const IndexMeta &mt); //! Close the entity int close(void); //! Flush Linear Meta information to storage int flush_linear_meta(void); //! Flush linear index to storage int flush(uint64_t checkpoint); //! Add vector to linear index int add(uint64_t key, const void *vec, size_t size); //! Search in linear list with filter int search(const void *query, const IndexFilter &filter, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! Search in a block void search_block(const void *query, const BlockLocation &bl, const BlockHeader *hd, float norm_val, IndexDocumentHeap *heap) const; //! Search in a block with filter void search_block(const void *query, const BlockLocation &bl, const BlockHeader *hd, float norm_val, const IndexFilter &filter, const DeletionMap *deletion_map, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! Flat Search with filter int search_bf(const void *query, const IndexFilter &filter, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! Clone the entity virtual FlatStreamerEntity::Pointer clone(void) const; //! Retrieve the total vectors in the index size_t vector_count(void) const { return meta_.header.total_vector_count; } //! Retrieve the linear list count size_t linear_list_count(void) const { return meta_.header.linear_list_count; } //! Retrieve block size of the linear vector size_t linear_block_size(void) const { return meta_.header.block_size; } //! Retrieve the vectors count in one block size_t block_vector_count(void) const { // assert(meta_.header.block_vector_count == 32); return meta_.header.block_vector_count; } //! Retrieve IndexMeta of the linear index const IndexMeta &meta(void) const { return index_meta_; } //! Retrieve mutable IndexMeta of the linear index IndexMeta *mutable_meta(void) { return &index_meta_; } //! Retrieve vector by local id const void *get_vector_by_key(uint64_t key) const; int get_vector_by_key(const uint64_t key, IndexStorage::MemoryBlock &block) const; //! Create a new iterator IndexProvider::Iterator::Pointer creater_iterator(void) const; //! Set params void set_block_vector_count(uint32_t count) { meta_.header.block_vector_count = count; } void set_use_key_info_map(bool use_id_map) { use_key_info_map_ = use_id_map; LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); } //! Set params void set_segment_size(uint32_t size) { meta_.segment_size = size; } //! Set params void set_linear_list_count(uint32_t count) { meta_.header.linear_list_count = count; } //! Set params void enable_filter_same_key(bool enabled) { filter_same_key_ = enabled; } inline uint64_t key(uint32_t id) const { if (id < id_key_vector_.size()) { return id_key_vector_[id]; } else { return kInvalidKey; } } inline void row_major_distance(const void *query, const void *feature, size_t fnum, float *out) const { const uint8_t *cur_feature = reinterpret_cast(feature); for (size_t f = 0; f < fnum; ++f) { row_distance_(query, cur_feature, index_meta_.dimension(), out + f); cur_feature += index_meta_.element_size(); } } int add_vector_with_id(const uint32_t id, const void *query, const uint32_t element_size); private: //! Disable them FlatStreamerEntity(const FlatStreamerEntity &) = delete; FlatStreamerEntity &operator=(const FlatStreamerEntity &) = delete; /*! Iterator of all the linear list */ class Iterator : public IndexProvider::Iterator { public: //! Constructor Iterator(const FlatStreamerEntity::Pointer &entity) : entity_(entity) { this->read_next_block(); } //! Retrieve pointer of data const void *data(void) const override { return reinterpret_cast(data_) + block_vector_index_ * entity_->index_meta_.element_size(); } //! Test if the iterator is valid bool is_valid(void) const override { return is_valid_; } //! Retrieve primary key uint64_t key(void) const override { return keys_[block_vector_index_]; } //! Next iterator void next(void) override { if (++block_vector_index_ == block_vector_count_) { ++block_index_; this->read_next_block(); } } private: //! Read next non-empty block void read_next_block(void); //! Members std::string buffer_{}; const FlatStreamerEntity::Pointer entity_; IndexStorage::MemoryBlock block_; const void *data_{nullptr}; const uint64_t *keys_{nullptr}; uint32_t segment_id_{1u}; // The first segment is header info uint32_t block_index_{0u}; uint32_t block_vector_index_{0u}; uint32_t block_vector_count_{0u}; bool is_valid_{true}; }; //! Retrive storage segment by index const IndexStorage::Segment::Pointer get_segment(size_t index) const { for (size_t i = segments_.size(); i <= index; ++i) { auto segment_id = ailego::StringHelper::Concat(FLAT_SEGMENT_FEATURES_SEG_ID, i); auto segment = storage_->get(segment_id); if (!segment) { LOG_ERROR("Failed to get segment %s", segment_id.c_str()); return IndexStorage::Segment::Pointer(); } segments_.emplace_back(std::move(segment)); } return segments_[index]; } //! Rejust the segment size as to aligned by page size void AdjustSegmentSize(StreamerLinearMeta *mt) { if (mt->segment_size < mt->header.block_size) { mt->segment_size = mt->header.block_size; } mt->segment_size = ailego_align( mt->segment_size / mt->header.block_size * mt->header.block_size, ailego::MemoryHelper::PageSize()); } //! Init with an empty storage int init_storage(IndexStorage::Pointer storage); //! Load linear meta information from storage int load_linear_meta(IndexStorage::Pointer storage); //! Load keys to keys map int load_segment_keys_to_map(BlockLocation block); //! Load keys to keys map int load_segment_keys_to_vector(void); //! Load index from storage int load_storage(IndexStorage::Pointer storage); //! Check whether the block is empty bool is_valid_block(const BlockLocation &block) const { return block.segment_id != 0; } //! Update header block of an linear list int update_head_block(const BlockLocation &block) { ailego_assert_with(segments_.size() != 0, "Invalid Segments"); auto &hd_segment = segments_[0]; if (hd_segment->write(0, &block, sizeof(block)) != sizeof(block)) { LOG_ERROR("Failed to write head block location"); return IndexError_WriteData; } return 0; } //! Alloc a new segment int alloc_segment(void); //! Alloc a new block int alloc_block(const BlockLocation &next, BlockLocation *block); //! Add a record to a block int add_to_block(const BlockLocation &block, uint64_t key, const void *data, size_t size); private: size_t get_block_offset(uint32_t block_index) const { return block_index * linear_block_size(); } size_t get_block_header_offset(uint32_t block_index) const { return get_block_offset(block_index) + linear_block_size() - sizeof(BlockHeader); } size_t get_block_deletion_map_offset(uint32_t block_index) const { return get_block_header_offset(block_index) - sizeof(DeletionMap); } size_t get_block_key_offset(uint32_t block_index, uint32_t vector_index) const { return get_block_offset(block_index) + block_vector_count() * index_meta_.element_size() + sizeof(uint64_t) * vector_index; } size_t get_block_vector_offset(uint32_t block_index, uint32_t vector_index) const { return this->get_block_offset(block_index) + vector_index * index_meta_.element_size(); } //! Get header block of an linear list int get_head_block(IndexStorage::MemoryBlock &header_block) const { ailego_assert_with(segments_.size() != 0, "Invalid Segments"); auto &hd_segment = segments_[0]; if (hd_segment->read(0, header_block, sizeof(BlockLocation)) != sizeof(BlockLocation)) { LOG_ERROR("Failed to read head block location"); return -1; } return 0; } //! Get BlockHeader of the block int get_block_header(const BlockLocation &block, IndexStorage::MemoryBlock &header_block) const { // The header is located in the end of a block to align features auto &segment = this->get_segment(block.segment_id); ailego_assert_with(segment != nullptr, "Index Overflow"); size_t off = this->get_block_header_offset(block.block_index); if (segment->read(off, header_block, sizeof(BlockHeader)) != sizeof(BlockHeader)) { LOG_ERROR("Failed to read block header, off=%zu", off); return -1; } return 0; } int get_block_deletion_map( const BlockLocation &block, IndexStorage::MemoryBlock &deletion_map_block) const { auto &segment = this->get_segment(block.segment_id); ailego_assert_with(segment != nullptr, "Index Overflow"); size_t off = this->get_block_deletion_map_offset(block.block_index); if (segment->read(off, deletion_map_block, sizeof(DeletionMap)) != sizeof(DeletionMap)) { LOG_ERROR("Failed to read deletion map, off=%zu", off); return -1; } return 0; } int get_block_keys(const BlockLocation &block, IndexStorage::MemoryBlock &keys_block) const { auto &segment = this->get_segment(block.segment_id); ailego_assert_with(segment != nullptr, "Index Overflow"); size_t off = this->get_block_key_offset(block.block_index, 0); if (segment->read(off, keys_block, block_vector_count() * sizeof(uint64_t)) != block_vector_count() * sizeof(uint64_t)) { LOG_ERROR("Failed to read block header, off=%zu", off); return -1; } return 0; } int get_block_vectors(const BlockLocation &block, IndexStorage::MemoryBlock &vector_block) const { auto &segment = this->get_segment(block.segment_id); ailego_assert_with(segment != nullptr, "Index Overflow"); size_t off = this->get_block_vector_offset(block.block_index, 0); if (segment->read(off, vector_block, block_vector_count() * index_meta_.element_size()) != block_vector_count() * index_meta_.element_size()) { LOG_ERROR("Failed to read block header, off=%zu", off); return -1; } return 0; } private: //! Constants static constexpr size_t kMaxSegmentId = std::numeric_limits::max(); static constexpr size_t kMaxBlockId = std::numeric_limits::max(); //! Members std::mutex mutex_{}; IndexMeta index_meta_{}; IndexStorage::Pointer storage_{}; IndexMetric::MatrixDistance row_distance_{}, column_distance_{}; mutable std::vector segments_{}; StreamerLinearMeta meta_{}; IndexStreamer::Stats &stats_; mutable std::shared_ptr key_info_map_lock_{}; std::unordered_map key_info_map_{}; std::vector withid_key_info_map_{}; std::vector withid_key_map_{}; std::vector id_key_vector_{}; bool filter_same_key_{false}; bool use_key_info_map_{true}; uint32_t vec_unit_size_{0}; uint32_t vec_cols_{0}; mutable std::string vec_buf_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_streamer_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "flat_distance_matrix.h" #include "flat_searcher.h" #include "flat_streamer.h" #include "flat_utility.h" namespace zvec { namespace core { /*! Brute Force Streamer Provider */ template class FlatStreamerProvider : public IndexProvider { public: //! Constructor FlatStreamerProvider(const FlatStreamer *owner) { feature_size_ = owner->meta().element_size(); total_vector_count_ = owner->entity().vector_count(); owner_ = owner; block_buffer_.resize(BATCH_SIZE * feature_size_); } //! Create a new iterator IndexProvider::Iterator::Pointer create_iterator(void) override { return owner_->entity().creater_iterator(); } //! Retrieve count of vectors size_t count(void) const override { return total_vector_count_; } //! Retrieve dimension of vector size_t dimension(void) const override { return owner_->meta().dimension(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return owner_->meta().data_type(); } //! Retrieve vector size in bytes size_t element_size(void) const override { return owner_->meta().element_size(); } //! Retrieve a vector using a primary key const void *get_vector(uint64_t key) const override { return this->get_vector_by_key(key); } int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return this->get_vector_by_key(key, block); } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_->name(); } protected: //! Retrieve a vector via primary key const void *get_vector_by_key(uint64_t key) const { return owner_->get_vector_by_key(key); } int get_vector_by_key(const uint64_t key, IndexStorage::MemoryBlock &block) const { return owner_->get_vector_by_key(key, block); } private: //! Members const FlatStreamer *owner_{nullptr}; IndexStorage::Segment::Pointer features_segment_{}; uint32_t feature_size_{0}; uint32_t total_vector_count_{0}; mutable std::vector block_buffer_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat/flat_utility.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include namespace zvec { namespace core { //! The default size of reading a block static constexpr uint32_t FLAT_DEFAULT_READ_BLOCK_SIZE = 4 * 1024 * 1024; static const std::string FLAT_LINEAR_META_SEG_ID = "flat.linear_meta"; static const std::string FLAT_LINEAR_LIST_HEAD_SEG_ID = "flat.linear_list_head"; static const std::string FLAT_SEGMENT_KEYS_SEG_ID("flat.keys"); static const std::string FLAT_SEGMENT_FEATURES_SEG_ID("flat.features"); static const std::string FLAT_SEGMENT_MAPPING_SEG_ID("flat.mapping"); // index params static const std::string PARAM_FLAT_COLUMN_MAJOR_ORDER( "proxima.flat.column_major_order"); static const std::string PARAM_FLAT_BATCH_SIZE("proxima.flat.batch_size"); static const std::string PARAM_FLAT_READ_BLOCK_SIZE( "proxima.flat.read_block_size"); static const std::string PARAM_FLAT_USE_ID_MAP("proxima.flat.use_id_map"); //! Determines if a number is equal to two to the power of n. template struct IsEqualPowerofTwo : std::integral_constant {}; //! Transpose a block template static inline void ReverseTranspose(size_t align_size, const void *src, size_t dim, void *dst) { switch (align_size) { case 2: ailego::MatrixHelper::ReverseTranspose(src, dim, dst); break; case 4: ailego::MatrixHelper::ReverseTranspose(src, dim, dst); break; case 8: ailego::MatrixHelper::ReverseTranspose(src, dim, dst); break; } } static inline void ReverseTranspose(size_t align_size, const void *src, size_t m, size_t dim, void *dst) { switch (align_size) { case 2: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; case 4: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; case 8: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; } } template static inline void TransposeOne(const void *src, size_t M, size_t N, void *dst) { for (size_t i = 0; i < N; ++i) { reinterpret_cast(dst)[i] = reinterpret_cast(src)[i * M]; } } static inline void Transpose(size_t align_size, const void *src, size_t m, size_t dim, void *dst) { switch (align_size) { case 2: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; case 4: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; case 8: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; } } //! Transpose queries template void TransposeQueries(const void *query, const IndexQueryMeta &qmeta, size_t query_count, std::string *out) { if (K <= 1) { ailego_assert(query_count == 1); (void)query_count; out->append(reinterpret_cast(query) + out->size(), qmeta.element_size()); } else { ailego_assert_with(IsEqualPowerofTwo::value, "K must be equal to two to the power of n."); size_t query_batch_count = query_count / K; size_t query_offset = out->size(); out->resize(query_offset + query_batch_count * K * qmeta.element_size()); switch (IndexMeta::AlignSizeof(qmeta.data_type())) { case 2: for (size_t i = 0; i != query_batch_count; ++i) { ailego::MatrixHelper::Transpose( (const char *)query + query_offset, qmeta.element_size() / sizeof(uint16_t), &((*out)[query_offset])); query_offset += qmeta.element_size() * K; } break; case 4: for (size_t i = 0; i != query_batch_count; ++i) { ailego::MatrixHelper::Transpose( (const char *)query + query_offset, qmeta.element_size() / sizeof(uint32_t), &((*out)[query_offset])); query_offset += qmeta.element_size() * K; } break; case 8: for (size_t i = 0; i != query_batch_count; ++i) { ailego::MatrixHelper::Transpose( (const char *)query + query_offset, qmeta.element_size() / sizeof(uint64_t), &((*out)[query_offset])); query_offset += qmeta.element_size() * K; } break; default: ailego_check_with(0, "BAD CASE"); } size_t query_left_count = query_count % K; if (query_left_count != 0) { TransposeQueries<(K >> 1)>(query, qmeta, query_left_count, out); } } } //! Create and initialize measure static inline int InitializeMetric(const IndexMeta &mt, IndexMetric::Pointer *out) { IndexMetric::Pointer measure = IndexFactory::CreateMetric(mt.metric_name()); if (!measure) { return IndexError_NoExist; } int error_code = measure->init(mt, mt.metric_params()); if (error_code != 0) { return error_code; } *out = measure; return 0; } //! Verify measure static inline bool VerifyMetric(const IndexMeta &meta) { IndexMetric::Pointer measure = IndexFactory::CreateMetric(meta.metric_name()); if (!measure) { return false; } int error_code = measure->init(meta, meta.metric_params()); if (error_code != 0) { return false; } return true; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_knn_flat_sparse STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS core_framework INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_builder.h" #include #include #include #include #include #include #include #include "flat_sparse_index_format.h" #include "flat_sparse_utility.h" namespace zvec { namespace core { FlatSparseBuilder::FlatSparseBuilder() {} int FlatSparseBuilder::init(const IndexMeta &meta, const ailego::Params & /*params*/) { LOG_INFO("Begin FlatSparseBuilder::init"); meta_ = meta; state_ = BUILD_STATE_INITED; LOG_INFO("End FlatSparseBuilder::init"); return 0; } int FlatSparseBuilder::cleanup(void) { LOG_INFO("Begin FlatSparseBuilder::cleanup"); stats_.clear_attributes(); stats_.set_trained_count(0UL); stats_.set_built_count(0UL); stats_.set_dumped_count(0UL); stats_.set_discarded_count(0UL); stats_.set_trained_costtime(0UL); stats_.set_built_costtime(0UL); stats_.set_dumped_costtime(0UL); state_ = BUILD_STATE_INIT; LOG_INFO("End FlatSparseBuilder::cleanup"); return 0; } int FlatSparseBuilder::train(IndexThreads::Pointer, IndexSparseHolder::Pointer /*holder*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before FlatSparseBuilder::train"); return IndexError_NoReady; } LOG_INFO("Begin FlatSparseBuilder::train"); stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End FlatSparseBuilder::train"); return 0; } int FlatSparseBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before FlatSparseBuilder::train"); return IndexError_NoReady; } LOG_INFO("Begin FlatSparseBuilder::train by trainer"); stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End FlatSparseBuilder::train by trainer"); return 0; } int FlatSparseBuilder::build(IndexThreads::Pointer, IndexSparseHolder::Pointer holder) { LOG_INFO("Begin FlatSparseBuilder::build"); ailego::ElapsedTime stamp; if (!holder) { LOG_ERROR("Input holder is nullptr while building index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while building index"); return IndexError_Mismatch; } holder_ = std::move(holder); stats_.set_built_count(holder_->count()); stats_.set_built_costtime(stamp.milli_seconds()); state_ = BUILD_STATE_BUILT; LOG_INFO("End FlatSparseBuilder::build"); return 0; } int FlatSparseBuilder::dump(const IndexDumper::Pointer &dumper) { if (state_ != BUILD_STATE_BUILT || !holder_) { LOG_INFO("Build the index before FlatSparseBuilder::dump"); return IndexError_NoReady; } LOG_INFO("Begin FlatSparseBuilder::dump"); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } uint32_t dump_count; ret = do_dump(dumper, &dump_count); if (ret != 0) { LOG_ERROR("Failed to dump index"); return ret; } holder_ = nullptr; stats_.set_dumped_count(dump_count); stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); LOG_INFO("End FlatSparseBuilder::dump"); return 0; } int FlatSparseBuilder::do_dump(const IndexDumper::Pointer &dumper, uint32_t *dump_count) { // bf meta int ret = dump_meta(dumper.get()); if (ret != 0) { LOG_ERROR("Failed to dump meta"); return ret; } std::vector keys; ret = dump_vector_and_offset(dumper.get(), &keys); if (ret != 0) { LOG_ERROR("Failed to dump offset data"); return ret; } ret = dump_keys(keys, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to dump keys"); return ret; } ret = dump_mapping(keys, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to dump mapping"); return ret; } *dump_count = keys.size(); return 0; } int FlatSparseBuilder::dump_meta(IndexDumper *dumper) { FlatSparseMeta meta; meta.create_time = ailego::Realtime::Seconds(); meta.update_time = ailego::Realtime::Seconds(); meta.doc_cnt = holder_->count(); if (dumper->write(&meta, sizeof(meta)) != sizeof(meta)) { LOG_ERROR("Failed to write meta"); return IndexError_WriteData; } size_t meta_padding_size = ailego_align(sizeof(meta), 32) - sizeof(meta); if (meta_padding_size) { std::string padding(meta_padding_size, '\0'); if (dumper->write(padding.data(), meta_padding_size) != meta_padding_size) { LOG_ERROR("Failed to write meta padding"); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_META_SEG_ID, sizeof(meta), meta_padding_size, 0); } int FlatSparseBuilder::dump_vector_and_offset(IndexDumper *dumper, std::vector *keys) { // iterate the holder auto iter = holder_->create_iterator(); if (!iter) { LOG_ERROR("Failed to create iterator"); return IndexError_Runtime; } uint64_t written_length{0U}; std::vector> offset_lens; while (iter->is_valid()) { keys->push_back(iter->key()); uint32_t length; if (write_vector_data(iter->sparse_count(), iter->sparse_indices(), iter->sparse_data(), dumper, &length) != 0) { return IndexError_WriteData; } offset_lens.push_back({written_length, length}); written_length += length; iter->next(); } if (dumper->append(PARAM_FLAT_SPARSE_DUMP_DATA_SEG_ID, written_length, 0, 0) != 0) { LOG_ERROR("Failed to append offset data"); return IndexError_WriteData; } LOG_DEBUG("Data total written: %zu", (size_t)written_length); for (auto &offset_len : offset_lens) { if (dumper->write(&offset_len.first, sizeof(offset_len.first)) != sizeof(offset_len.first)) { LOG_ERROR("Failed to write offset"); return IndexError_WriteData; } if (dumper->write(&offset_len.second, sizeof(offset_len.second)) != sizeof(offset_len.second)) { LOG_ERROR("Failed to write length"); return IndexError_WriteData; } } if (dumper->append(PARAM_FLAT_SPARSE_DUMP_OFFSET_SEG_ID, offset_lens.size() * (sizeof(uint64_t) + sizeof(uint32_t)), 0, 0) != 0) { LOG_ERROR("Failed to append offset data"); return IndexError_WriteData; } LOG_DEBUG("Offset total written: %zu", offset_lens.size() * (sizeof(uint64_t) + sizeof(uint32_t))); return 0; } int FlatSparseBuilder::write_vector_data(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec, IndexDumper *dumper, uint32_t *length) { std::string sparse_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_vec, meta_.unit_size(), sparse_buffer); if (dumper->write(sparse_buffer.data(), sparse_buffer.size()) != sparse_buffer.size()) { LOG_ERROR("Failed to write sparse data"); return IndexError_WriteData; } *length = sparse_buffer.size(); return 0; } int FlatSparseBuilder::dump_keys(const std::vector &keys, IndexDumper *dumper) { size_t keys_size = keys.size() * sizeof(uint64_t); if (dumper->write(keys.data(), keys_size) != keys_size) { LOG_ERROR("Failed to write keys to dumper %s", dumper->name().c_str()); return IndexError_WriteData; } size_t keys_padding_size = ailego_align(keys_size, 32) - keys_size; if (keys_padding_size) { std::string padding(keys_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write padding to dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_DUMP_KEYS_SEG_ID, keys_size, keys_padding_size, 0); } int FlatSparseBuilder::dump_mapping(const std::vector &keys, IndexDumper *dumper) { std::vector mapping(keys.size()); std::iota(mapping.begin(), mapping.end(), 0); std::sort( mapping.begin(), mapping.end(), [&keys](uint32_t lhs, uint32_t rhs) { return (keys[lhs] < keys[rhs]); }); size_t mapping_size = mapping.size() * sizeof(uint32_t); size_t mapping_padding_size = ailego_align(mapping_size, 32) - mapping_size; if (dumper->write(mapping.data(), mapping_size) != mapping_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } // Write the padding if need if (mapping_padding_size) { std::string padding(mapping_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_DUMP_MAPPING_SEG_ID, mapping_size, mapping_padding_size, 0); } INDEX_FACTORY_REGISTER_BUILDER(FlatSparseBuilder); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include namespace zvec { namespace core { /*! Brute Force Sparse Builder */ class FlatSparseBuilder : public IndexBuilder { public: //! Constructor FlatSparseBuilder(); //! Initialize the builder int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder int cleanup(void) override; //! Train the data int train(IndexThreads::Pointer, IndexSparseHolder::Pointer holder) override; //! Train the data int train(const IndexTrainer::Pointer &trainer) override; int train(IndexThreads::Pointer /*threads*/, IndexHolder::Pointer /*holder*/) override { return IndexError_NotImplemented; } int build(IndexThreads::Pointer /*threads*/, IndexHolder::Pointer /*holder*/) override { return IndexError_NotImplemented; } //! Build the index int build(IndexThreads::Pointer threads, IndexSparseHolder::Pointer holder) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } private: int do_dump(const IndexDumper::Pointer &dumper, uint32_t *dump_count); int dump_meta(IndexDumper *dumper); int dump_keys(const std::vector &keys, IndexDumper *dumper); int dump_mapping(const std::vector &keys, IndexDumper *dumper); int dump_vector_and_offset(IndexDumper *dumper, std::vector *keys); int write_vector_data(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec, IndexDumper *dumper, uint32_t *length); private: enum BUILD_STATE { BUILD_STATE_INIT = 0, BUILD_STATE_INITED = 1, BUILD_STATE_TRAINED = 2, BUILD_STATE_BUILT = 3 }; IndexSparseHolder::Pointer holder_{}; std::atomic_bool error_{false}; IndexMeta meta_{}; IndexMetric::Pointer measure_{}; std::mutex mutex_{}; std::condition_variable cond_{}; Stats stats_{}; BUILD_STATE state_{BUILD_STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_context.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_context.h" namespace zvec { namespace core { const FlatSparseEntity *FlatSparseContext::entity() const { if (context_type_ == kStreamerContext) { return &streamer_owner_->entity(); } else if (context_type_ == kSearcherContext) { return &searcher_owner_->entity(); } return nullptr; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "flat_sparse_entity.h" #include "flat_sparse_searcher.h" #include "flat_sparse_streamer.h" namespace zvec { namespace core { class FlatSparseStreamer; class FlatSparseSearcher; /*! Brute Force Sparse Streamer Context */ class FlatSparseContext : public IndexContext { public: //! Constructor enum ContextType { kUnknownContext = 0, kSearcherContext = 1, kStreamerContext = 3 }; FlatSparseContext(const FlatSparseStreamer *streamer_ptr) : streamer_owner_(streamer_ptr), context_type_(kStreamerContext) {} FlatSparseContext(const FlatSparseSearcher *searcher_ptr) : searcher_owner_(searcher_ptr), context_type_(kSearcherContext) {} //! Destructor virtual ~FlatSparseContext(void) = default; //! Set topk of search result void set_topk(uint32_t topk) override { topk_ = topk; result_heap_.limit(topk_); result_heap_.set_threshold(this->threshold()); } //! Retrieve search result const IndexDocumentList &result(void) const override { return results_.at(0); } //! Retrieve search result with index const IndexDocumentList &result(size_t index) const override { return results_.at(index); } //! Retrieve result object for output IndexDocumentList *mutable_result(size_t idx) override { return &results_.at(idx); } inline IndexDocumentHeap *result_heap() { return &result_heap_; } //! Update the parameters of context int update(const ailego::Params & /*params*/) override { return 0; } //! Retrieve magic number uint32_t magic(void) const override { return magic_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector() const override { return fetch_vector_; } //! Retrieve search group result with index const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index const IndexGroupDocumentList &group_result(size_t idx) const override { return group_results_[idx]; } IndexGroupDocumentList *mutable_group_result(size_t idx) { return &group_results_[idx]; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; result_group_heap_.clear(); } //! Get if group by search inline bool group_by_search() { return group_num_ > 0; } inline uint32_t group_topk() const { return group_topk_; } inline uint32_t group_num() const { return group_num_; } void reset() override {} //! Reset the context void reset(const FlatSparseStreamer *streamer_ptr) { magic_ = streamer_ptr->magic(); streamer_owner_ = streamer_ptr; context_type_ = kStreamerContext; } void reset(const FlatSparseSearcher *searcher_ptr) { magic_ = searcher_ptr->magic(); searcher_owner_ = searcher_ptr; context_type_ = kSearcherContext; } //! Reset all the query results void reset_results(size_t qnum) { if (group_by_search()) { group_results_.resize(qnum); } else { result_heap_.clear(); result_heap_.limit(topk_); result_heap_.set_threshold(this->threshold()); results_.resize(qnum); stats_vec_.resize(qnum); for (size_t i = 0; i < results_.size(); ++i) { results_[i].clear(); stats_vec_[i].clear(); } } } Stats *mutable_stats(size_t idx = 0) { ailego_assert_with(stats_vec_.size() > idx, "invalid index"); return &stats_vec_[idx]; } inline void topk_to_result(uint32_t idx) { if (ailego_unlikely(result_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(result_heap_.size())); result_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = result_heap_[i].score(); if (score > this->threshold()) { break; } key_t key = result_heap_[i].key(); if (fetch_vector_) { node_id_t id = entity()->get_id(key); IndexStorage::MemoryBlock vec_block; entity()->get_sparse_vector(id, vec_block); const void *sparse_data = vec_block.data(); IndexSparseDocument sparse_doc; if (sparse_data != nullptr) { SparseUtility::ReverseSparseFormat(sparse_data, sparse_doc, entity()->sparse_unit_size()); } results_[idx].emplace_back(key, score, id, nullptr, sparse_doc); } else { results_[idx].emplace_back(key, score); } } } private: const FlatSparseEntity *entity() const; private: const FlatSparseStreamer *streamer_owner_{nullptr}; const FlatSparseSearcher *searcher_owner_{nullptr}; ContextType context_type_{kUnknownContext}; std::vector stats_vec_{}; uint32_t magic_{0}; uint32_t topk_{0}; IndexDocumentHeap result_heap_; // std::string batch_queries_{}; bool fetch_vector_{false}; // group uint32_t group_num_{0}; uint32_t group_topk_{0}; std::map result_group_heap_{}; std::vector results_{}; std::vector group_results_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "flat_sparse_index_format.h" namespace zvec { namespace core { using node_id_t = uint32_t; constexpr node_id_t kInvalidNodeId = static_cast(-1); /*! Flat Sparse Entity */ class FlatSparseEntity { public: typedef std::shared_ptr Pointer; //! Constructor explicit FlatSparseEntity() {} //! Destructor ~FlatSparseEntity() = default; //! Disable them FlatSparseEntity(const FlatSparseEntity &) = delete; FlatSparseEntity &operator=(const FlatSparseEntity &) = delete; //! Search in linear list with filter int search(const std::string &sparse_vector, const IndexFilter &filter, IndexDocumentHeap *heap) const { for (node_id_t i = 0; i < doc_cnt(); i++) { uint64_t key = get_key(i); if (ailego_unlikely(key == kInvalidKey)) { // LOG_ERROR("The key of node_id[%u] not found in keys map", i); // return IndexError_Runtime; continue; } if (!filter.is_valid() || !filter(key)) { float dist = get_search_distance(sparse_vector, i); heap->emplace(key, dist); } } return 0; } //! Search in linear list with filter and target pkeys int search_p_keys(const std::string &sparse_vector, const std::vector &p_keys, const IndexFilter &filter, IndexDocumentHeap *heap) const { for (auto p_key : p_keys) { if (!filter.is_valid() || !filter(p_key)) { auto node_id = get_id(p_key); if (node_id != kInvalidNodeId) { float dist = get_search_distance(sparse_vector, node_id); heap->emplace(p_key, dist); } } } return 0; } //! Group search in linear list with filter int search_group( const std::string &sparse_vector, const IndexFilter &filter, const std::function &group_by_func, uint32_t topk, std::unordered_map *heap) const { for (node_id_t i = 0; i < doc_cnt(); i++) { uint64_t key = get_key(i); if (ailego_unlikely(key == kInvalidKey)) { LOG_ERROR("The key of node_id[%u] not found in keys map", i); return IndexError_Runtime; } if (!filter.is_valid() || !filter(key)) { float dist = get_search_distance(sparse_vector, i); std::string group_id = group_by_func(key); auto &group_heap = (*heap)[group_id]; if (group_heap.empty()) { group_heap.limit(topk); } group_heap.emplace(key, dist); } } return 0; } //! Group search in linear list with filter and target pkeys int search_group_p_keys( const std::string &sparse_vector, const std::vector &p_keys, const IndexFilter &filter, const std::function &group_by_func, uint32_t topk, std::unordered_map *heap) const { for (auto p_key : p_keys) { if (!filter.is_valid() || !filter(p_key)) { auto node_id = get_id(p_key); if (node_id != kInvalidNodeId) { float dist = get_search_distance(sparse_vector, node_id); std::string group_id = group_by_func(p_key); auto &group_heap = (*heap)[group_id]; if (group_heap.empty()) { group_heap.limit(topk); } group_heap.emplace(p_key, dist); } } } return 0; } //! Get sparse vector by key int get_sparse_vector(uint64_t key, std::string *sparse_vector) const { const void *sparse_vector_ptr; uint32_t sparse_vector_len; int ret = get_sparse_vector_ptr_by_key(key, &sparse_vector_ptr, &sparse_vector_len); if (ret != 0) { return ret; } *sparse_vector = std::string(static_cast(sparse_vector_ptr), sparse_vector_len); return 0; } //! Get sparse vector by node id const void *get_sparse_vector(node_id_t id) const { const void *sparse_vector_ptr; uint32_t sparse_vector_len; int ret = get_sparse_vector_ptr_by_id(id, &sparse_vector_ptr, &sparse_vector_len); if (ret != 0) { return nullptr; } return sparse_vector_ptr; } int get_sparse_vector_by_key(const uint64_t key, std::string *sparse_vector) const { uint32_t sparse_vector_len; IndexStorage::MemoryBlock sparse_vector_block; int ret = get_sparse_vector_ptr_by_key(key, sparse_vector_block, &sparse_vector_len); if (ret != 0) { return ret; } *sparse_vector = std::string(static_cast(sparse_vector_block.data()), sparse_vector_len); return 0; } int get_sparse_vector(node_id_t id, IndexStorage::MemoryBlock &sparse_vector_block) const { uint32_t sparse_vector_len; return get_sparse_vector_ptr_by_id(id, sparse_vector_block, &sparse_vector_len); } int get_sparse_vector_ptr_by_key(uint64_t key, const void **sparse_vector_ptr, uint32_t *sparse_vector_len_ptr) const { auto node_id = get_id(key); if (node_id == kInvalidNodeId) { return IndexError_NoExist; } return get_sparse_vector_ptr_by_id(node_id, sparse_vector_ptr, sparse_vector_len_ptr); } int get_sparse_vector_ptr_by_key( const uint64_t key, IndexStorage::MemoryBlock &sparse_vector_block, uint32_t *sparse_vector_len_ptr) const { auto node_id = get_id(key); if (node_id == kInvalidNodeId) { return IndexError_NoExist; } return get_sparse_vector_ptr_by_id(node_id, sparse_vector_block, sparse_vector_len_ptr); } std::vector get_keys() const { std::vector keys; node_id_t doc_total_cnt = doc_cnt(); for (node_id_t node_id = 0; node_id < doc_total_cnt; ++node_id) { uint64_t key = get_key(node_id); if (key == kInvalidKey) { return {kInvalidKey}; } else { keys.push_back(key); } } return keys; } public: virtual uint32_t doc_cnt() const = 0; virtual uint32_t total_sparse_count() const = 0; virtual node_id_t get_id(uint64_t key) const = 0; virtual uint64_t get_key(node_id_t id) const = 0; virtual int get_sparse_vector_ptr_by_id( node_id_t id, const void **sparse_vector, uint32_t *sparse_vector_len) const = 0; virtual int get_sparse_vector_ptr_by_id( const node_id_t /*id*/, IndexStorage::MemoryBlock & /*sparse_vector_block*/, uint32_t * /*sparse_vector_len*/) const { return IndexError_NotImplemented; } virtual float get_search_distance(const std::string &vector, node_id_t target_node_id) const = 0; virtual size_t sparse_unit_size() const = 0; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_index_format.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { static constexpr uint64_t kInvalidKey = std::numeric_limits::max(); static constexpr uint32_t kDefaultOffsetChunkSize = 1024 * 1024; // 1MB static constexpr uint32_t kDefaultDataChunkSize = 8 * 1024 * 1024; // 8MB struct FlatSparseMeta { uint64_t create_time{0}; uint64_t update_time{0}; uint32_t doc_cnt{0}; uint32_t total_sparse_count{0}; uint8_t reserved[8] = {0}; }; static_assert(sizeof(FlatSparseMeta) % 32 == 0, "FlatSparseMeta must be aligned with 32 bytes"); struct FlatSparseStreamerMeta { uint32_t offset_chunk_count{0}; uint32_t offset_chunk_size{kDefaultOffsetChunkSize}; uint32_t data_chunk_count{0}; uint32_t data_chunk_size{kDefaultDataChunkSize}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "flat_sparse_streamer_entity.h" namespace zvec { namespace core { /*! Brute Force Sparse Streamer Provider */ // FlatSparseStreamerEntity or FlatSparseSearcherEntity template class FlatSparseIndexProvider : public IndexSparseProvider { public: //! Constructor FlatSparseIndexProvider(const std::shared_ptr entity, const IndexMeta &meta, const std::string &owner) : entity_(entity), meta_(meta), owner_class_(owner) {} //! Create a new iterator IndexSparseProvider::Iterator::Pointer create_iterator(void) override { return IndexSparseProvider::Iterator::Pointer(new (std::nothrow) Iterator(entity_, meta_)); } //! Retrieve count of vectors size_t count(void) const override { return entity_->doc_cnt(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return meta_.data_type(); } //! Retrieve a vector using a primary key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override { std::string sparse_data; int ret = entity_->get_sparse_vector_by_key(key, &sparse_data); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to get sparse vector, key=%zu, ret=%s", (size_t)key, IndexError::What(ret)); return ret; } SparseUtility::ReverseSparseFormat(sparse_data, sparse_count, sparse_indices_buffer, sparse_values_buffer, meta_.unit_size()); return 0; } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_class_; } size_t total_sparse_count() const override { return entity_->total_sparse_count(); } private: class Iterator : public IndexSparseProvider::Iterator { public: Iterator(const std::shared_ptr &entity, const IndexMeta &meta) : entity_(entity), meta_(meta), cur_id_(0U), valid_(false) { IndexStorage::MemoryBlock sparse_data_block; entity_->get_sparse_vector(cur_id_, sparse_data_block); const void *sparse_data = sparse_data_block.data(); if (sparse_data != nullptr) { valid_ = true; sparse_indices_buffer_.clear(); sparse_data_buffer_.clear(); SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, meta.unit_size()); } } //! Retrieve sparse count virtual uint32_t sparse_count() const override { return sparse_count_; } //! Retrieve sparse indices virtual const uint32_t *sparse_indices() const override { return reinterpret_cast(sparse_indices_buffer_.data()); } //! Retrieve sparse data virtual const void *sparse_data() const override { return reinterpret_cast(sparse_data_buffer_.data()); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return cur_id_ < entity_->doc_cnt() && valid_; } //! Retrieve primary key virtual uint64_t key(void) const override { // std::cout << "iter key=" << cur_id_ << std::endl; return entity_->get_key(cur_id_); } //! Next iterator virtual void next(void) override { cur_id_ = get_next_valid_id(cur_id_ + 1); if (cur_id_ < entity_->doc_cnt()) { IndexStorage::MemoryBlock sparse_data_block; entity_->get_sparse_vector(cur_id_, sparse_data_block); const void *sparse_data = sparse_data_block.data(); if (sparse_data != nullptr) { valid_ = true; sparse_indices_buffer_.clear(); sparse_data_buffer_.clear(); SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, meta_.unit_size()); } else { valid_ = false; } } } //! Reset the iterator void reset(void) { cur_id_ = get_next_valid_id(0); IndexStorage::MemoryBlock sparse_data_block; entity_->get_sparse_vector(cur_id_, sparse_data_block); const void *sparse_data = sparse_data_block.data(); if (sparse_data != nullptr) { valid_ = true; SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, meta_.unit_size()); } } private: node_id_t get_next_valid_id(node_id_t start_id) { for (node_id_t i = start_id; i < entity_->doc_cnt(); i++) { if (entity_->get_key(i) != kInvalidNodeId) { return i; } } return kInvalidNodeId; } private: const std::shared_ptr entity_{nullptr}; const IndexMeta &meta_; node_id_t cur_id_; uint32_t sparse_count_; std::string sparse_indices_buffer_; std::string sparse_data_buffer_; bool valid_{false}; }; private: const std::shared_ptr entity_{nullptr}; const IndexMeta &meta_; const std::string owner_class_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_search.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "flat_sparse_context.h" namespace zvec { namespace core { static inline IndexGroupDocumentList ConvertGroupMapToResult( std::unordered_map group_map, uint32_t group_num) { IndexGroupDocumentList result; std::vector> best_score_in_groups; for (auto itr = group_map.begin(); itr != group_map.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; if (heap.size() > 0) { float best_score = heap[0].score(); best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; result.emplace_back( GroupIndexDocument(group_id, std::move(group_map[group_id]))); } return result; } static inline int FlatSearch(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, bool with_p_keys, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, const IndexMeta, IndexContext::Pointer &context, FlatSparseEntity *entity) { int ret; FlatSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to FlatSparseContext failed"); return IndexError_Cast; } // reset context results ctx->reset_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](uint64_t key) { return ctx->group_by()(key); }; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; ailego::MinusInnerProductSparseMatrix::transform_sparse_format( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, sparse_query_buffer); std::unordered_map group_heap{}; if (with_p_keys) { ret = entity->search_group_p_keys(sparse_query_buffer, p_keys[q], ctx->filter(), group_by, ctx->group_topk(), &group_heap); } else { ret = entity->search_group(sparse_query_buffer, ctx->filter(), group_by, ctx->group_topk(), &group_heap); } if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to search group, ret=%s", IndexError::What(ret)); return ret; } // sort group heap for (auto &group : group_heap) { group.second.sort(); } auto group_result = ConvertGroupMapToResult(std::move(group_heap), ctx->group_num()); ctx->mutable_group_result(q)->swap(group_result); } } else { for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; ailego::MinusInnerProductSparseMatrix::transform_sparse_format( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, sparse_query_buffer); auto heap = ctx->result_heap(); if (with_p_keys) { ret = entity->search_p_keys(sparse_query_buffer, p_keys[q], ctx->filter(), heap); } else { ret = entity->search(sparse_query_buffer, ctx->filter(), heap); } if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to search, ret=%s", IndexError::What(ret)); return ret; } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_searcher.h" #include #include #include "flat_sparse_context.h" #include "flat_sparse_provider.h" #include "flat_sparse_search.h" namespace zvec { namespace core { const uint32_t FlatSparseSearcher::VERSION = 0U; FlatSparseSearcher::FlatSparseSearcher(void) {} FlatSparseSearcher::~FlatSparseSearcher(void) {} int FlatSparseSearcher::init(const ailego::Params & /*params*/) { state_ = STATE_INITED; return 0; } int FlatSparseSearcher::cleanup(void) { this->unload(); return 0; } int FlatSparseSearcher::load(IndexStorage::Pointer container, IndexMetric::Pointer /*measure*/) { if (state_ != STATE_INITED) { LOG_ERROR("Init the searcher first before load index"); return IndexError_Runtime; } LOG_INFO("Begin FlatSparseSearcher::load"); int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } if (meta_.searcher_revision() != VERSION) { LOG_ERROR("Unsupported searcher revision %u", meta_.searcher_revision()); return IndexError_Unsupported; } ret = entity_.load(container, meta_); if (ret != 0) { LOG_ERROR("FlatSparseSearcher load index failed"); return ret; } state_ = STATE_LOADED; magic_ = IndexContext::GenerateMagic(); LOG_INFO("End FlatSparseSearcher::load"); return 0; } int FlatSparseSearcher::unload(void) { LOG_INFO("Begin FlatSparseSearcher::unload"); meta_.clear(); entity_.unload(); state_ = STATE_INITED; LOG_INFO("End FlatSparseSearcher::unload"); return 0; } int FlatSparseSearcher::search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { return do_search(sparse_count, sparse_indices, sparse_query, false, {}, qmeta, count, context); } int FlatSparseSearcher::search_bf_by_p_keys_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const { return do_search(sparse_count, sparse_indices, sparse_query, true, p_keys, qmeta, count, context); } int FlatSparseSearcher::get_sparse_vector( uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { if (state_ != STATE_LOADED) { LOG_ERROR("Failed to get sparse vector, load container first!"); return IndexError_NoIndexLoaded; } std::string sparse_data; int ret = entity_.get_sparse_vector(key, &sparse_data); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to get sparse vector, key=%zu, ret=%s", (size_t)key, IndexError::What(ret)); return ret; } SparseUtility::ReverseSparseFormat(sparse_data, sparse_count, sparse_indices_buffer, sparse_values_buffer, meta_.unit_size()); return 0; } FlatSparseSearcher::ContextPointer FlatSparseSearcher::create_context() const { if (state_ != STATE_LOADED) { LOG_ERROR("Failed to create Context, load container first!"); return Context::UPointer(); } FlatSparseSearcherEntity::Pointer entity = entity_.clone(); return FlatSparseSearcher::ContextPointer(new FlatSparseContext(this)); } //! Create a new iterator IndexSearcher::SparseProvider::Pointer FlatSparseSearcher::create_sparse_provider(void) const { if (state_ != STATE_LOADED) { LOG_ERROR("Failed to create provider, load container first!"); return SparseProvider::Pointer(); } auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone entity failed"); return SparseProvider::Pointer(); } return SparseProvider::Pointer( new FlatSparseIndexProvider( entity, meta_, "FlatSparseSearcher")); } int FlatSparseSearcher::do_search( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, bool with_p_keys, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const { if (state_ != STATE_LOADED) { LOG_ERROR("Failed to do search, load container first!"); return IndexError_NoIndexLoaded; } int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } return FlatSearch(sparse_count, sparse_indices, sparse_query, with_p_keys, p_keys, qmeta, count, meta_, context, (FlatSparseEntity *)&entity_); } INDEX_FACTORY_REGISTER_SEARCHER(FlatSparseSearcher); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "flat_sparse_searcher_entity.h" namespace zvec { namespace core { class FlatSparseSearcher : public IndexSearcher { public: static const uint32_t VERSION; public: using ContextPointer = IndexSearcher::Context::Pointer; public: FlatSparseSearcher(void); virtual ~FlatSparseSearcher(void); FlatSparseSearcher(const FlatSparseSearcher &) = delete; FlatSparseSearcher &operator=(const FlatSparseSearcher &) = delete; public: //! Initialize Searcher int init(const ailego::Params ¶ms) override; //! Cleanup Searcher int cleanup(void) override; //! Load Index from storage int load(IndexStorage::Pointer container, IndexMetric::Pointer /*measure*/) override; //! Unload index from storage int unload(void) override; int search_impl(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, Context::Pointer & /*context*/) const override { return IndexError_NotImplemented; } int search_impl(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, uint32_t /*count*/, Context::Pointer & /*context*/) const override { return IndexError_NotImplemented; } int search_bf_impl(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, Context::Pointer & /*context*/) const override { return IndexError_NotImplemented; } int search_bf_impl(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, uint32_t /*count*/, Context::Pointer & /*context*/) const override { return IndexError_NotImplemented; } //! Similarity search with sparse inputs int search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return search_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity search with sparse inputs int search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override { return search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context); } //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return search_bf_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(&sparse_count, sparse_indices, sparse_query, p_keys, qmeta, 1, context); } //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch sparser vector by key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; //! Create a searcher context ContextPointer create_context() const override; //! Create a new iterator IndexSearcher::SparseProvider::Pointer create_sparse_provider( void) const override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index const ailego::Params ¶ms(void) const override { return params_; } const FlatSparseSearcherEntity &entity(void) const { return entity_; } uint32_t magic(void) const { return magic_; } private: inline int check_params(const IndexQueryMeta &qmeta) const { if (ailego_unlikely(qmeta.data_type() != meta_.data_type())) { LOG_ERROR("Unsupported query meta"); return IndexError_Mismatch; } return 0; } int do_search(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, bool with_p_keys, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; FlatSparseSearcherEntity entity_{}; IndexMeta meta_{}; ailego::Params params_{}; uint32_t magic_{0U}; Stats stats_; State state_{STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_searcher_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_searcher_entity.h" #include #include #include "flat_sparse_utility.h" namespace zvec { namespace core { FlatSparseSearcherEntity::FlatSparseSearcherEntity() {} int FlatSparseSearcherEntity::load(const IndexStorage::Pointer &container, const IndexMeta &index_meta) { if (container_) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } int ret = this->load_container(container); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to load storage index"); return ret; } if (init_measure(index_meta) != 0) { LOG_ERROR("Failed to init measure"); return IndexError_InvalidFormat; } container_ = container; return 0; } int FlatSparseSearcherEntity::init_measure(const IndexMeta &meta) { measure_ = IndexFactory::CreateMetric(meta.metric_name()); if (!measure_) { LOG_ERROR("Failed to create measure %s", meta.metric_name().c_str()); return IndexError_NoExist; } int ret = measure_->init(meta, meta.metric_params()); if (ret != 0) { LOG_ERROR("Failled to init measure, ret=%d", ret); return ret; } if (!measure_->sparse_distance()) { LOG_ERROR("Invalid measure distance"); return IndexError_InvalidArgument; } search_sparse_distance_ = measure_->sparse_distance(); if (measure_->query_metric() && measure_->query_metric()->distance()) { search_sparse_distance_ = measure_->query_metric()->sparse_distance(); } sparse_unit_size_ = meta.unit_size(); return 0; } int FlatSparseSearcherEntity::load_container( const IndexStorage::Pointer &container) { // meta auto segment = container->get(PARAM_FLAT_SPARSE_META_SEG_ID); if (!segment || segment->data_size() < sizeof(meta_)) { LOG_ERROR("Missing segment %s, or invalid segment size", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_InvalidFormat; } const void *data; if (ailego_unlikely(segment->read(0, &data, sizeof(meta_)) != sizeof(meta_))) { LOG_ERROR("Failed to read meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_ReadData; } meta_ = *(reinterpret_cast(data)); // keys segment keys_chunk_ = container->get(PARAM_FLAT_SPARSE_DUMP_KEYS_SEG_ID); if (!keys_chunk_) { LOG_ERROR("Missing segment %s", PARAM_FLAT_SPARSE_DUMP_KEYS_SEG_ID.c_str()); return IndexError_InvalidFormat; } // mapping segment mapping_chunk_ = container->get(PARAM_FLAT_SPARSE_DUMP_MAPPING_SEG_ID); if (!mapping_chunk_) { LOG_ERROR("Missing segment %s", PARAM_FLAT_SPARSE_DUMP_MAPPING_SEG_ID.c_str()); return IndexError_InvalidFormat; } // offset segment sparse_offset_chunk_ = container->get(PARAM_FLAT_SPARSE_DUMP_OFFSET_SEG_ID); if (!sparse_offset_chunk_) { LOG_ERROR("Missing segment %s", PARAM_FLAT_SPARSE_DUMP_OFFSET_SEG_ID.c_str()); return IndexError_InvalidFormat; } // data segment sparse_data_chunk_ = container->get(PARAM_FLAT_SPARSE_DUMP_DATA_SEG_ID); if (!sparse_data_chunk_) { LOG_ERROR("Missing segment %s", PARAM_FLAT_SPARSE_DUMP_DATA_SEG_ID.c_str()); return IndexError_InvalidFormat; } return 0; } int FlatSparseSearcherEntity::unload() { container_.reset(); sparse_data_chunk_.reset(); sparse_offset_chunk_.reset(); keys_chunk_.reset(); mapping_chunk_.reset(); return 0; } FlatSparseSearcherEntity::Pointer FlatSparseSearcherEntity::clone() const { auto entity = new (std::nothrow) FlatSparseSearcherEntity(meta_, sparse_data_chunk_, sparse_offset_chunk_, keys_chunk_, mapping_chunk_); return FlatSparseSearcherEntity::Pointer(entity); } int FlatSparseSearcherEntity::get_sparse_vector_ptr_by_id( node_id_t id, const void **sparse_vector_ptr, uint32_t *sparse_vector_len_ptr) const { uint32_t offset_chunk_offset = id * offset_size_per_node(); const void *offset_info = nullptr; if (ailego_unlikely(sparse_offset_chunk_->read( offset_chunk_offset, &offset_info, offset_size_per_node()) != offset_size_per_node())) { LOG_ERROR("Read offset info failed, offset=%u", offset_chunk_offset); return IndexError_ReadData; }; // sparse offset uint64_t sparse_offset = *(uint64_t *)offset_info; uint32_t sparse_vector_len = *(uint32_t *)((uint8_t *)offset_info + sizeof(uint64_t)); if (sparse_vector_len > 0) { const void *sparse_data = get_sparse_vector_data(sparse_offset, sparse_vector_len); if (ailego_unlikely(sparse_data == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)sparse_offset, sparse_vector_len); return IndexError_ReadData; } *sparse_vector_ptr = sparse_data; *sparse_vector_len_ptr = sparse_vector_len; } return 0; } const void *FlatSparseSearcherEntity::get_sparse_vector_data( uint64_t offset, uint32_t length) const { const void *data; auto size = sparse_data_chunk_->read(offset, &data, length); if (size != length) { LOG_ERROR( "read sparse vector data failed: offset=%zu, " "length=%u, size=%zu", (size_t)offset, length, size); return nullptr; } return data; } node_id_t FlatSparseSearcherEntity::get_id(uint64_t key) const { if (ailego_unlikely(!mapping_chunk_)) { LOG_ERROR("Index missing mapping segment"); return kInvalidNodeId; } //! Do binary search node_id_t start = 0UL; node_id_t end = doc_cnt(); const void *data; node_id_t idx = 0u; while (start < end) { idx = start + (end - start) / 2; if (ailego_unlikely(mapping_chunk_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != sizeof(node_id_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } const uint64_t *mkey; node_id_t local_id = *reinterpret_cast(data); if (ailego_unlikely(keys_chunk_->read( local_id * sizeof(uint64_t), (const void **)(&mkey), sizeof(uint64_t)) != sizeof(uint64_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } if (*mkey < key) { start = idx + 1; } else if (*mkey > key) { end = idx; } else { return local_id; } } return kInvalidNodeId; } uint64_t FlatSparseSearcherEntity::get_key(node_id_t id) const { const void *key; if (ailego_unlikely( keys_chunk_->read(id * sizeof(uint64_t), &key, sizeof(uint64_t)) != sizeof(uint64_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidKey; } return *(reinterpret_cast(key)); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_searcher_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "flat_sparse_entity.h" #include "flat_sparse_index_format.h" namespace zvec { namespace core { /*! Flat Sparse Searcher Entity */ class FlatSparseSearcherEntity : public FlatSparseEntity { public: typedef std::shared_ptr Pointer; using Chunk = IndexStorage::Segment; //! Constructor explicit FlatSparseSearcherEntity(); //! Destructor virtual ~FlatSparseSearcherEntity() = default; //! Disable them FlatSparseSearcherEntity(const FlatSparseSearcherEntity &) = delete; FlatSparseSearcherEntity &operator=(const FlatSparseSearcherEntity &) = delete; //! Load the entity with container int load(const IndexStorage::Pointer &container, const IndexMeta &index_meta); //! Unload the entity int unload(); public: inline uint32_t doc_cnt() const override { return meta_.doc_cnt; } inline uint32_t total_sparse_count() const override { return meta_.total_sparse_count; } size_t sparse_unit_size() const override { return sparse_unit_size_; } float get_search_distance(const std::string &vector, node_id_t target_node_id) const override { float dist; const void *target_vector; uint32_t target_vector_len; get_sparse_vector_ptr_by_id(target_node_id, &target_vector, &target_vector_len); search_sparse_distance_(vector.c_str(), target_vector, &dist); return dist; } FlatSparseSearcherEntity::Pointer clone() const; node_id_t get_id(uint64_t key) const override; uint64_t get_key(node_id_t id) const override; int get_sparse_vector_ptr_by_id(node_id_t id, const void **sparse_vector, uint32_t *sparse_vector_len) const override; private: int load_container(const IndexStorage::Pointer &container); int init_measure(const IndexMeta &meta); inline uint32_t offset_size_per_node() const { return sizeof(uint64_t) + sizeof(uint32_t); } const void *get_sparse_vector_data(uint64_t offset, uint32_t length) const; private: FlatSparseSearcherEntity(const FlatSparseMeta &meta, Chunk::Pointer sparse_data_chunk, Chunk::Pointer sparse_offset_chunk, Chunk::Pointer keys_chunk, Chunk::Pointer mapping_chunk) : meta_(meta), sparse_data_chunk_(sparse_data_chunk), sparse_offset_chunk_(sparse_offset_chunk), keys_chunk_(keys_chunk), mapping_chunk_(mapping_chunk) {} private: IndexStorage::Pointer container_{}; // meta FlatSparseMeta meta_; // measure IndexMetric::Pointer measure_{}; IndexMetric::MatrixSparseDistance search_sparse_distance_{}; // chunk Chunk::Pointer sparse_data_chunk_; Chunk::Pointer sparse_offset_chunk_; Chunk::Pointer keys_chunk_; Chunk::Pointer mapping_chunk_; size_t sparse_unit_size_{0U}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_streamer.h" #include #include #include #include #include #include "flat_sparse_context.h" #include "flat_sparse_provider.h" #include "flat_sparse_search.h" namespace zvec { namespace core { const uint32_t FlatSparseStreamer::VERSION = 0U; FlatSparseStreamer::FlatSparseStreamer() : entity_(stats_) {} FlatSparseStreamer::~FlatSparseStreamer() { this->close(); } int FlatSparseStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { LOG_DEBUG("FlatSparseStreamer init"); meta_ = imeta; meta_.set_streamer("FlatSparseStreamer", VERSION, params); state_ = STATE_INITED; return 0; } int FlatSparseStreamer::cleanup() { LOG_DEBUG("FlatSparseStreamer cleanup"); this->close(); meta_.clear(); return 0; } int FlatSparseStreamer::open(IndexStorage::Pointer stg) { LOG_DEBUG("FlatSparseStreamer open"); if (ailego_unlikely(state_ != STATE_INITED)) { LOG_ERROR("Open storage failed, init streamer first!"); return IndexError_NoReady; } int ret = entity_.open(std::move(stg), meta_); if (ret != 0) { LOG_ERROR("FlatSparseStreamer entity failed to open storage"); return ret; } IndexMeta index_meta; ret = entity_.get_index_sparse_meta(&index_meta); if (ret == IndexError_NoExist) { // Set IndexMeta for the new index ret = entity_.set_index_sparse_meta(meta_); if (ret != 0) { LOG_ERROR("Failed to set index meta for %s", IndexError::What(ret)); return ret; } } else { if (index_meta.streamer_revision() != meta_.streamer_revision()) { LOG_ERROR("Streamer revision mismatch, expect=%u, actual=%u", meta_.streamer_revision(), index_meta.streamer_revision()); return IndexError_Mismatch; } if (index_meta.metric_name() != meta_.metric_name() || index_meta.data_type() != meta_.data_type()) { LOG_ERROR("IndexMeta mismatch from the previous in index"); return IndexError_Mismatch; } // The IndexMeasure Params may be updated like MipsSquaredEuclidean auto metric_params = index_meta.metric_params(); metric_params.merge(meta_.metric_params()); meta_.set_metric(index_meta.metric_name(), 0, metric_params); } state_ = STATE_OPENED; magic_ = IndexContext::GenerateMagic(); return 0; } int FlatSparseStreamer::close() { if (state_ != STATE_OPENED) { return 0; } LOG_DEBUG("FlatSparseStreamer close"); stats_.clear(); int ret = entity_.close(); if (ret != 0) { LOG_ERROR("Failed to close entity %s", IndexError::What(ret)); return ret; } state_ = STATE_INITED; return 0; } int FlatSparseStreamer::flush(uint64_t checkpoint) { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to flush, open streamer first!"); return IndexError_NoReady; } LOG_INFO("FlatSparseStreamer flush, checkpoint=%zu", (size_t)checkpoint); return entity_.flush(checkpoint); } int FlatSparseStreamer::dump(const IndexDumper::Pointer &dumper) { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to dump, open streamer first!"); return IndexError_NoReady; } LOG_INFO("FlatSparseStreamer dump"); shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); meta_.set_searcher("FlatSparseSearcher", VERSION, ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } return entity_.dump(dumper); } FlatSparseStreamer::ContextPointer FlatSparseStreamer::create_context() const { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to create Context, open streamer first!"); return Context::UPointer(); } FlatSparseStreamerEntity::Pointer entity = entity_.clone(); return FlatSparseStreamer::ContextPointer(new FlatSparseContext(this)); } IndexStreamer::SparseProvider::Pointer FlatSparseStreamer::create_sparse_provider(void) const { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to create provider, open streamer first!"); return SparseProvider::Pointer(); } auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone entity failed"); return SparseProvider::Pointer(); } return SparseProvider::Pointer( new FlatSparseIndexProvider( entity, meta_, "FlatSparseStreamerProvider")); } int FlatSparseStreamer::add_impl(uint64_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to add_impl, open streamer first!"); (*stats_.mutable_discarded_count())++; return IndexError_NoReady; } int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(sparse_count > PARAM_FLAT_SPARSE_MAX_DIM_SIZE)) { LOG_ERROR( "Failed to add sparse vector: number of non-zero elements (%u) exceeds " "maximum allowed (%u), key=%zu", sparse_count, PARAM_FLAT_SPARSE_MAX_DIM_SIZE, (size_t)pkey); (*stats_.mutable_discarded_count())++; return IndexError_InvalidValue; } // context is trivial here FlatSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to FlatSparseContext failed"); (*stats_.mutable_discarded_count())++; return IndexError_Cast; } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); // convert to sparse format and add to entity std::string sparse_query_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_query, meta_.unit_size(), sparse_query_buffer); ret = entity_.add(pkey, sparse_query_buffer, sparse_count); if (ret != 0) { LOG_ERROR("Failed to add sparse vector, key=%zu, ret=%s", (size_t)pkey, IndexError::What(ret)); (*stats_.mutable_discarded_count())++; return ret; } (*stats_.mutable_added_count())++; return 0; } int FlatSparseStreamer::add_with_id_impl(uint32_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to add_with_id_impl, open streamer first!"); (*stats_.mutable_discarded_count())++; return IndexError_NoReady; } int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(sparse_count > PARAM_FLAT_SPARSE_MAX_DIM_SIZE)) { LOG_ERROR( "Failed to add sparse vector: number of non-zero elements (%u) exceeds " "maximum allowed (%u), key=%zu", sparse_count, PARAM_FLAT_SPARSE_MAX_DIM_SIZE, (size_t)pkey); (*stats_.mutable_discarded_count())++; return IndexError_InvalidValue; } // context is trivial here FlatSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to FlatSparseContext failed"); (*stats_.mutable_discarded_count())++; return IndexError_Cast; } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); // convert to sparse format and add to entity std::string sparse_query_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_query, meta_.unit_size(), sparse_query_buffer); ret = entity_.add_vector_with_id(pkey, sparse_query_buffer, sparse_count); if (ret != 0) { LOG_ERROR("Failed to add sparse vector, key=%zu, ret=%s", (size_t)pkey, IndexError::What(ret)); (*stats_.mutable_discarded_count())++; return ret; } (*stats_.mutable_added_count())++; return 0; } //! Similarity search with sparse inputs int FlatSparseStreamer::search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return search_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity search with sparse inputs int FlatSparseStreamer::search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { return search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context); } //! Similarity brute force search with sparse inputs int FlatSparseStreamer::search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return search_bf_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Linear search by primary keys int FlatSparseStreamer::search_bf_by_p_keys_impl( const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const { return search_bf_by_p_keys_impl(&sparse_count, sparse_indices, sparse_query, p_keys, qmeta, 1, context); } //! Similarity brute force search with sparse inputs int FlatSparseStreamer::search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { return do_search(sparse_count, sparse_indices, sparse_query, false, {}, qmeta, count, context); } //! Linear search by primary keys with sparse inputs int FlatSparseStreamer::search_bf_by_p_keys_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const { return do_search(sparse_count, sparse_indices, sparse_query, true, p_keys, qmeta, count, context); } //! Fetch sparse vector by key int FlatSparseStreamer::get_sparse_vector( uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to get_sparse_vector, open streamer first!"); return IndexError_NoReady; } std::string sparse_data; int ret = entity_.get_sparse_vector_by_key(key, &sparse_data); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to get sparse vector, key=%zu, ret=%s", (size_t)key, IndexError::What(ret)); return ret; } SparseUtility::ReverseSparseFormat(sparse_data, sparse_count, sparse_indices_buffer, sparse_values_buffer, meta_.unit_size()); return 0; } int FlatSparseStreamer::do_search( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, bool with_p_keys, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const { if (state_ != STATE_OPENED) { LOG_ERROR("Failed to do_search, open streamer first!"); return IndexError_NoReady; } int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } FlatSparseContext *ctx = dynamic_cast(context.get()); if (ctx->magic() != magic_) { ctx->reset(this); } return FlatSearch(sparse_count, sparse_indices, sparse_query, with_p_keys, p_keys, qmeta, count, meta_, context, (FlatSparseEntity *)&entity_); } INDEX_FACTORY_REGISTER_STREAMER(FlatSparseStreamer); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "flat_sparse_streamer_entity.h" namespace zvec { namespace core { /*! Flat Sparse Streamer */ class FlatSparseStreamer : public IndexStreamer { public: static const uint32_t VERSION; public: using ContextPointer = IndexStreamer::Context::Pointer; FlatSparseStreamer(void); virtual ~FlatSparseStreamer(void); FlatSparseStreamer(const FlatSparseStreamer &streamer) = delete; FlatSparseStreamer &operator=(const FlatSparseStreamer &streamer) = delete; public: //! Initialize Streamer int init(const IndexMeta &, const ailego::Params &) override; //! Cleanup Streamer int cleanup(void) override; //! Open index from file path int open(IndexStorage::Pointer stg) override; //! Close file int close(void) override; //! flush file int flush(uint64_t checkpoint) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Create a context ContextPointer create_context(void) const override; //! Create a new iterator IndexStreamer::SparseProvider::Pointer create_sparse_provider( void) const override; int add_impl(uint64_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; int add_with_id_impl(uint32_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Similarity search with sparse inputs int search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search with sparse inputs int search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override; //! Linear search by primary keys with sparse inputs int search_bf_by_p_keys_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch sparse vector by key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; int get_sparse_vector_by_id( uint32_t id, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override { return get_sparse_vector(id, sparse_count, sparse_indices_buffer, sparse_values_buffer); } //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index const IndexMeta &meta(void) const override { return meta_; } const FlatSparseStreamerEntity &entity(void) const { return entity_; } uint32_t magic(void) const { return magic_; } private: inline int check_params(const IndexQueryMeta &qmeta) const { if (ailego_unlikely(qmeta.data_type() != meta_.data_type())) { LOG_ERROR("Unsupported query meta, type=%d, expected=%d", qmeta.data_type(), meta_.data_type()); return IndexError_Mismatch; } return 0; } int do_search(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, bool with_p_keys, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; IndexMeta meta_{}; FlatSparseStreamerEntity entity_; uint32_t magic_{0U}; Stats stats_{}; State state_{STATE_INIT}; //! avoid add vector while dumping index ailego::SharedMutex shared_mutex_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_streamer_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "flat_sparse_streamer_entity.h" #include #include #include #include #include #include #include #include #include "flat_sparse_index_format.h" #include "flat_sparse_utility.h" namespace zvec { namespace core { FlatSparseStreamerEntity::FlatSparseStreamerEntity(IndexStreamer::Stats &stats) : stats_(stats) {} int FlatSparseStreamerEntity::open(IndexStorage::Pointer storage, const IndexMeta &meta) { if (storage_) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } keys_map_lock_ = std::make_shared(); if (!keys_map_lock_) { LOG_ERROR("FlatSparseStreamerEntity new object failed"); return IndexError_NoMemory; } keys_map_ = std::make_shared>(); if (storage->get(PARAM_FLAT_SPARSE_META_SEG_ID) || storage->get(PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID)) { int ret = this->load_storage(storage, meta); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to load storage index"); return ret; } } else { int ret = this->init_storage(storage, meta); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to load storage index"); return ret; } } if (init_metric(meta) != 0) { LOG_ERROR("Failed to init metric"); return IndexError_InvalidFormat; } // reserve data chunk meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_MAX_DATA_CHUNK_CNT, &max_data_chunk_cnt_); sparse_data_chunks_.reserve(max_data_chunk_cnt_); // reserve offset chunk meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_MAX_DOC_CNT, &max_doc_cnt_); sparse_offset_chunks_.reserve(max_doc_cnt_ / doc_cnt_per_offset_chunk() + 1); sparse_unit_size_ = meta.unit_size(); LOG_DEBUG( "FlatSparseStreamerEntity open success, doc_count[%u], " "data_chunk_size[%u], offset_chunk_size[%u], data_chunk_count[%zu], " "offset_chunk_count[%zu]", meta_.doc_cnt, streamer_meta_.data_chunk_size, streamer_meta_.offset_chunk_size, sparse_data_chunks_.size(), sparse_offset_chunks_.size()); storage_ = storage; return 0; } int FlatSparseStreamerEntity::init_metric(const IndexMeta &meta) { metric_ = IndexFactory::CreateMetric(meta.metric_name()); if (!metric_) { LOG_ERROR("Failed to create metric %s", meta.metric_name().c_str()); return IndexError_NoExist; } int ret = metric_->init(meta, meta.metric_params()); if (ret != 0) { LOG_ERROR("Failled to init metric, ret=%d", ret); return ret; } if (!metric_->sparse_distance()) { LOG_ERROR("Invalid metric distance"); return IndexError_InvalidArgument; } search_sparse_distance_ = metric_->sparse_distance(); if (metric_->query_metric() && metric_->query_metric()->distance()) { search_sparse_distance_ = metric_->query_metric()->sparse_distance(); } return 0; } int FlatSparseStreamerEntity::load_storage(IndexStorage::Pointer storage, const IndexMeta &meta) { size_t index_size{0}; // load meta auto segment = storage->get(PARAM_FLAT_SPARSE_META_SEG_ID); if (!segment || segment->data_size() < sizeof(meta_)) { LOG_ERROR("Missing segment %s, or invalid segment size", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_InvalidFormat; } IndexStorage::MemoryBlock data_block; if (ailego_unlikely(segment->read(0, data_block, sizeof(meta_)) != sizeof(meta_))) { LOG_ERROR("Failed to read meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_ReadData; } meta_ = *(reinterpret_cast(data_block.data())); index_size += segment->capacity(); // load streamer meta segment = storage->get(PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID); if (!segment || segment->data_size() < sizeof(streamer_meta_)) { LOG_ERROR("Missing segment %s, or invalid segment size", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_InvalidFormat; } if (ailego_unlikely(segment->read(0, data_block, sizeof(streamer_meta_)) != sizeof(streamer_meta_))) { LOG_ERROR("Failed to read streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_ReadData; } streamer_meta_ = *(reinterpret_cast(data_block.data())); index_size += segment->capacity(); uint32_t meta_data_chunk_size{streamer_meta_.data_chunk_size}; uint32_t meta_offset_chunk_size{streamer_meta_.offset_chunk_size}; meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_DATA_CHUNK_SIZE, &meta_data_chunk_size); meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_OFFSET_CHUNK_SIZE, &meta_offset_chunk_size); if (streamer_meta_.data_chunk_size != meta_data_chunk_size || streamer_meta_.offset_chunk_size != meta_offset_chunk_size) { LOG_ERROR( "Invalid streamer meta chunk size data[%u] offset[%u], expect data[%u] " "offset[%u]", streamer_meta_.data_chunk_size, streamer_meta_.offset_chunk_size, meta_data_chunk_size, meta_offset_chunk_size); return IndexError_InvalidFormat; } // check chunk cnt if (streamer_meta_.data_chunk_count > max_data_chunk_cnt_ || meta_.doc_cnt > max_doc_cnt_) { LOG_ERROR( "Invalid data chunk count[%u] doc count[%u], expect less than " "chunk count[%u] doc count[%u]", streamer_meta_.data_chunk_count, meta_.doc_cnt, max_data_chunk_cnt_, max_doc_cnt_); return IndexError_InvalidFormat; } // load offset chunks for (size_t i = 0; i < streamer_meta_.offset_chunk_count; ++i) { std::string segment_id = ailego::StringHelper::Concat(PARAM_FLAT_SPARSE_OFFSET_SEG_ID_PREFIX, i); segment = storage->get(segment_id); if (!segment) { LOG_ERROR("Missing segment %s", segment_id.c_str()); return IndexError_InvalidFormat; } sparse_offset_chunks_.emplace_back(segment); index_size += segment->capacity(); } // load data chunks for (size_t i = 0; i < streamer_meta_.data_chunk_count; ++i) { std::string segment_id = ailego::StringHelper::Concat(PARAM_FLAT_SPARSE_DATA_SEG_ID_PREFIX, i); segment = storage->get(segment_id); if (!segment) { LOG_ERROR("Missing segment %s", segment_id.c_str()); } sparse_data_chunks_.emplace_back(segment); index_size += segment->capacity(); } // load keys for (node_id_t i = 0; i < meta_.doc_cnt; ++i) { (*keys_map_)[get_key(i)] = i; } stats_.set_index_size(index_size); stats_.set_check_point(storage->check_point()); stats_.set_create_time(meta_.create_time); stats_.set_update_time(meta_.update_time); stats_.set_loaded_count(keys_map_->size()); return 0; } int FlatSparseStreamerEntity::init_storage(IndexStorage::Pointer storage, const IndexMeta &meta) { meta_.create_time = ailego::Realtime::Seconds(); stats_.set_create_time(meta_.create_time); meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(meta_.update_time); meta_.doc_cnt = 0; meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_DATA_CHUNK_SIZE, &streamer_meta_.data_chunk_size); meta.streamer_params().get(PARAM_FLAT_SPARSE_STREAMER_OFFSET_CHUNK_SIZE, &streamer_meta_.offset_chunk_size); // append meta segment size_t size = ailego_align(sizeof(meta_), ailego::MemoryHelper::PageSize()); int ret = storage->append(PARAM_FLAT_SPARSE_META_SEG_ID, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to append meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return ret; } auto segment = storage->get(PARAM_FLAT_SPARSE_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("Failed to write meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_WriteData; } *stats_.mutable_index_size() += size; // append streamer meta segment size = ailego_align(sizeof(streamer_meta_), ailego::MemoryHelper::PageSize()); ret = storage->append(PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to append streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return ret; } segment = storage->get(PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &streamer_meta_, sizeof(streamer_meta_)) != sizeof(streamer_meta_)) { LOG_ERROR("Failed to write streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_WriteData; } *stats_.mutable_index_size() += size; return 0; } int FlatSparseStreamerEntity::close() { storage_.reset(); sparse_data_chunks_.clear(); sparse_offset_chunks_.clear(); keys_map_lock_.reset(); keys_map_.reset(); return 0; } int FlatSparseStreamerEntity::flush(uint64_t checkpoint) { // flush meta meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(meta_.update_time); auto segment = storage_->get(PARAM_FLAT_SPARSE_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("Failed to write meta segment %s", PARAM_FLAT_SPARSE_META_SEG_ID.c_str()); return IndexError_WriteData; } // flush streamer meta streamer_meta_.data_chunk_count = sparse_data_chunks_.size(); streamer_meta_.offset_chunk_count = sparse_offset_chunks_.size(); segment = storage_->get(PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID); if (ailego_unlikely(!segment)) { LOG_ERROR("Failed to get streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_Runtime; } if (segment->write(0, &streamer_meta_, sizeof(streamer_meta_)) != sizeof(streamer_meta_)) { LOG_ERROR("Failed to write streamer meta segment %s", PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID.c_str()); return IndexError_WriteData; } if (checkpoint != 0) { storage_->refresh(checkpoint); } int ret = storage_->flush(); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Failed to flush storage for %s", IndexError::What(ret)); return ret; } if (checkpoint != 0) { stats_.set_check_point(checkpoint); } return 0; } int FlatSparseStreamerEntity::dump(const IndexDumper::Pointer &dumper) { ailego::ElapsedTime stamp; int ret; // meta ret = dump_meta(dumper.get()); if (ret != 0) { return ret; } auto duration_dump_meta = stamp.milli_seconds(); // offset & data ret = dump_offset_data(dumper.get()); if (ret != 0) { return ret; } auto duration_dump_offset_data = stamp.milli_seconds() - duration_dump_meta; // keys std::vector keys = get_keys(); ret = dump_keys(keys, dumper.get()); if (ret != 0) { return ret; } auto duration_dump_keys = stamp.milli_seconds() - duration_dump_offset_data - duration_dump_meta; // mapping ret = dump_mapping(keys, dumper.get()); if (ret != 0) { return ret; } auto duration_dump_mapping = stamp.milli_seconds() - duration_dump_offset_data - duration_dump_meta - duration_dump_keys; LOG_INFO( "Dump index meta: %zu ms, offset & data: %zu ms, keys: %zu ms, " "mapping: %zu ms", (size_t)duration_dump_meta, (size_t)duration_dump_offset_data, (size_t)duration_dump_keys, (size_t)duration_dump_mapping); return 0; } int FlatSparseStreamerEntity::dump_offset_data(IndexDumper *dumper) { ailego::ElapsedTime stamp; uint64_t init_offset = dump_size_; std::vector> offset_length; // write data int ret; node_id_t total_doc_cnt = doc_cnt(); for (node_id_t node_id = 0; node_id < total_doc_cnt; node_id++) { uint32_t target_vector_len; IndexStorage::MemoryBlock target_vector_block; ret = get_sparse_vector_ptr_by_id(node_id, target_vector_block, &target_vector_len); if (ret != 0) { LOG_ERROR("Failed to get vector, node_id=%u, error: %s", node_id, IndexError::What(ret)); return ret; } const void *target_vector = target_vector_block.data(); ret = dump_sparse_vector_data(target_vector, target_vector_len, dumper); if (ret != 0) { LOG_ERROR("Failed to dump sparse vector data, node_id=%u, error: %s", node_id, IndexError::What(ret)); return ret; } offset_length.push_back({dump_size_ - init_offset, target_vector_len}); dump_size_ += target_vector_len; } // append data segment if (dumper->append(PARAM_FLAT_SPARSE_DUMP_DATA_SEG_ID, dump_size_ - init_offset, 0, 0) != 0) { LOG_ERROR("append data segment failed"); return IndexError_WriteData; } auto duration_dump_data = stamp.milli_seconds(); // write offset for (auto &offset_length_pair : offset_length) { if (dumper->write(&offset_length_pair.first, sizeof(offset_length_pair.first)) != sizeof(offset_length_pair.first)) { return IndexError_WriteData; } if (dumper->write(&offset_length_pair.second, sizeof(offset_length_pair.second)) != sizeof(offset_length_pair.second)) { return IndexError_WriteData; } dump_size_ += sizeof(offset_length_pair.first) + sizeof(offset_length_pair.second); } // append offset segment if (dumper->append( PARAM_FLAT_SPARSE_DUMP_OFFSET_SEG_ID, offset_length.size() * (sizeof(uint64_t) + sizeof(uint32_t)), 0, 0) != 0) { LOG_ERROR("append offset segment failed"); return IndexError_WriteData; } auto duration_dump_offset = stamp.milli_seconds() - duration_dump_data; LOG_INFO("Dump offset: %zu ms, data: %zu ms", (size_t)duration_dump_offset, (size_t)duration_dump_data); return 0; } int FlatSparseStreamerEntity::dump_sparse_vector_data(const void *data, uint32_t length, IndexDumper *dumper) { if (dumper->write(data, length) != length) { return IndexError_WriteData; } return 0; } int FlatSparseStreamerEntity::dump_meta(IndexDumper *dumper) { if (dumper->write(&meta_, sizeof(meta_)) != sizeof(meta_)) { LOG_ERROR("write meta failed"); return IndexError_WriteData; } size_t meta_padding_size = ailego_align(sizeof(meta_), 32) - sizeof(meta_); if (meta_padding_size) { std::string padding(meta_padding_size, '\0'); if (dumper->write(padding.data(), meta_padding_size) != meta_padding_size) { LOG_ERROR("write meta padding failed"); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_META_SEG_ID, sizeof(meta_), meta_padding_size, 0); } int FlatSparseStreamerEntity::dump_keys(const std::vector &keys, IndexDumper *dumper) { if (keys.size() == 1 && keys.back() == kInvalidKey) { return IndexError_Runtime; } size_t keys_size = keys.size() * sizeof(uint64_t); if (dumper->write(keys.data(), keys_size) != keys_size) { LOG_ERROR("Failed to write keys to dumper %s", dumper->name().c_str()); return IndexError_WriteData; } size_t keys_padding_size = ailego_align(keys_size, 32) - keys_size; if (keys_padding_size) { std::string padding(keys_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write padding to dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_DUMP_KEYS_SEG_ID, keys_size, keys_padding_size, 0); } int FlatSparseStreamerEntity::dump_mapping(const std::vector &keys, IndexDumper *dumper) { std::vector mapping(keys.size()); std::iota(mapping.begin(), mapping.end(), 0); std::sort( mapping.begin(), mapping.end(), [&keys](uint32_t lhs, uint32_t rhs) { return (keys[lhs] < keys[rhs]); }); size_t mapping_size = mapping.size() * sizeof(uint32_t); size_t mapping_padding_size = ailego_align(mapping_size, 32) - mapping_size; if (dumper->write(mapping.data(), mapping_size) != mapping_size) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } // Write the padding if need if (mapping_padding_size) { std::string padding(mapping_padding_size, '\0'); if (dumper->write(padding.data(), padding.size()) != padding.size()) { LOG_ERROR("Failed to write data into dumper %s", dumper->name().c_str()); return IndexError_WriteData; } } return dumper->append(PARAM_FLAT_SPARSE_DUMP_MAPPING_SEG_ID, mapping_size, mapping_padding_size, 0); } FlatSparseStreamerEntity::Pointer FlatSparseStreamerEntity::clone() const { auto entity = new (std::nothrow) FlatSparseStreamerEntity( stats_, meta_, streamer_meta_, keys_map_lock_, keys_map_, sparse_data_chunks_, sparse_offset_chunks_); return FlatSparseStreamerEntity::Pointer(entity); } int FlatSparseStreamerEntity::add(uint64_t key, const std::string &sparse_vector, const uint32_t sparse_count) { uint32_t sparse_vector_len = sparse_vector.size(); sparse_vector_len = AlignSize(sparse_vector_len); if (sparse_vector_len > streamer_meta_.data_chunk_size) { LOG_ERROR( "Sparse Vector Length exceed the chunk size, sparse vec len: %u, chunk " "size: %u", sparse_vector_len, streamer_meta_.data_chunk_size); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); node_id_t local_id = doc_cnt(); if (ailego_unlikely(local_id >= max_doc_cnt_)) { LOG_ERROR("Add vector failed for exceed max doc count: %u", max_doc_cnt_); return IndexError_IndexFull; } // duplicate check if (ailego_unlikely(get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } // get sparse data chunk and offset for write sparse vector Chunk::Pointer sparse_data_chunk; uint32_t sparse_data_chunk_offset = -1U; uint32_t sparse_data_chunk_index = sparse_data_chunks_.size() - 1U; if (sparse_data_chunk_index == -1U || sparse_data_chunks_[sparse_data_chunk_index]->data_size() + sparse_vector_len > streamer_meta_.data_chunk_size) { if (ailego_unlikely(sparse_data_chunks_.capacity() == sparse_data_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); if (sparse_data_chunk_index != -1U) { LOG_ERROR( "capacity: %zu, chunk used size: %zu, chunk size: %u, " "sparse_vector_len: %u", sparse_data_chunks_.capacity(), sparse_data_chunks_[sparse_data_chunk_index]->data_size(), streamer_meta_.data_chunk_size, sparse_vector_len); } return IndexError_IndexFull; } sparse_data_chunk = alloc_new_data_chunk(sparse_data_chunks_.size()); if (ailego_unlikely(!sparse_data_chunk)) { LOG_ERROR("allocate data chunk failed"); return IndexError_NoMemory; } sparse_data_chunks_.emplace_back(sparse_data_chunk); sparse_data_chunk_index = sparse_data_chunks_.size() - 1U; sparse_data_chunk_offset = 0UL; } else { sparse_data_chunk = sparse_data_chunks_[sparse_data_chunk_index]; sparse_data_chunk_offset = sparse_data_chunk->data_size(); } // write sparse vector if (sparse_vector.size() > 0) { if (ailego_unlikely(write_sparse_vector_data( sparse_data_chunk_index, sparse_data_chunk_offset, sparse_vector.data(), sparse_vector.size()) != 0)) { LOG_ERROR("write sparse vector failed"); return IndexError_NoMemory; } } uint64_t sparse_offset = sparse_data_chunk_index; sparse_offset = (sparse_offset << 32U) + sparse_data_chunk_offset; // get sparse offset chunk and offset for write new info Chunk::Pointer sparse_offset_chunk; uint32_t sparse_offset_chunk_offset = -1U; uint32_t sparse_offset_chunk_index = sparse_offset_chunks_.size() - 1U; if (sparse_offset_chunk_index == -1U || sparse_offset_chunks_[sparse_offset_chunk_index]->data_size() + offset_size_per_node() > streamer_meta_.offset_chunk_size) { // no space left and need to allocate new offset chunk if (ailego_unlikely(sparse_offset_chunks_.capacity() == sparse_offset_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } sparse_offset_chunk = alloc_new_offset_chunk(sparse_offset_chunks_.size()); if (ailego_unlikely(!sparse_offset_chunk)) { LOG_ERROR("allocate offset chunk failed"); return IndexError_NoMemory; } sparse_offset_chunks_.emplace_back(sparse_offset_chunk); sparse_offset_chunk_index = sparse_offset_chunks_.size() - 1U; sparse_offset_chunk_offset = 0UL; } else { sparse_offset_chunk = sparse_offset_chunks_[sparse_offset_chunk_index]; sparse_offset_chunk_offset = sparse_offset_chunk->data_size(); } // write offset size_t size = sparse_offset_chunk->write(sparse_offset_chunk_offset, &sparse_offset, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("Chunk write sparse vec offset failed, ret=%zu", size); return IndexError_WriteData; } // write length size = sparse_offset_chunk->write(sparse_offset_chunk_offset + sizeof(uint64_t), &sparse_vector_len, sizeof(uint32_t)); if (ailego_unlikely(size != sizeof(uint32_t))) { LOG_ERROR("Chunk write sparse vec len failed, ret=%zu", size); return IndexError_WriteData; } // write key size = sparse_offset_chunk->write( sparse_offset_chunk_offset + 2 * sizeof(uint64_t), &key, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("Chunk write key failed, ret=%zu", size); return IndexError_WriteData; } // LOG_INFO("Write sparse vector, key=%lu, offset chunk=%u, offset=%u, // len=%u", // key, sparse_offset_chunk_index, sparse_offset_chunk_offset, // offset_size_per_node()); // LOG_INFO("Write sparse vector, key=%lu, data chunk=%u, offset=%u, len=%u", // key, sparse_data_chunk_index, sparse_data_chunk_offset, // sparse_vector_len); // resize chunk if (sparse_vector_len > 0) { sparse_data_chunk_offset += sparse_vector_len; if (ailego_unlikely(sparse_data_chunk->resize(sparse_data_chunk_offset) != sparse_data_chunk_offset)) { LOG_ERROR("Sparse Chunk resize to %u failed", sparse_data_chunk_offset); return IndexError_Runtime; } } // persist in keys_map { keys_map_lock_->lock(); (*keys_map_)[key] = local_id; keys_map_lock_->unlock(); } inc_doc_count(); inc_total_sparse_count(sparse_count); return 0; } int FlatSparseStreamerEntity::add_vector_with_id( uint32_t id, const std::string &sparse_vector, const uint32_t sparse_count) { uint32_t sparse_vector_len = sparse_vector.size(); sparse_vector_len = AlignSize(sparse_vector_len); if (sparse_vector_len > streamer_meta_.data_chunk_size) { LOG_ERROR( "Sparse Vector Length exceed the chunk size, sparse vec len: %u, chunk " "size: %u", sparse_vector_len, streamer_meta_.data_chunk_size); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); if (id >= doc_cnt()) { for (auto i = doc_cnt(); i <= id; i++) { node_id_t local_id = doc_cnt(); if (ailego_unlikely(local_id >= max_doc_cnt_)) { LOG_ERROR("Add vector failed for exceed max doc count: %u", max_doc_cnt_); return IndexError_IndexFull; } uint32_t sparse_data_chunk_index, sparse_data_chunk_offset, sparse_offset_chunk_index, sparse_offset_chunk_offset; if (i < id) { write_sparse_vector_to_chunk("", 0, sparse_data_chunk_index, sparse_data_chunk_offset); } else { write_sparse_vector_to_chunk(sparse_vector, sparse_vector_len, sparse_data_chunk_index, sparse_data_chunk_offset); } uint64_t sparse_offset = ((uint64_t)sparse_data_chunk_index << 32U) + sparse_data_chunk_offset; get_new_sparse_offset_chunk(sparse_offset_chunk_index, sparse_offset_chunk_offset); uint64_t written_key = kInvalidKey; if (i == id) { written_key = i; } write_sparse_offset_to_chunk(sparse_offset_chunk_index, sparse_offset_chunk_offset, sparse_offset, sparse_vector_len, written_key); { keys_map_lock_->lock(); (*keys_map_)[i] = written_key; keys_map_lock_->unlock(); } inc_doc_count(); } } else { uint32_t sparse_data_chunk_index, sparse_data_chunk_offset; write_sparse_vector_to_chunk(sparse_vector, sparse_vector_len, sparse_data_chunk_index, sparse_data_chunk_offset); uint64_t sparse_offset = ((uint64_t)sparse_data_chunk_index << 32U) + sparse_data_chunk_offset; uint32_t sparse_offset_chunk_index = id / get_offset_info_number_per_chunk(); uint32_t sparse_offset_chunk_offset = id % get_offset_info_number_per_chunk() * offset_size_per_node(); write_sparse_offset_to_chunk(sparse_offset_chunk_index, sparse_offset_chunk_offset, sparse_offset, sparse_vector_len, id); { keys_map_lock_->lock(); (*keys_map_)[id] = id; keys_map_lock_->unlock(); } } inc_total_sparse_count(sparse_count); return 0; } int FlatSparseStreamerEntity::write_sparse_vector_to_chunk( const std::string &sparse_vector, const uint32_t sparse_vector_len, uint32_t &sparse_data_chunk_index, uint32_t &sparse_data_chunk_offset) { // get sparse data chunk and offset for write sparse vector Chunk::Pointer sparse_data_chunk; sparse_data_chunk_offset = -1U; sparse_data_chunk_index = sparse_data_chunks_.size() - 1U; if (sparse_data_chunk_index == -1U || sparse_data_chunks_[sparse_data_chunk_index]->data_size() + sparse_vector_len > streamer_meta_.data_chunk_size) { if (ailego_unlikely(sparse_data_chunks_.capacity() == sparse_data_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); if (sparse_data_chunk_index != -1U) { LOG_ERROR( "capacity: %zu, chunk used size: %zu, chunk size: %u, " "sparse_vector_len: %u", sparse_data_chunks_.capacity(), sparse_data_chunks_[sparse_data_chunk_index]->data_size(), streamer_meta_.data_chunk_size, sparse_vector_len); } return IndexError_IndexFull; } sparse_data_chunk = alloc_new_data_chunk(sparse_data_chunks_.size()); if (ailego_unlikely(!sparse_data_chunk)) { LOG_ERROR("allocate data chunk failed"); return IndexError_NoMemory; } sparse_data_chunks_.emplace_back(sparse_data_chunk); sparse_data_chunk_index = sparse_data_chunks_.size() - 1U; sparse_data_chunk_offset = 0UL; } else { sparse_data_chunk = sparse_data_chunks_[sparse_data_chunk_index]; sparse_data_chunk_offset = sparse_data_chunk->data_size(); } // write sparse vector if (sparse_vector.size() > 0) { if (ailego_unlikely(write_sparse_vector_data( sparse_data_chunk_index, sparse_data_chunk_offset, sparse_vector.data(), sparse_vector.size()) != 0)) { LOG_ERROR("write sparse vector failed"); return IndexError_NoMemory; } } // resize chunk if (sparse_vector_len > 0) { uint32_t sparse_data_chunk_size = sparse_data_chunk_offset + sparse_vector_len; if (ailego_unlikely(sparse_data_chunk->resize(sparse_data_chunk_size) != sparse_data_chunk_size)) { LOG_ERROR("Sparse Chunk resize to %u failed", sparse_data_chunk_size); return IndexError_Runtime; } } return 0; } int FlatSparseStreamerEntity::get_new_sparse_offset_chunk( uint32_t &sparse_offset_chunk_index, uint32_t &sparse_offset_chunk_offset) { // get sparse offset chunk and offset for write new info Chunk::Pointer sparse_offset_chunk; sparse_offset_chunk_offset = -1U; sparse_offset_chunk_index = sparse_offset_chunks_.size() - 1U; if (sparse_offset_chunk_index == -1U || sparse_offset_chunks_[sparse_offset_chunk_index]->data_size() + offset_size_per_node() > streamer_meta_.offset_chunk_size) { // no space left and need to allocate new offset chunk if (ailego_unlikely(sparse_offset_chunks_.capacity() == sparse_offset_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } sparse_offset_chunk = alloc_new_offset_chunk(sparse_offset_chunks_.size()); if (ailego_unlikely(!sparse_offset_chunk)) { LOG_ERROR("allocate offset chunk failed"); return IndexError_NoMemory; } sparse_offset_chunks_.emplace_back(sparse_offset_chunk); sparse_offset_chunk_index = sparse_offset_chunks_.size() - 1U; sparse_offset_chunk_offset = 0UL; } else { sparse_offset_chunk = sparse_offset_chunks_[sparse_offset_chunk_index]; sparse_offset_chunk_offset = sparse_offset_chunk->data_size(); } return 0; } int FlatSparseStreamerEntity::write_sparse_offset_to_chunk( const uint32_t sparse_offset_chunk_index, const uint32_t sparse_offset_chunk_offset, const uint64_t sparse_offset, const uint32_t sparse_vector_len, const uint64_t node_id) { // write offset Chunk::Pointer sparse_offset_chunk = sparse_offset_chunks_[sparse_offset_chunk_index]; size_t size = sparse_offset_chunk->write(sparse_offset_chunk_offset, &sparse_offset, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("Chunk write sparse vec offset failed, ret=%zu", size); return IndexError_WriteData; } // write length size = sparse_offset_chunk->write(sparse_offset_chunk_offset + sizeof(uint64_t), &sparse_vector_len, sizeof(uint32_t)); if (ailego_unlikely(size != sizeof(uint32_t))) { LOG_ERROR("Chunk write sparse vec len failed, ret=%zu", size); return IndexError_WriteData; } // write key size = sparse_offset_chunk->write( sparse_offset_chunk_offset + 2 * sizeof(uint64_t), &node_id, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("Chunk write key failed, ret=%zu", size); return IndexError_WriteData; } return 0; } uint64_t FlatSparseStreamerEntity::get_key(node_id_t node_id) const { uint32_t offset_chunk_index = node_id / get_offset_info_number_per_chunk(); uint32_t offset_chunk_key_offset = node_id % get_offset_info_number_per_chunk() * offset_size_per_node() + 2 * sizeof(uint64_t); IndexStorage::MemoryBlock block; if (ailego_unlikely(sparse_offset_chunks_[offset_chunk_index]->read( offset_chunk_key_offset, block, sizeof(uint64_t)) != sizeof(uint64_t))) { LOG_ERROR("Read key failed, offset=%u, node_id=%u", offset_chunk_key_offset, node_id); return kInvalidKey; }; return *reinterpret_cast(block.data()); } int FlatSparseStreamerEntity::get_sparse_vector_ptr_by_id( node_id_t node_id, const void **sparse_vector_ptr, uint32_t *sparse_vector_len_ptr) const { uint32_t offset_chunk_index = node_id / get_offset_info_number_per_chunk(); uint32_t offset_chunk_offset = node_id % get_offset_info_number_per_chunk() * offset_size_per_node(); // LOG_DEBUG("Read sparse vector, offset chunk=%u, offset=%u, len=%u", // offset_chunk_index, offset_chunk_offset, offset_size_per_node()); auto offset_chunk = sparse_offset_chunks_[offset_chunk_index]; const void *offset_info = nullptr; size_t read_len = offset_chunk->read(offset_chunk_offset, &offset_info, offset_size_per_node()); if (ailego_unlikely(read_len != offset_size_per_node())) { LOG_ERROR("Read offset info failed, offset=%u, read_len=%zu, expect=%u", offset_chunk_offset, read_len, offset_size_per_node()); return IndexError_ReadData; }; // sparse offset uint64_t sparse_offset = *(uint64_t *)offset_info; uint32_t sparse_vector_len = *(uint32_t *)((uint8_t *)offset_info + sizeof(uint64_t)); uint32_t sparse_data_chunk_index = static_cast((sparse_offset >> 32) & 0xFFFFFFFF); uint32_t sparse_data_chunk_offset = static_cast(sparse_offset & 0xFFFFFFFF); if (sparse_vector_len > 0) { const void *sparse_data = get_sparse_vector_data( sparse_data_chunk_index, sparse_data_chunk_offset, sparse_vector_len); if (ailego_unlikely(sparse_data == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)sparse_offset, sparse_vector_len); return IndexError_ReadData; } *sparse_vector_ptr = sparse_data; *sparse_vector_len_ptr = sparse_vector_len; } // LOG_DEBUG("Read sparse vector, data chunk=%u, offset=%u, len=%u", // sparse_data_chunk_index, sparse_data_chunk_offset, // sparse_vector_len); return 0; } int FlatSparseStreamerEntity::get_sparse_vector_ptr_by_id( node_id_t node_id, IndexStorage::MemoryBlock &sparse_vector_block, uint32_t *sparse_vector_len_ptr) const { uint32_t offset_chunk_index = node_id / get_offset_info_number_per_chunk(); uint32_t offset_chunk_offset = node_id % get_offset_info_number_per_chunk() * offset_size_per_node(); // LOG_DEBUG("Read sparse vector, offset chunk=%u, offset=%u, len=%u", // offset_chunk_index, offset_chunk_offset, offset_size_per_node()); auto offset_chunk = sparse_offset_chunks_[offset_chunk_index]; const void *offset_info = nullptr; IndexStorage::MemoryBlock offset_info_block; size_t read_len = offset_chunk->read(offset_chunk_offset, offset_info_block, offset_size_per_node()); if (ailego_unlikely(read_len != offset_size_per_node())) { LOG_ERROR("Read offset info failed, offset=%u, read_len=%zu, expect=%u", offset_chunk_offset, read_len, offset_size_per_node()); return IndexError_ReadData; }; offset_info = offset_info_block.data(); // sparse offset uint64_t sparse_offset = *(uint64_t *)offset_info; uint32_t sparse_vector_len = *(uint32_t *)((uint8_t *)offset_info + sizeof(uint64_t)); uint32_t sparse_data_chunk_index = static_cast((sparse_offset >> 32) & 0xFFFFFFFF); uint32_t sparse_data_chunk_offset = static_cast(sparse_offset & 0xFFFFFFFF); if (sparse_vector_len > 0) { get_sparse_vector_data(sparse_data_chunk_index, sparse_data_chunk_offset, sparse_vector_len, sparse_vector_block); if (ailego_unlikely(sparse_vector_block.data() == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)sparse_offset, sparse_vector_len); return IndexError_ReadData; } *sparse_vector_len_ptr = sparse_vector_len; } return 0; } int FlatSparseStreamerEntity::write_sparse_vector_data(uint32_t chunk_index, uint64_t offset, const void *data, uint32_t length) { auto size = sparse_data_chunks_[chunk_index]->write(offset, data, length); if (size != length) { LOG_ERROR( "write sparse vector data failed: chunk_index=%u, offset=%zu, " "length=%u, size=%zu, chunk_data_size=%zu", chunk_index, (size_t)offset, length, size, sparse_data_chunks_[chunk_index]->data_size()); return IndexError_WriteData; } // LOG_DEBUG( // "write_sparse_vector_data: chunk_index=%u, offset=%lu, length=%u, " // "data=%p", // chunk_index, offset, length, data); return 0; } const void *FlatSparseStreamerEntity::get_sparse_vector_data( uint32_t chunk_index, uint64_t offset, uint32_t length) const { const void *data; auto size = sparse_data_chunks_[chunk_index]->read(offset, &data, length); if (size != length) { LOG_ERROR( "read sparse vector data failed: chunk_index=%u, offset=%zu, " "length=%u, size=%zu", chunk_index, (size_t)offset, length, size); return nullptr; } // LOG_DEBUG( // "get_sparse_vector_data: chunk_index=%u, offset=%lu, length=%u, " // "data=%p", // chunk_index, offset, length, data); return data; } int FlatSparseStreamerEntity::get_sparse_vector_data( uint32_t chunk_index, uint64_t offset, uint32_t length, IndexStorage::MemoryBlock &block) const { auto size = sparse_data_chunks_[chunk_index]->read(offset, block, length); if (size != length) { LOG_ERROR( "read sparse vector data failed: chunk_index=%u, offset=%zu, " "length=%u, size=%zu", chunk_index, (size_t)offset, length, size); return IndexError_ReadData; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_streamer_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include "flat_sparse_entity.h" #include "flat_sparse_index_format.h" #include "flat_sparse_utility.h" namespace zvec { namespace core { /*! Flat Sparse Streamer Entity */ class FlatSparseStreamerEntity : public FlatSparseEntity { public: typedef std::shared_ptr Pointer; using Chunk = IndexStorage::Segment; //! Constructor explicit FlatSparseStreamerEntity(IndexStreamer::Stats &stats); //! Destructor virtual ~FlatSparseStreamerEntity() = default; //! Disable them FlatSparseStreamerEntity(const FlatSparseStreamerEntity &) = delete; FlatSparseStreamerEntity &operator=(const FlatSparseStreamerEntity &) = delete; //! Open the entity with storage int open(IndexStorage::Pointer storage, const IndexMeta &meta); //! Close the entity int close(); //! Flush linear index to storage int flush(uint64_t checkpoint); //! Dump index by dumper int dump(const IndexDumper::Pointer &dumper); //! Add sparse vector to linear index int add(uint64_t key, const std::string &sparse_vector, const uint32_t sparse_count); //! Add sparse vector to linear index with id int add_vector_with_id(uint32_t id, const std::string &sparse_vector, uint32_t sparse_count); //! Clone entity FlatSparseStreamerEntity::Pointer clone() const; int get_index_sparse_meta(IndexMeta *meta) const { return IndexHelper::DeserializeFromStorage(storage_.get(), meta); } int set_index_sparse_meta(const IndexMeta &meta) const { return IndexHelper::SerializeToStorage(meta, storage_.get()); } public: inline uint32_t doc_cnt() const override { return meta_.doc_cnt; } inline uint32_t total_sparse_count() const override { return meta_.total_sparse_count; } size_t sparse_unit_size() const override { return sparse_unit_size_; } inline node_id_t get_id(uint64_t key) const override { keys_map_lock_->lock_shared(); auto it = keys_map_->find(key); keys_map_lock_->unlock_shared(); return it == keys_map_->end() ? kInvalidNodeId : it->second; } uint64_t get_key(node_id_t node_id) const override; int get_sparse_vector_ptr_by_id(node_id_t id, const void **sparse_vector, uint32_t *sparse_vector_len) const override; int get_sparse_vector_ptr_by_id( const node_id_t id, IndexStorage::MemoryBlock &sparse_vector_block, uint32_t *sparse_vector_len) const override; float get_search_distance(const std::string &vector, node_id_t target_node_id) const override { float dist; const void *target_vector; uint32_t target_vector_len; get_sparse_vector_ptr_by_id(target_node_id, &target_vector, &target_vector_len); search_sparse_distance_(vector.c_str(), target_vector, &dist); return dist; } private: void inc_doc_count() { meta_.doc_cnt++; } void inc_total_sparse_count(uint32_t count) { meta_.total_sparse_count += count; } int init_metric(const IndexMeta &meta); int init_storage(IndexStorage::Pointer storage, const IndexMeta &meta); int load_storage(IndexStorage::Pointer storage, const IndexMeta &meta); static inline size_t AlignSize(size_t size) { return (size + 0x1F) & (~0x1F); } inline uint32_t offset_size_per_node() const { return 3 * sizeof(uint64_t); } inline uint32_t doc_cnt_per_offset_chunk() const { return streamer_meta_.offset_chunk_size / offset_size_per_node(); } Chunk::Pointer alloc_new_offset_chunk(uint32_t chunk_id) { std::string segment_id = ailego::StringHelper::Concat( PARAM_FLAT_SPARSE_OFFSET_SEG_ID_PREFIX, chunk_id); // LOG_INFO("Alloc new offset chunk %s", segment_id.c_str()); return alloc_new_chunk(segment_id, streamer_meta_.offset_chunk_size); } Chunk::Pointer alloc_new_data_chunk(uint32_t chunk_id) { std::string segment_id = ailego::StringHelper::Concat( PARAM_FLAT_SPARSE_DATA_SEG_ID_PREFIX, chunk_id); // LOG_INFO("Alloc new data chunk %s", segment_id.c_str()); return alloc_new_chunk(segment_id, streamer_meta_.data_chunk_size); } Chunk::Pointer alloc_new_chunk(const std::string &segment_id, uint32_t size) { int ret = storage_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { return nullptr; } *stats_.mutable_index_size() += size; return storage_->get(segment_id); } inline uint32_t get_offset_info_number_per_chunk() const { return streamer_meta_.offset_chunk_size / offset_size_per_node(); } int write_sparse_vector_to_chunk(const std::string &sparse_vector, const uint32_t sparse_vector_len, uint32_t &sparse_data_chunk_index, uint32_t &sparse_data_chunk_offset); int get_new_sparse_offset_chunk(uint32_t &sparse_offset_chunk_index, uint32_t &sparse_offset_chunk_offset); int write_sparse_offset_to_chunk(const uint32_t sparse_offset_chunk_index, const uint32_t sparse_offset_chunk_offset, const uint64_t sparse_offset, const uint32_t sparse_vector_len, const uint64_t node_id); int write_sparse_vector_data(uint32_t chunk_index, uint64_t offset, const void *data, uint32_t length); const void *get_sparse_vector_data(uint32_t chunk_index, uint64_t offset, uint32_t length) const; int get_sparse_vector_data(uint32_t chunk_index, uint64_t offset, uint32_t length, IndexStorage::MemoryBlock &block) const; int dump_sparse_vector_data(const void *data, uint32_t length, IndexDumper *dumper); int dump_meta(IndexDumper *dumper); int dump_index_meta(IndexDumper *dumper); int dump_keys(const std::vector &keys, IndexDumper *dumper); int dump_mapping(const std::vector &keys, IndexDumper *dumper); int dump_offset_data(IndexDumper *dumper); private: FlatSparseStreamerEntity( IndexStreamer::Stats &stats, const FlatSparseMeta &meta, const FlatSparseStreamerMeta &streamer_meta, std::shared_ptr keys_map_lock, std::shared_ptr> keys_map, std::vector sparse_data_chunks, std::vector sparse_offset_chunks) : stats_(stats), meta_(meta), streamer_meta_(streamer_meta), keys_map_lock_(keys_map_lock), keys_map_(keys_map), sparse_data_chunks_(std::move(sparse_data_chunks)), sparse_offset_chunks_(std::move(sparse_offset_chunks)) {} private: IndexStorage::Pointer storage_{}; IndexStreamer::Stats &stats_; // meta FlatSparseMeta meta_; FlatSparseStreamerMeta streamer_meta_; // metric IndexMetric::Pointer metric_{}; IndexMetric::MatrixSparseDistance search_sparse_distance_{}; std::mutex mutex_{}; // keys map mutable std::shared_ptr keys_map_lock_{}; std::shared_ptr> keys_map_{}; // chunks mutable std::vector sparse_data_chunks_{}; mutable std::vector sparse_offset_chunks_{}; // config uint32_t max_doc_cnt_{1 << 24U}; // 16 million uint32_t max_data_chunk_cnt_{ 1 << 10U}; // 1024, default single_data_chunk_size = 8M, // default_total_max = 1024 * 8M = 8G uint64_t dump_size_{0U}; size_t sparse_unit_size_{0U}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/flat_sparse/flat_sparse_utility.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { static constexpr uint32_t PARAM_FLAT_SPARSE_MAX_DIM_SIZE = 16384; static const std::string PARAM_FLAT_SPARSE_META_SEG_ID = "bruteforce_sparse_meta"; // streamer static const std::string PARAM_FLAT_SPARSE_STREAMER_META_SEG_ID = "bruteforce_sparse_streamer_meta"; static const std::string PARAM_FLAT_SPARSE_OFFSET_SEG_ID_PREFIX = "bruteforce_sparse_streamer_offset_"; static const std::string PARAM_FLAT_SPARSE_DATA_SEG_ID_PREFIX = "bruteforce_sparse_streamer_data_"; // searcher static const std::string PARAM_FLAT_SPARSE_DUMP_OFFSET_SEG_ID = "bruteforce_sparse_searcher_offset_segment"; static const std::string PARAM_FLAT_SPARSE_DUMP_DATA_SEG_ID = "bruteforce_sparse_searcher_data_segment"; static const std::string PARAM_FLAT_SPARSE_DUMP_KEYS_SEG_ID = "bruteforce_sparse_searcher_keys_segment"; static const std::string PARAM_FLAT_SPARSE_DUMP_MAPPING_SEG_ID = "bruteforce_sparse_searcher_mapping_segment"; // streamer static const std::string PARAM_FLAT_SPARSE_STREAMER_OFFSET_CHUNK_SIZE( "proxima.bruteforce.sparse_streamer.offset_chunk_size"); static const std::string PARAM_FLAT_SPARSE_STREAMER_DATA_CHUNK_SIZE( "proxima.bruteforce.sparse_streamer.data_chunk_size"); static const std::string PARAM_FLAT_SPARSE_STREAMER_MAX_DOC_CNT( "proxima.bruteforce.sparse_streamer.max_doc_cnt"); static const std::string PARAM_FLAT_SPARSE_STREAMER_MAX_DATA_CHUNK_CNT( "proxima.bruteforce.sparse_streamer.max_data_chunk_cnt"); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_knn_hnsw STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS core_framework sparsehash INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/hnsw/hnsw_algorithm.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_algorithm.h" #include #include #include #include namespace zvec { namespace core { HnswAlgorithm::HnswAlgorithm(HnswEntity &entity) : entity_(entity), mt_(std::chrono::system_clock::now().time_since_epoch().count()), lock_pool_(kLockCnt) {} int HnswAlgorithm::cleanup() { return 0; } int HnswAlgorithm::add_node(node_id_t id, level_t level, HnswContext *ctx) { spin_lock_.lock(); // std::cout << "id: " << id << ", level: " << level << std::endl; auto cur_max_level = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { entity_.update_ep_and_level(id, level); spin_lock_.unlock(); return 0; } spin_lock_.unlock(); if (ailego_unlikely(level > cur_max_level)) { mutex_.lock(); // re-check max level cur_max_level = entity_.cur_max_level(); entry_point = entity_.entry_point(); if (level <= cur_max_level) { mutex_.unlock(); } } level_t cur_level = cur_max_level; dist_t dist = ctx->dist_calculator()(entry_point); for (; cur_level > level; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } for (; cur_level >= 0; --cur_level) { search_neighbors(cur_level, &entry_point, &dist, ctx->level_topk(cur_level), ctx); } // add neighbors from down level to top level, to avoid upper level visible // to knn_search but the under layer level not ready for (cur_level = 0; cur_level <= level; ++cur_level) { add_neighbors(id, cur_level, ctx->level_topk(cur_level), ctx); ctx->level_topk(cur_level).clear(); } if (ailego_unlikely(level > cur_max_level)) { spin_lock_.lock(); entity_.update_ep_and_level(id, level); spin_lock_.unlock(); mutex_.unlock(); } return 0; } int HnswAlgorithm::search(HnswContext *ctx) const { spin_lock_.lock(); auto maxLevel = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); spin_lock_.unlock(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { return 0; } dist_t dist = ctx->dist_calculator().dist(entry_point); for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } auto &topk_heap = ctx->topk_heap(); topk_heap.clear(); search_neighbors(0, &entry_point, &dist, topk_heap, ctx); if (ctx->group_by_search()) { expand_neighbors_by_group(topk_heap, ctx); } return 0; } //! select_entry_point on hnsw level, ef = 1 void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswContext *ctx) const { auto &entity = ctx->get_entity(); HnswDistCalculator &dc = ctx->dist_calculator(); while (true) { const Neighbors neighbors = entity.get_neighbors(level, *entry_point); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } uint32_t size = neighbors.size(); if (size == 0) { break; } std::vector neighbor_vec_blocks; int ret = entity.get_vector(&neighbors[0], size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } bool find_closer = false; std::vector dists(size); std::vector neighbor_vecs(size); for (uint32_t i = 0; i < size; ++i) { neighbor_vecs[i] = neighbor_vec_blocks[i].data(); } dc.batch_dist(neighbor_vecs.data(), size, dists.data()); for (uint32_t i = 0; i < size; ++i) { dist_t cur_dist = dists[i]; if (cur_dist < *dist) { *entry_point = neighbors[i]; *dist = cur_dist; find_closer = true; } } if (!find_closer) { break; } } return; } void HnswAlgorithm::add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswContext *ctx) { if (ailego_unlikely(topk_heap.size() == 0)) { return; } HnswDistCalculator &dc = ctx->dist_calculator(); update_neighbors(dc, id, level, topk_heap); // reverse update neighbors for (size_t i = 0; i < topk_heap.size(); ++i) { reverse_update_neighbors(dc, topk_heap[i].first, level, id, topk_heap[i].second, ctx->update_heap()); } return; } void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, TopkHeap &topk, HnswContext *ctx) const { const auto &entity = ctx->get_entity(); HnswDistCalculator &dc = ctx->dist_calculator(); VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } candidates.clear(); visit.clear(); visit.set_visited(*entry_point); if (!filter(*entry_point)) { topk.emplace(*entry_point, *dist); } candidates.emplace(*entry_point, *dist); while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; dist_t main_dist = top->second; if (topk.full() && main_dist > topk[0].second) { break; } candidates.pop(); const Neighbors neighbors = entity.get_neighbors(level, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } std::vector neighbor_vec_blocks; int ret = entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } // do prefetch static constexpr node_id_t BATCH_SIZE = 12; static constexpr node_id_t PREFETCH_STEP = 2; for (uint32_t i = 0; i < std::min(BATCH_SIZE * PREFETCH_STEP, size); ++i) { ailego_prefetch(neighbor_vec_blocks[i].data()); } // done std::vector dists(size); std::vector neighbor_vecs(size); for (uint32_t i = 0; i < size; ++i) { neighbor_vecs[i] = neighbor_vec_blocks[i].data(); } dc.batch_dist(neighbor_vecs.data(), size, dists.data()); for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; dist_t cur_dist = dists[i]; if ((!topk.full()) || cur_dist < topk[0].second) { candidates.emplace(node, cur_dist); // update entry_point for next level scan if (cur_dist < *dist) { *entry_point = node; *dist = cur_dist; } if (!filter(node)) { topk.emplace(node, cur_dist); } } // end if } // end for } // while return; } void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk, HnswContext *ctx) const { if (!ctx->group_by().is_valid()) { return; } const auto &entity = ctx->get_entity(); std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity.get_key(id)); }; // devide into groups std::map &group_topk_heaps = ctx->group_topk_heaps(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; auto score = topk[i].second; std::string group_id = group_by(id); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, score); } // stage 2, expand to reach group num as possible if (group_topk_heaps.size() < ctx->group_num()) { VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); HnswDistCalculator &dc = ctx->dist_calculator(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } // refill to get enough groups candidates.clear(); visit.clear(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; float score = topk[i].second; visit.set_visited(id); candidates.emplace_back(id, score); } // do expand while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; candidates.pop(); const Neighbors neighbors = entity.get_neighbors(0, main_node); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } std::vector neighbor_vec_blocks; int ret = entity.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } static constexpr node_id_t PREFETCH_STEP = 2; for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; node_id_t prefetch_id = i + PREFETCH_STEP; if (prefetch_id < size) { ailego_prefetch(neighbor_vec_blocks[prefetch_id].data()); } dist_t cur_dist = dc.dist(neighbor_vec_blocks[i].data()); if (!filter(node)) { std::string group_id = group_by(node); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(node, cur_dist); if (group_topk_heaps.size() >= ctx->group_num()) { break; } } candidates.emplace(node, cur_dist); } // end for } // end while } // end if } void HnswAlgorithm::update_neighbors(HnswDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap) { topk_heap.sort(); uint32_t max_neighbor_cnt = entity_.neighbor_cnt(level); if (topk_heap.size() <= static_cast(entity_.prune_cnt())) { if (topk_heap.size() <= static_cast(max_neighbor_cnt)) { entity_.update_neighbors(level, id, topk_heap); return; } } uint32_t cur_size = 0; for (size_t i = 0; i < topk_heap.size(); ++i) { node_id_t cur_node = topk_heap[i].first; dist_t cur_node_dist = topk_heap[i].second; bool good = true; for (uint32_t j = 0; j < cur_size; ++j) { dist_t tmp_dist = dc.dist(cur_node, topk_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { topk_heap[cur_size].first = cur_node; topk_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } // when after-prune neighbor count is too seldom, // we use this strategy to make-up enough edges // not only just make-up out-degrees // we also make-up enough in-degrees uint32_t min_neighbors = entity_.min_neighbor_cnt(); for (size_t k = cur_size; cur_size < min_neighbors && k < topk_heap.size(); ++k) { bool exist = false; for (size_t j = 0; j < cur_size; ++j) { if (topk_heap[j].first == topk_heap[k].first) { exist = true; break; } } if (!exist) { topk_heap[cur_size].first = topk_heap[k].first; topk_heap[cur_size].second = topk_heap[k].second; cur_size++; } } topk_heap.resize(cur_size); entity_.update_neighbors(level, id, topk_heap); return; } void HnswAlgorithm::reverse_update_neighbors(HnswDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, dist_t dist, TopkHeap &update_heap) { const size_t max_neighbor_cnt = entity_.neighbor_cnt(level); uint32_t lock_idx = id & kLockMask; lock_pool_[lock_idx].lock(); const Neighbors neighbors = entity_.get_neighbors(level, id); size_t size = neighbors.size(); ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size"); if (size < max_neighbor_cnt) { entity_.add_neighbor(level, id, size, link_id); lock_pool_[lock_idx].unlock(); return; } update_heap.emplace(link_id, dist); for (size_t i = 0; i < size; ++i) { node_id_t node = neighbors[i]; dist_t cur_dist = dc.dist(id, node); update_heap.emplace(node, cur_dist); } //! TODO: optimize prune //! prune edges update_heap.sort(); size_t cur_size = 0; for (size_t i = 0; i < update_heap.size(); ++i) { node_id_t cur_node = update_heap[i].first; dist_t cur_node_dist = update_heap[i].second; bool good = true; for (size_t j = 0; j < cur_size; ++j) { dist_t tmp_dist = dc.dist(cur_node, update_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { update_heap[cur_size].first = cur_node; update_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } update_heap.resize(cur_size); entity_.update_neighbors(level, id, update_heap); lock_pool_[lock_idx].unlock(); update_heap.clear(); return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_algorithm.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_context.h" #include "hnsw_dist_calculator.h" #include "hnsw_entity.h" namespace zvec { namespace core { //! hnsw graph algorithm implement class HnswAlgorithm { public: typedef std::unique_ptr UPointer; public: //! Constructor explicit HnswAlgorithm(HnswEntity &entity); //! Destructor ~HnswAlgorithm() = default; //! Cleanup HnswAlgorithm int cleanup(); //! Add a node to hnsw graph //! @id: the node unique id //! @level: a node will be add to graph in each level [0, level] //! return 0 on success, or errCode in failure int add_node(node_id_t id, level_t level, HnswContext *ctx); //! do knn search in graph //! return 0 on success, or errCode in failure. results saved in ctx int search(HnswContext *ctx) const; //! Initiate HnswAlgorithm int init() { level_probas_.clear(); double level_mult = 1 / std::log(static_cast(entity_.scaling_factor())); for (int level = 0;; level++) { // refers faiss get_random_level alg double proba = std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); if (proba < 1e-9) { break; } level_probas_.push_back(proba); } return 0; } //! Generate a random level //! return graph level uint32_t get_random_level() const { // gen rand float (0, 1) double f = mt_() / static_cast(mt_.max()); for (size_t level = 0; level < level_probas_.size(); level++) { if (f < level_probas_[level]) { return level; } f -= level_probas_[level]; } return level_probas_.size() - 1; } private: //! Select in upper layer to get entry point for next layer search void select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswContext *ctx) const; //! update node id neighbors from topkHeap, and reverse link is also updated void add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswContext *ctx); //! Given a node id and level, search the nearest neighbors in graph //! Note: the nearest neighbors result keeps in topk, and entry_point and //! dist will be updated to current level nearest node id and distance void search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, TopkHeap &topk, HnswContext *ctx) const; //! Update the node's neighbors void update_neighbors(HnswDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap); //! Checking linkId could be id's new neighbor, and add as neighbor if true //! @dc distance calculator //! @updateHeap temporary heap in updating neighbors void reverse_update_neighbors(HnswDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, dist_t dist, TopkHeap &update_heap); //! expand neighbors until group nums are reached void expand_neighbors_by_group(TopkHeap &topk, HnswContext *ctx) const; private: HnswAlgorithm(const HnswAlgorithm &) = delete; HnswAlgorithm &operator=(const HnswAlgorithm &) = delete; private: static constexpr uint32_t kLockCnt{1U << 8}; static constexpr uint32_t kLockMask{kLockCnt - 1U}; HnswEntity &entity_; mutable std::mt19937 mt_{}; std::vector level_probas_{}; mutable ailego::SpinMutex spin_lock_{}; // global spin lock std::mutex mutex_{}; // global mutex // TODO: spin lock? std::vector lock_pool_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_builder.h" #include #include #include #include #include #include #include "hnsw_algorithm.h" #include "hnsw_params.h" namespace zvec { namespace core { HnswBuilder::HnswBuilder() = default; int HnswBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { LOG_INFO("Begin HnswBuilder::init"); meta_ = meta; auto params_copy = params; meta_.set_builder("HnswBuilder", HnswEntity::kRevision, std::move(params_copy)); size_t memory_quota = 0UL; params.get(PARAM_HNSW_BUILDER_MEMORY_QUOTA, &memory_quota); params.get(PARAM_HNSW_BUILDER_THREAD_COUNT, &thread_cnt_); params.get(PARAM_HNSW_BUILDER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); params.get(PARAM_HNSW_BUILDER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_BUILDER_CHECK_INTERVAL_SECS, &check_interval_secs_); params.get(PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_BUILDER_SCALING_FACTOR, &scaling_factor_); multiplier = HnswEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; if (ef_construction_ == 0) { ef_construction_ = HnswEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0) { upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d]", PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%d] must be <= [%s]-[%d]", PARAM_HNSW_BUILDER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), upper_max_neighbor_cnt_); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0) { l0_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { LOG_ERROR("L0MaxNeighborCnt must be in range (0,%d)", HnswEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_BUILDER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (thread_cnt_ == 0) { thread_cnt_ = std::thread::hardware_concurrency(); } if (thread_cnt_ > std::thread::hardware_concurrency()) { LOG_WARN("[%s] greater than cpu cores %u", PARAM_HNSW_BUILDER_THREAD_COUNT.c_str(), std::thread::hardware_concurrency()); } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } int ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed, ret=%d", ret); return ret; } entity_.set_vector_size(meta_.element_size()); entity_.set_ef_construction(ef_construction_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_memory_quota(memory_quota); entity_.set_prune_cnt(prune_cnt); ret = entity_.init(); if (ret != 0) { return ret; } alg_ = HnswAlgorithm::UPointer(new HnswAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } state_ = BUILD_STATE_INITED; LOG_INFO( "End HnswBuilder::init, params: vectorSize=%u efConstruction=%u " "l0NeighborCnt=%u upperNeighborCnt=%u scalingFactor=%u " "memoryQuota=%zu neighborPruneCnt=%zu metricName=%s ", meta_.element_size(), ef_construction_, l0_max_neighbor_cnt_, upper_max_neighbor_cnt_, scaling_factor_, memory_quota, prune_cnt, meta_.metric_name().c_str()); return 0; } int HnswBuilder::cleanup(void) { LOG_INFO("Begin HnswBuilder::cleanup"); l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; min_neighbor_cnt_ = 0; upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; ef_construction_ = HnswEntity::kDefaultEfConstruction; scaling_factor_ = HnswEntity::kDefaultScalingFactor; check_interval_secs_ = kDefaultLogIntervalSecs; errcode_ = 0; error_ = false; entity_.cleanup(); alg_->cleanup(); meta_.clear(); metric_.reset(); stats_.clear_attributes(); stats_.set_trained_count(0UL); stats_.set_built_count(0UL); stats_.set_dumped_count(0UL); stats_.set_discarded_count(0UL); stats_.set_trained_costtime(0UL); stats_.set_built_costtime(0UL); stats_.set_dumped_costtime(0UL); state_ = BUILD_STATE_INIT; LOG_INFO("End HnswBuilder::cleanup"); return 0; } int HnswBuilder::train(IndexThreads::Pointer, IndexHolder::Pointer holder) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswBuilder::train"); return IndexError_NoReady; } if (!holder) { LOG_ERROR("Input holder is nullptr while training index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while training index"); return IndexError_Mismatch; } LOG_INFO("Begin HnswBuilder::train"); size_t trained_cost_time = 0; size_t trained_count = 0; if (metric_->support_train()) { auto start_time = ailego::Monotime::MilliSeconds(); auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } while (iter->is_valid()) { int ret = metric_->train(iter->data(), meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw build measure train failed, ret=%d", ret); return ret; } iter->next(); ++trained_count; } trained_cost_time = ailego::Monotime::MilliSeconds() - start_time; } stats_.set_trained_count(trained_count); stats_.set_trained_costtime(trained_cost_time); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswBuilder::train"); return 0; } int HnswBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswBuilder::train"); return IndexError_NoReady; } LOG_INFO("Begin HnswBuilder::train by trainer"); stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswBuilder::train by trainer"); return 0; } int HnswBuilder::build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) { if (state_ != BUILD_STATE_TRAINED) { LOG_ERROR("Train the index before HnswBuilder::build"); return IndexError_NoReady; } if (!holder) { LOG_ERROR("Input holder is nullptr while building index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while building index"); return IndexError_Mismatch; } if (!threads) { threads = std::make_shared(thread_cnt_, false); if (!threads) { return IndexError_NoMemory; } } auto start_time = ailego::Monotime::MilliSeconds(); LOG_INFO("Begin HnswBuilder::build"); if (holder->count() != static_cast(-1)) { LOG_DEBUG("HnswBuilder holder documents count %lu", holder->count()); int ret = entity_.reserve_space(holder->count()); if (ret != 0) { LOG_ERROR("HnswBuilde reserver space failed"); return ret; } } auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } int ret; error_ = false; while (iter->is_valid()) { level_t level = alg_->get_random_level(); node_id_t id; const void *vec = iter->data(); ret = entity_.add_vector(level, iter->key(), vec, &id); if (ailego_unlikely(ret != 0)) { return ret; } iter->next(); } // Holder is not needed, cleanup it. holder.reset(); LOG_INFO("Finished save vector, start build graph..."); auto task_group = threads->make_group(); if (!task_group) { LOG_ERROR("Failed to create task group"); return IndexError_Runtime; } std::atomic finished{0}; for (size_t i = 0; i < threads->count(); ++i) { task_group->submit(ailego::Closure ::New(this, &HnswBuilder::do_build, i, threads->count(), &finished)); } while (!task_group->is_finished()) { std::unique_lock lk(mutex_); cond_.wait_until(lk, std::chrono::system_clock::now() + std::chrono::seconds(check_interval_secs_)); if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } LOG_INFO("Built cnt %u, finished percent %.3f%%", finished.load(), finished.load() * 100.0f / entity_.doc_cnt()); } if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } task_group->wait_finish(); stats_.set_built_count(finished.load()); stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = BUILD_STATE_BUILT; LOG_INFO("End HnswBuilder::build"); return 0; } void HnswBuilder::do_build(node_id_t idx, size_t step_size, std::atomic *finished) { AILEGO_DEFER([&]() { std::lock_guard latch(mutex_); cond_.notify_one(); }); HnswContext *ctx = new (std::nothrow) HnswContext(meta_.dimension(), metric_, std::shared_ptr(&entity_, [](HnswEntity *) {})); if (ailego_unlikely(ctx == nullptr)) { if (!error_.exchange(true)) { LOG_ERROR("Failed to create context"); errcode_ = IndexError_NoMemory; } return; } HnswContext::Pointer auto_ptr(ctx); ctx->set_max_scan_num(entity_.doc_cnt()); int ret = ctx->init(HnswContext::kBuilderContext); if (ret != 0) { if (!error_.exchange(true)) { LOG_ERROR("Failed to init context"); errcode_ = IndexError_Runtime; } return; } IndexQueryMeta qmeta(meta_.data_type(), meta_.dimension()); for (node_id_t id = idx; id < entity_.doc_cnt(); id += step_size) { ctx->reset_query(entity_.get_vector(id)); ret = alg_->add_node(id, entity_.get_level(id), ctx); if (ailego_unlikely(ret != 0)) { if (!error_.exchange(true)) { LOG_ERROR("Hnsw graph add node failed"); errcode_ = ret; } return; } ctx->clear(); (*finished)++; } } int HnswBuilder::dump(const IndexDumper::Pointer &dumper) { if (state_ != BUILD_STATE_BUILT) { LOG_INFO("Build the index before HnswBuilder::dump"); return IndexError_NoReady; } LOG_INFO("Begin HnswBuilder::dump"); meta_.set_searcher("HnswSearcher", HnswEntity::kRevision, ailego::Params()); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } ret = entity_.dump(dumper); if (ret != 0) { LOG_ERROR("HnswBuilder dump index failed"); return ret; } stats_.set_dumped_count(entity_.doc_cnt()); stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); LOG_INFO("EndHnswBuilder::dump"); return 0; } INDEX_FACTORY_REGISTER_BUILDER(HnswBuilder); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_algorithm.h" #include "hnsw_builder_entity.h" namespace zvec { namespace core { class HnswBuilder : public IndexBuilder { public: //! Constructor HnswBuilder(); //! Initialize the builder virtual int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder virtual int cleanup(void) override; //! Train the data virtual int train(IndexThreads::Pointer, IndexHolder::Pointer holder) override; //! Train the data virtual int train(const IndexTrainer::Pointer &trainer) override; //! Build the index virtual int build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) override; //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } private: void do_build(node_id_t idx, size_t step_size, std::atomic *finished); constexpr static uint32_t kDefaultLogIntervalSecs = 15U; constexpr static uint32_t kMaxNeighborCnt = 65535; private: enum BUILD_STATE { BUILD_STATE_INIT = 0, BUILD_STATE_INITED = 1, BUILD_STATE_TRAINED = 2, BUILD_STATE_BUILT = 3 }; HnswBuilderEntity entity_{}; HnswAlgorithm::UPointer alg_; // impl graph algorithm uint32_t thread_cnt_{0}; uint32_t min_neighbor_cnt_{0}; uint32_t upper_max_neighbor_cnt_{HnswEntity::kDefaultUpperMaxNeighborCnt}; uint32_t l0_max_neighbor_cnt_{HnswEntity::kDefaultL0MaxNeighborCnt}; uint32_t ef_construction_{HnswEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswEntity::kDefaultScalingFactor}; uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; int errcode_{0}; std::atomic_bool error_{false}; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; std::mutex mutex_{}; std::condition_variable cond_{}; Stats stats_{}; BUILD_STATE state_{BUILD_STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_builder_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_builder_entity.h" #include #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswBuilderEntity::HnswBuilderEntity() { update_ep_and_level(kInvalidNodeId, 0U); } int HnswBuilderEntity::cleanup() { memory_quota_ = 0UL; neighbors_size_ = 0U; upper_neighbors_size_ = 0U; padding_size_ = 0U; vectors_buffer_.clear(); keys_buffer_.clear(); neighbors_buffer_.clear(); upper_neighbors_buffer_.clear(); neighbors_index_.clear(); vectors_buffer_.shrink_to_fit(); keys_buffer_.shrink_to_fit(); neighbors_buffer_.shrink_to_fit(); upper_neighbors_buffer_.shrink_to_fit(); neighbors_index_.shrink_to_fit(); this->HnswEntity::cleanup(); return 0; } int HnswBuilderEntity::init() { size_t size = vector_size(); //! aligned size to 32 set_node_size(AlignSize(size)); //! if node size is aligned to 1k, the build performance will downgrade if (node_size() % 1024 == 0) { set_node_size(AlignSize(node_size() + 1)); } padding_size_ = node_size() - size; neighbors_size_ = neighbors_size(); upper_neighbors_size_ = upper_neighbors_size(); return 0; } int HnswBuilderEntity::reserve_space(size_t docs) { if (memory_quota_ > 0 && (node_size() * docs + neighbors_size_ * docs + sizeof(NeighborIndex) * docs > memory_quota_)) { return IndexError_NoMemory; } vectors_buffer_.reserve(node_size() * docs); keys_buffer_.reserve(sizeof(key_t) * docs); neighbors_buffer_.reserve(neighbors_size_ * docs); neighbors_index_.reserve(docs); return 0; } int HnswBuilderEntity::add_vector(level_t level, key_t key, const void *vec, node_id_t *id) { if (memory_quota_ > 0 && (vectors_buffer_.capacity() + keys_buffer_.capacity() + neighbors_buffer_.capacity() + upper_neighbors_buffer_.capacity() + neighbors_index_.capacity() * sizeof(NeighborIndex)) > memory_quota_) { LOG_ERROR("Add vector failed, used memory exceed quota, cur_doc=%u", doc_cnt()); return IndexError_NoMemory; } vectors_buffer_.append(reinterpret_cast(vec), vector_size()); vectors_buffer_.append(padding_size_, '\0'); keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); // init level 0 neighbors neighbors_buffer_.append(neighbors_size_, '\0'); neighbors_index_.emplace_back(upper_neighbors_buffer_.size(), level); // init upper layer neighbors for (level_t cur_level = 1; cur_level <= level; ++cur_level) { upper_neighbors_buffer_.append(upper_neighbors_size_, '\0'); } *id = (*mutable_doc_cnt())++; return 0; } key_t HnswBuilderEntity::get_key(node_id_t id) const { return *(reinterpret_cast(keys_buffer_.data() + id * sizeof(key_t))); } const void *HnswBuilderEntity::get_vector(node_id_t id) const { return vectors_buffer_.data() + id * node_size(); } int HnswBuilderEntity::get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector(id); block.reset((void *)vec); return 0; } int HnswBuilderEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { for (uint32_t i = 0; i < count; ++i) { vecs[i] = vectors_buffer_.data() + ids[i] * node_size(); } return 0; } int HnswBuilderEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { const void *vecs[count]; get_vector(ids, count, vecs); for (uint32_t i = 0; i < count; ++i) { vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } const Neighbors HnswBuilderEntity::get_neighbors(level_t level, node_id_t id) const { const NeighborsHeader *hd = get_neighbor_header(level, id); return {hd->neighbor_cnt, hd->neighbors}; } int HnswBuilderEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); for (size_t i = 0; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } hd->neighbor_cnt = neighbors.size(); // std::cout << "id: " << id << ", neighbour, id: "; // for (size_t i = 0; i < neighbors.size(); ++i) { // if (i == neighbors.size()-1) // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // std::endl; // else // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // ", id: "; // } return 0; } void HnswBuilderEntity::add_neighbor(level_t level, node_id_t id, uint32_t /*size*/, node_id_t neighbor_id) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); hd->neighbors[hd->neighbor_cnt++] = neighbor_id; return; } int HnswBuilderEntity::dump(const IndexDumper::Pointer &dumper) { key_t *keys = reinterpret_cast(const_cast(keys_buffer_.data())); auto ret = dump_segments(dumper, keys, [&](node_id_t id) { return get_level(id); }); if (ailego_unlikely(ret < 0)) { return ret; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_builder_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_entity.h" namespace zvec { namespace core { class HnswBuilderEntity : public HnswEntity { public: //! Add vector and key to hnsw entity, and local id will be saved to id virtual int add_vector(level_t level, key_t key, const void *vec, node_id_t *id) override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const override; //! Batch get vectors feature data by keys virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get the node id's neighbors on graph level const NeighborsHeader *get_neighbor_header(level_t level, node_id_t id) const { if (level == 0) { return reinterpret_cast( neighbors_buffer_.data() + neighbors_size_ * id); } else { size_t offset = neighbors_index_[id].offset; return reinterpret_cast( upper_neighbors_buffer_.data() + offset + (level - 1) * upper_neighbors_size_); } } //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Replace node id in level's neighbors virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! add a neighbor to id in graph level virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Dump the hnsw graph to dumper virtual int dump(const IndexDumper::Pointer &dumper) override; //! Cleanup the entity virtual int cleanup(void) override; public: //! Constructor HnswBuilderEntity(); //! Get the node graph level by id level_t get_level(node_id_t id) const { return neighbors_index_[id].level; } //! Init builerEntity int init(); //! reserve buffer space for documents //! @param docs number of documents int reserve_space(size_t docs); //! Set memory quota params inline void set_memory_quota(size_t memory_quota) { memory_quota_ = memory_quota; } //! Get neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get upper neighbors size inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } public: HnswBuilderEntity(const HnswBuilderEntity &) = delete; HnswBuilderEntity &operator=(const HnswBuilderEntity &) = delete; private: friend class HnswSearcherEntity; //! class internal used only struct NeighborIndex { NeighborIndex(size_t off, level_t l) : offset(off), level(l) {} uint64_t offset : 48; uint64_t level : 16; }; std::string vectors_buffer_{}; // aligned vectors std::string keys_buffer_{}; // aligned vectors std::string neighbors_buffer_{}; // level 0 neighbors buffer std::string upper_neighbors_buffer_{}; // upper layer neighbors buffer std::string sparse_data_buffer_{}; // aligned spase data buffer size_t sparse_data_offset_{0}; // // upper layer offset + level in upper_neighbors_buffer_ std::vector neighbors_index_{}; size_t memory_quota_{0UL}; size_t neighbors_size_{0U}; // level 0 neighbors size size_t upper_neighbors_size_{0U}; // level 0 neighbors size size_t padding_size_{}; // padding size for each vector element }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_chunk.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_chunk.h" #include #include #include #include #include #include #include #include namespace zvec { namespace core { int ChunkBroker::init_storage(size_t chunk_size) { chunk_meta_.clear(); chunk_meta_.chunk_size = chunk_size; chunk_meta_.create_time = ailego::Realtime::Seconds(); stats_.set_create_time(chunk_meta_.create_time); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); //! alloc meta chunk size_t size = sizeof(HnswChunkMeta); size = (size + page_mask_) & (~page_mask_); const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return ret; } chunk_meta_segment_ = get_chunk(CHUNK_TYPE_META, kDefaultChunkSeqId); if (ailego_unlikely(!chunk_meta_segment_)) { LOG_ERROR("Get meta segment failed"); return IndexError_Runtime; } //! update meta info and write to storage chunk_meta_.chunk_cnts[CHUNK_TYPE_META] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); return IndexError_WriteData; } return 0; } int ChunkBroker::load_storage(size_t chunk_size) { IndexStorage::MemoryBlock data_block; size_t size = chunk_meta_segment_->read(0UL, data_block, chunk_meta_segment_->data_size()); if (size != sizeof(HnswChunkMeta)) { LOG_ERROR("Invalid hnsw meta chunk, read size=%zu chunk size=%zu", size, chunk_meta_segment_->data_size()); return IndexError_InvalidFormat; } std::memcpy(&chunk_meta_, data_block.data(), size); if (chunk_meta_.chunk_size != chunk_size) { LOG_ERROR( "Params hnsw chunk size=%zu mismatch from previous %zu " "in index", chunk_size, (size_t)chunk_meta_.chunk_size); return IndexError_Mismatch; } *stats_.mutable_check_point() = stg_->check_point(); stats_.set_revision_id(chunk_meta_.revision_id); stats_.set_update_time(chunk_meta_.update_time); stats_.set_create_time(chunk_meta_.create_time); char create_time[32]; char update_time[32]; ailego::Realtime::Gmtime(chunk_meta_.create_time, "%Y-%m-%d %H:%M:%S", create_time, sizeof(create_time)); ailego::Realtime::Gmtime(chunk_meta_.update_time, "%Y-%m-%d %H:%M:%S", update_time, sizeof(update_time)); LOG_DEBUG( "Load index, indexSize=%zu chunkSize=%zu nodeChunks=%zu " "upperNeighborChunks=%zu revisionId=%zu " "createTime=%s updateTime=%s", (size_t)chunk_meta_.total_size, (size_t)chunk_meta_.chunk_size, (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_NODE], (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_UPPER_NEIGHBOR], (size_t)chunk_meta_.revision_id, create_time, update_time); return 0; } int ChunkBroker::open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc) { if (ailego_unlikely(stg_)) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } stg_ = std::move(stg); if (stg_->isHugePage()) { page_mask_ = ailego::MemoryHelper::HugePageSize() - 1; } else { page_mask_ = ailego::MemoryHelper::PageSize() - 1; } check_crc_ = check_crc; max_chunks_size_ = max_index_size; dirty_ = false; const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); chunk_meta_segment_ = stg_->get(segment_id); if (!chunk_meta_segment_) { LOG_DEBUG("Create new index"); return init_storage(chunk_size); } return load_storage(chunk_size); } int ChunkBroker::close(void) { flush(0UL); stg_.reset(); check_crc_ = false; dirty_ = false; return 0; } int ChunkBroker::flush(uint64_t checkpoint) { ailego_assert_with(chunk_meta_segment_, "invalid meta segment"); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); size_t size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); } stg_->refresh(checkpoint); int ret = stg_->flush(); if (ret == 0) { (*stats_.mutable_check_point()) = checkpoint; } else { LOG_ERROR("Storage flush failed for %s", IndexError::What(ret)); } return ret; } std::pair ChunkBroker::alloc_chunk(int type, uint64_t seq_id, size_t size) { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); Chunk::Pointer chunk; if (ailego_unlikely(!stg_)) { LOG_ERROR("Init storage first"); return std::make_pair(IndexError_Uninitialized, chunk); } //! check exist a empty chunk with the same name chunk = get_chunk(type, seq_id); if (chunk) { if (ailego_unlikely(chunk->capacity() == size && chunk->data_size() == 0UL)) { LOG_ERROR("Exist invalid chunk size %zu, expect size %zu", chunk->capacity(), size); chunk.reset(); return std::make_pair(IndexError_Runtime, chunk); } return std::make_pair(0, chunk); } //! align to page size size = (size + page_mask_) & (~page_mask_); if (ailego_unlikely(chunk_meta_.total_size + size >= max_chunks_size_)) { LOG_ERROR("No space to new a chunk, curIndexSize=%zu allocSize=%zu", (size_t)chunk_meta_.total_size, size); return std::make_pair(IndexError_IndexFull, chunk); } std::string segment_id = make_segment_id(type, seq_id); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return std::make_pair(ret, chunk); } chunk_meta_.chunk_cnts[type] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage append segment failed, wsize=%zu", size); } chunk = get_chunk(type, seq_id); return std::make_pair(chunk ? 0 : IndexError_NoMemory, chunk); } Chunk::Pointer ChunkBroker::get_chunk(int type, uint64_t seq_id) const { std::string segment_id = make_segment_id(type, seq_id); return stg_->get(segment_id); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_chunk.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace zvec { namespace core { using Chunk = IndexStorage::Segment; class ChunkBroker { public: typedef std::shared_ptr Pointer; enum CHUNK_TYPE { CHUNK_TYPE_HEADER = 1, CHUNK_TYPE_META = 2, CHUNK_TYPE_NODE = 3, CHUNK_TYPE_UPPER_NEIGHBOR = 4, CHUNK_TYPE_NEIGHBOR_INDEX = 5, CHUNK_TYPE_SPARSE_NODE = 6, CHUNK_TYPE_MAX = 8 }; static constexpr size_t kDefaultChunkSeqId = 0UL; ChunkBroker(IndexStreamer::Stats &stats) : stats_(stats) {} //! Open storage int open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc); int close(void); int flush(uint64_t checkpoint); //! alloc a new chunk with size, not thread-safe std::pair alloc_chunk(int type, uint64_t seq_id, size_t size); //! alloc a new chunk with chunk size inline std::pair alloc_chunk(int type, uint64_t seq_id) { return alloc_chunk(type, seq_id, chunk_meta_.chunk_size); } Chunk::Pointer get_chunk(int type, uint64_t seq_id) const; inline size_t get_chunk_cnt(int type) const { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); return chunk_meta_.chunk_cnts[type]; } inline bool dirty(void) const { return dirty_; } inline void mark_dirty(void) { if (!dirty_) { dirty_ = true; chunk_meta_.revision_id += 1; stats_.set_revision_id(chunk_meta_.revision_id); } } const IndexStorage::Pointer storage(void) const { return stg_; } private: ChunkBroker(const ChunkBroker &) = delete; ChunkBroker &operator=(const ChunkBroker &) = delete; struct HnswChunkMeta { HnswChunkMeta(void) { memset(this, 0, sizeof(HnswChunkMeta)); } void clear() { memset(this, 0, sizeof(HnswChunkMeta)); } uint64_t chunk_cnts[CHUNK_TYPE_MAX]; uint64_t chunk_size; // size of per chunk uint64_t total_size; // total size of allocated chunk uint64_t revision_id; // index revision uint64_t create_time; uint64_t update_time; uint64_t reserved[3]; }; static_assert(sizeof(HnswChunkMeta) % 32 == 0, "HnswChunkMeta must be aligned with 32 bytes"); //! Init the storage after open an empty index int init_storage(size_t chunk_size); //! Load index from storage int load_storage(size_t chunk_size); static inline const std::string make_segment_id(int type, uint64_t seq_id) { return "HnswT" + ailego::StringHelper::ToString(type) + "S" + ailego::StringHelper::ToString(seq_id); } private: IndexStreamer::Stats &stats_; HnswChunkMeta chunk_meta_{}; size_t page_mask_{0UL}; size_t max_chunks_size_{0UL}; IndexStorage::Pointer stg_{}; IndexStorage::Segment::Pointer chunk_meta_segment_{}; bool check_crc_{false}; bool dirty_{false}; // set as true if index is modified , the flag // will not be cleared even if flushed }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_context.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_context.h" #include #include "hnsw_params.h" namespace zvec { namespace core { HnswContext::HnswContext(size_t dimension, const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity) : IndexContext(metric), entity_(entity), dc_(entity_.get(), metric, dimension) {} HnswContext::HnswContext(const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity) : IndexContext(metric), entity_(entity), dc_(entity_.get(), metric) {} HnswContext::~HnswContext() { visit_filter_.destroy(); } int HnswContext::init(ContextType type) { int ret; uint32_t doc_cnt; type_ = type; switch (type) { case kBuilderContext: ret = visit_filter_.init(VisitFilter::ByteMap, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); update_heap_.limit(entity_->l0_neighbor_cnt() + 1); break; case kSearcherContext: ret = visit_filter_.init(filter_mode_, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); break; case kStreamerContext: // maxScanNum is unknown if inited from streamer, so the docCnt may // change. we need to compute maxScanNum by scan ratio, and preserve // max_doc_cnt space from visit filter doc_cnt = entity_->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); ret = visit_filter_.init(filter_mode_, reserve_max_doc_cnt_, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } update_heap_.limit(entity_->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); check_need_adjuct_ctx(); break; default: LOG_ERROR("Init context failed"); return IndexError_Runtime; } return 0; } int HnswContext::update(const ailego::Params ¶ms) { auto update_visit_filter_param = [&]() { bool need_update = false; std::string p; switch (type_) { case kSearcherContext: p = PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_ENABLE; break; case kStreamerContext: p = PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_ENABLE; break; } if (params.has(p)) { bool bf_enabled; params.get(p, &bf_enabled); if (bf_enabled ^ (filter_mode_ == VisitFilter::BloomFilter)) { need_update = true; filter_mode_ = bf_enabled ? VisitFilter::BloomFilter : VisitFilter::ByteMap; } } float prob = negative_probability_; p.clear(); switch (type_) { case kSearcherContext: p = PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; case kStreamerContext: p = PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; } params.get(p, &prob); if (filter_mode_ == VisitFilter::BloomFilter && std::abs(prob - negative_probability_) > 1e-6) { need_update = true; } if (need_update) { visit_filter_.destroy(); int max_doc_cnt = 0; if (type_ == kSearcherContext) { max_doc_cnt = entity_->doc_cnt(); } else { max_doc_cnt = reserve_max_doc_cnt_; } int ret = visit_filter_.init(filter_mode_, max_doc_cnt, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } } return 0; }; switch (type_) { case kSearcherContext: if (params.has(PARAM_HNSW_SEARCHER_EF)) { params.get(PARAM_HNSW_SEARCHER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } if (params.has(PARAM_HNSW_SEARCHER_MAX_SCAN_RATIO)) { params.get(PARAM_HNSW_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); max_scan_num_ = static_cast(max_scan_ratio_ * entity_->doc_cnt()); max_scan_num_ = std::max(10000U, max_scan_num_); } if (params.has(PARAM_HNSW_SEARCHER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); case kStreamerContext: if (params.has(PARAM_HNSW_STREAMER_EF)) { params.get(PARAM_HNSW_STREAMER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } params.get(PARAM_HNSW_STREAMER_EF, &ef_); params.get(PARAM_HNSW_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (params.has(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); default: LOG_ERROR("update context failed, type=%u", type_); return IndexError_Runtime; } } int HnswContext::update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity, uint32_t magic_num) { uint32_t doc_cnt; if (ailego_unlikely(type != type_)) { LOG_ERROR( "HnswContext doesn't support shared by different type, " "src=%u dst=%u", type_, type); return IndexError_Unsupported; } magic_ = kInvalidMgic; // TODO: support change filter mode? switch (type) { case kBuilderContext: LOG_ERROR("BuildContext doesn't support update"); return IndexError_NotImplemented; case kSearcherContext: if (!visit_filter_.reset(entity->doc_cnt(), max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; case kStreamerContext: doc_cnt = entity->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); if (!visit_filter_.reset(reserve_max_doc_cnt_, max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } update_heap_.limit(entity->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; default: LOG_ERROR("update context failed"); return IndexError_Runtime; } entity_ = entity; dc_.update(entity_.get(), metric, meta.dimension()); magic_ = magic_num; level_topks_.clear(); return 0; } void HnswContext::fill_random_to_topk_full(void) { static std::mt19937 mt( std::chrono::system_clock::now().time_since_epoch().count()); std::uniform_int_distribution dt(0, entity_->doc_cnt() - 1); std::function gen; node_id_t seqid; std::function myfilter = [](node_id_t) { return false; }; if (this->filter().is_valid()) { myfilter = [&](node_id_t id) { return this->filter()(entity_->get_key(id)); }; } if (topk_heap_.limit() < entity_->doc_cnt() / 2) { gen = [&](void) { return dt(mt); }; } else { // If topk limit is big value, gen sequential id from an random initial seqid = dt(mt); gen = [&](void) { seqid = seqid == (entity_->doc_cnt() - 1) ? 0 : (seqid + 1); return seqid; }; } for (size_t i = 0; !topk_heap_.full() && i < entity_->doc_cnt(); ++i) { const auto id = gen(); if (!visit_filter_.visited(id) && !myfilter(id)) { visit_filter_.set_visited(id); topk_heap_.emplace(id, dc_.dist(id)); } } return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "utility/sparse_utility.h" #include "utility/visit_filter.h" #include "hnsw_dist_calculator.h" #include "hnsw_entity.h" namespace zvec { namespace core { class HnswContext : public IndexContext { public: //! Index Context Pointer typedef std::unique_ptr Pointer; enum ContextType { kUnknownContext = 0, kSearcherContext = 1, kBuilderContext = 2, kStreamerContext = 3 }; //! Construct HnswContext(size_t dimension, const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity); //! Construct HnswContext(const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity); //! Destructor virtual ~HnswContext(); public: //! Set topk of search result virtual void set_topk(uint32_t val) override { topk_ = val; topk_heap_.limit(std::max(val, ef_)); } //! Retrieve search result virtual const IndexDocumentList &result(void) const override { return results_[0]; } //! Retrieve search result virtual const IndexDocumentList &result(size_t idx) const override { return results_[idx]; } //! Retrieve result object for output virtual IndexDocumentList *mutable_result(size_t idx) override { ailego_assert_with(idx < results_.size(), "invalid idx"); return &results_[idx]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result( size_t idx) const override { return group_results_[idx]; } virtual uint32_t magic(void) const override { return magic_; } //! Set mode of debug virtual void set_debug_mode(bool enable) override { debug_mode_ = enable; } //! Retrieve mode of debug virtual bool debug_mode(void) const override { return this->debugging(); } //! Retrieve string of debug virtual std::string debug_string(void) const override { char buf[4096]; size_t size = snprintf( buf, sizeof(buf), "scan_cnt=%zu,get_vector_cnt=%u,get_neighbors_cnt=%u,dup_node=%u", get_scan_num(), stats_get_vector_cnt_, stats_get_neighbors_cnt_, stats_visit_dup_cnt_); return std::string(buf, size); } //! Update the parameters of context virtual int update(const ailego::Params ¶ms) override; public: //! Init context int init(ContextType type); //! Update context, the context may be shared by different searcher/streamer int update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, const HnswEntity::Pointer &entity, uint32_t magic_num); inline const HnswEntity &get_entity() const { return *entity_; } inline void resize_results(size_t size) { if (group_by_search()) { group_results_.resize(size); } else { results_.resize(size); } } inline void topk_to_result() { return topk_to_result(0); } //! Construct result from topk heap, result will be normalized inline void topk_to_result(uint32_t idx) { if (group_by_search()) { topk_to_group_result(idx); } else { topk_to_single_result(idx); } } inline void recal_topk_dist() { TopkHeap heap(topk_heap_); topk_heap_.clear(); for (size_t i = 0; i < heap.size(); ++i) { node_id_t id = heap[i].first; dist_t dist = dc_.dist(id); topk_heap_.emplace_back(id, dist); } } inline void topk_to_single_result(uint32_t idx) { if (force_padding_topk_ && !topk_heap_.full() && topk_heap_.size() < entity_->doc_cnt()) { this->fill_random_to_topk_full(); } if (ailego_unlikely(topk_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(topk_heap_.size())); topk_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = topk_heap_[i].second; if (score > this->threshold()) { break; } node_id_t id = topk_heap_[i].first; if (fetch_vector_) { results_[idx].emplace_back(entity_->get_key(id), score, id, entity_->get_vector(id)); } else { results_[idx].emplace_back(entity_->get_key(id), score, id); } } return; } //! Construct result from topk heap, result will be normalized inline void topk_to_group_result(uint32_t idx) { ailego_assert_with(idx < group_results_.size(), "invalid idx"); group_results_[idx].clear(); std::vector> group_topk_list; std::vector> best_score_in_groups; for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; heap.sort(); if (heap.size() > 0) { float best_score = heap[0].second; best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; group_topk_list.emplace_back( std::make_pair(group_id, group_topk_heaps_[group_id])); } group_results_[idx].resize(group_topk_list.size()); for (uint32_t i = 0; i < group_topk_list.size(); ++i) { const std::string &group_id = group_topk_list[i].first; group_results_[idx][i].set_group_id(group_id); uint32_t size = std::min( group_topk_, static_cast(group_topk_list[i].second.size())); for (uint32_t j = 0; j < size; ++j) { auto score = group_topk_list[i].second[j].second; if (score > this->threshold()) { break; } node_id_t id = group_topk_list[i].second[j].first; if (fetch_vector_) { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score, id, entity_->get_vector(id)); } else { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score, id); } } } } inline void reset_query(const void *query) { if (auto query_preprocess_func = index_metric_->get_query_preprocess_func(); query_preprocess_func != nullptr) { size_t dim = dc_.dimension(); preprocess_buffer_.resize(dim); memcpy(preprocess_buffer_.data(), query, dim); query_preprocess_func(preprocess_buffer_.data(), dim); query = preprocess_buffer_.data(); } dc_.reset_query(query); dc_.clear_compare_cnt(); } inline HnswDistCalculator &dist_calculator() { return dc_; } inline TopkHeap &topk_heap() { return topk_heap_; } inline TopkHeap &update_heap() { return update_heap_; } inline VisitFilter &visit_filter() { return visit_filter_; } inline CandidateHeap &candidates() { return candidates_; } inline void set_max_scan_num(uint32_t max_scan_num) { max_scan_num_ = max_scan_num; } inline void set_max_scan_limit(uint32_t max_scan_limit) { max_scan_limit_ = max_scan_limit; } inline void set_min_scan_limit(uint32_t min_scan_limit) { min_scan_limit_ = min_scan_limit; } inline void set_ef(uint32_t v) { ef_ = v; } inline void set_filter_mode(uint32_t v) { filter_mode_ = v; } inline void set_filter_negative_probability(float v) { negative_probability_ = v; } inline void set_max_scan_ratio(float v) { max_scan_ratio_ = v; } virtual void set_magic(uint32_t v) { magic_ = v; } virtual void set_force_padding_topk(bool v) { force_padding_topk_ = v; } void set_bruteforce_threshold(uint32_t v) override { bruteforce_threshold_ = v; } inline uint32_t get_bruteforce_threshold() const { return bruteforce_threshold_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector() const override { return fetch_vector_; } //! Reset context void reset(void) override { this->clear(); set_filter(nullptr); reset_threshold(); set_fetch_vector(false); set_group_params(0, 0); reset_group_by(); } inline std::map &group_topk_heaps() { return group_topk_heaps_; } inline TopkHeap &level_topk(int level) { if (ailego_unlikely(level_topks_.size() <= static_cast(level))) { int cur_level = level_topks_.size(); level_topks_.resize(level + 1); for (; cur_level <= level; ++cur_level) { size_t heap_size = std::max(entity_->neighbor_cnt(cur_level), entity_->ef_construction()); level_topks_[cur_level].clear(); level_topks_[cur_level].limit(heap_size); } } return level_topks_[level]; } inline void check_need_adjuct_ctx(void) { check_need_adjuct_ctx(entity_->doc_cnt()); } inline size_t compute_reserve_cnt(uint32_t cur_doc) const { if (cur_doc > kMaxReserveDocCnt) { return kMaxReserveDocCnt; } else if (cur_doc < kMinReserveDocCnt) { return kMinReserveDocCnt; } return cur_doc; } //! candidates heap and visitfilter need to resize as doc cnt growing up inline void check_need_adjuct_ctx(uint32_t doc_cnt) { if (ailego_unlikely(doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_)) { while (doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_) { reserve_max_doc_cnt_ = reserve_max_doc_cnt_ + compute_reserve_cnt(reserve_max_doc_cnt_); } uint32_t max_scan_cnt = compute_max_scan_num(reserve_max_doc_cnt_); max_scan_num_ = max_scan_cnt; visit_filter_.reset(reserve_max_doc_cnt_, max_scan_cnt); candidates_.clear(); candidates_.limit(max_scan_num_); } } inline uint32_t compute_max_scan_num(uint32_t max_doc_cnt) const { uint32_t max_scan = max_doc_cnt * max_scan_ratio_; if (max_scan < min_scan_limit_) { max_scan = min_scan_limit_; } else if (max_scan > max_scan_limit_) { max_scan = max_scan_limit_; } return max_scan; } inline size_t get_scan_num() const { return dc_.compare_cnt(); } inline uint64_t reach_scan_limit() const { return dc_.compare_cnt() >= max_scan_num_; } inline bool error() const { return dc_.error(); } inline void clear() { dc_.clear(); if (ailego_unlikely(this->debugging())) { stats_get_neighbors_cnt_ = 0u; stats_get_vector_cnt_ = 0u; stats_visit_dup_cnt_ = 0u; } // do not clear results_ for the next query will need it for (auto &it : results_) { it.clear(); } for (auto &it : group_results_) { it.clear(); } } uint32_t *mutable_stats_get_neighbors() { return &stats_get_neighbors_cnt_; } uint32_t *mutable_stats_get_vector() { return &stats_get_vector_cnt_; } uint32_t *mutable_stats_visit_dup_cnt() { return &stats_visit_dup_cnt_; } inline bool debugging(void) const { return debug_mode_; } inline void update_dist_caculator_distance( const IndexMetric::MatrixDistance &distance, const IndexMetric::MatrixBatchDistance &batch_distance) { dc_.update_distance(distance, batch_distance); } //! Get topk inline uint32_t topk() const override { return topk_; } //! Get group topk inline uint32_t group_topk() const { return group_topk_; } //! Get group num inline uint32_t group_num() const { return group_num_; } //! Get if group by search inline bool group_by_search() { return group_num_ > 0; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; topk_ = group_topk_ * group_num_; topk_heap_.limit(std::max(topk_, ef_)); group_topk_heaps_.clear(); } private: // Filling random nodes if topk not full void fill_random_to_topk_full(void); constexpr static uint32_t kTriggerReserveCnt = 4096UL; constexpr static uint32_t kMinReserveDocCnt = 4096UL; constexpr static uint32_t kMaxReserveDocCnt = 128 * 1024UL; constexpr static uint32_t kInvalidMgic = -1U; private: HnswEntity::Pointer entity_; HnswDistCalculator dc_; IndexMetric::Pointer metric_; bool debug_mode_{false}; bool force_padding_topk_{false}; uint32_t max_scan_num_{0}; uint32_t max_scan_limit_{0}; uint32_t min_scan_limit_{0}; uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt}; uint32_t topk_{0}; uint32_t group_topk_{0}; uint32_t filter_mode_{VisitFilter::ByteMap}; float negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; uint32_t ef_{HnswEntity::kDefaultEf}; float max_scan_ratio_{HnswEntity::kDefaultScanRatio}; uint32_t magic_{0U}; std::vector results_{}; std::vector group_results_{}; TopkHeap topk_heap_{}; TopkHeap update_heap_{}; std::vector level_topks_{}; CandidateHeap candidates_{}; VisitFilter visit_filter_{}; uint32_t bruteforce_threshold_{}; bool fetch_vector_{false}; uint32_t group_num_{0}; std::map group_topk_heaps_{}; uint32_t type_{kUnknownContext}; //! debug stats info uint32_t stats_get_neighbors_cnt_{0u}; uint32_t stats_get_vector_cnt_{0u}; uint32_t stats_visit_dup_cnt_{0u}; std::string preprocess_buffer_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_dist_calculator.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_entity.h" namespace zvec { namespace core { class HnswDistCalculator { public: typedef std::shared_ptr Pointer; public: enum DistType { DIST_NONE = 0, DIST_DENSE = 1, DIST_HYBRID = 2, DIST_SPARSE = 3 }; public: //! Constructor HnswDistCalculator(const HnswEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(nullptr), dim_(dim), compare_cnt_(0) {} //! Constructor HnswDistCalculator(const HnswEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim, const void *query) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(query), dim_(dim), compare_cnt_(0) {} //! Constructor HnswDistCalculator(const HnswEntity *entity, const IndexMetric::Pointer &metric) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(nullptr), dim_(0), compare_cnt_(0) {} void update(const HnswEntity *entity, const IndexMetric::Pointer &metric) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); } void update(const HnswEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); dim_ = dim; } inline void update_distance( const IndexMetric::MatrixDistance &distance, const IndexMetric::MatrixBatchDistance &batch_distance) { distance_ = distance; batch_distance_ = batch_distance; } //! Reset query vector data inline void reset_query(const void *query) { error_ = false; query_ = query; } //! Returns distance inline dist_t dist(const void *vec_lhs, const void *vec_rhs) { if (ailego_unlikely(vec_lhs == nullptr || vec_rhs == nullptr)) { LOG_ERROR("Nullptr of dense vector"); error_ = true; return 0.0f; } float score{0.0f}; distance_(vec_lhs, vec_rhs, dim_, &score); return score; } //! Returns distance between query and vec. inline dist_t dist(const void *vec) { compare_cnt_++; return dist(vec, query_); } //! Return distance between query and node id. inline dist_t dist(node_id_t id) { compare_cnt_++; const void *feat = entity_->get_vector(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; return 0.0f; } return dist(feat, query_); } //! Return dist node lhs between node rhs inline dist_t dist(node_id_t lhs, node_id_t rhs) { compare_cnt_++; const void *feat = entity_->get_vector(lhs); const void *query = entity_->get_vector(rhs); if (ailego_unlikely(feat == nullptr || query == nullptr)) { LOG_ERROR("Get nullptr vector"); error_ = true; return 0.0f; } return dist(feat, query); } dist_t operator()(const void *vec) { return dist(vec); } dist_t operator()(id_t i) { return dist(i); } dist_t operator()(id_t lhs, id_t rhs) { return dist(lhs, rhs); } void batch_dist(const void **vecs, size_t num, dist_t *distances) { compare_cnt_++; batch_distance_(vecs, query_, num, dim_, distances); } inline dist_t batch_dist(node_id_t id) { compare_cnt_++; const void *feat = entity_->get_vector(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; return 0.0f; } dist_t score = 0; batch_distance_(&feat, query_, 1, dim_, &score); return score; } inline void clear() { compare_cnt_ = 0; error_ = false; } inline void clear_compare_cnt() { compare_cnt_ = 0; } inline bool error() const { return error_; } //! Get distances compute times inline uint32_t compare_cnt() const { return compare_cnt_; } inline uint32_t dimension() const { return dim_; } private: HnswDistCalculator(const HnswDistCalculator &) = delete; HnswDistCalculator &operator=(const HnswDistCalculator &) = delete; private: const HnswEntity *entity_; IndexMetric::MatrixDistance distance_; IndexMetric::MatrixBatchDistance batch_distance_; const void *query_; uint32_t dim_; uint32_t compare_cnt_; // record distance compute times // uint32_t compare_cnt_batch_; // record batch distance compute time bool error_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_entity.h" #include #include "utility/sparse_utility.h" namespace zvec { namespace core { const std::string HnswEntity::kGraphHeaderSegmentId = "graph.header"; const std::string HnswEntity::kGraphFeaturesSegmentId = "graph.features"; const std::string HnswEntity::kGraphKeysSegmentId = "graph.keys"; const std::string HnswEntity::kGraphNeighborsSegmentId = "graph.neighbors"; const std::string HnswEntity::kGraphOffsetsSegmentId = "graph.offsets"; const std::string HnswEntity::kGraphMappingSegmentId = "graph.mapping"; const std::string HnswEntity::kHnswHeaderSegmentId = "hnsw.header"; const std::string HnswEntity::kHnswNeighborsSegmentId = "hnsw.neighbors"; const std::string HnswEntity::kHnswOffsetsSegmentId = "hnsw.offsets"; int HnswEntity::CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size) { *padding_size = AlignSize(data_size) - data_size; if (*padding_size == 0) { return 0; } std::string padding(*padding_size, '\0'); if (dumper->write(padding.data(), *padding_size) != *padding_size) { LOG_ERROR("Append padding failed, size %lu", *padding_size); return IndexError_WriteData; } return 0; } int64_t HnswEntity::dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const { size_t len = dumper->write(data, size); if (len != size) { LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", segment_id.c_str(), size, len); return IndexError_WriteData; } size_t padding_size = AlignSize(size) - size; if (padding_size > 0) { std::string padding(padding_size, '\0'); if (dumper->write(padding.data(), padding_size) != padding_size) { LOG_ERROR("Append padding failed, size %lu", padding_size); return IndexError_WriteData; } } uint32_t crc = ailego::Crc32c::Hash(data, size); int ret = dumper->append(segment_id, size, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); return ret; } return len + padding_size; } int64_t HnswEntity::dump_header(const IndexDumper::Pointer &dumper, const HNSWHeader &hd) const { //! dump basic graph header. header is aligned and does not need padding int64_t graph_hd_size = dump_segment(dumper, kGraphHeaderSegmentId, &hd.graph, hd.graph.size); if (graph_hd_size < 0) { return graph_hd_size; } //! dump basic graph header. header is aligned and does not need padding int64_t hnsw_hd_size = dump_segment(dumper, kHnswHeaderSegmentId, &hd.hnsw, hd.hnsw.size); if (hnsw_hd_size < 0) { return hnsw_hd_size; } return graph_hd_size + hnsw_hd_size; } void HnswEntity::reshuffle_vectors( const std::function & /*get_level*/, std::vector * /*n2o_mapping*/, std::vector * /*o2n_mapping*/, key_t * /*keys*/) const { // TODO return; } int64_t HnswEntity::dump_mapping_segment(const IndexDumper::Pointer &dumper, const key_t *keys) const { std::vector mapping(doc_cnt()); std::iota(mapping.begin(), mapping.end(), 0U); std::sort(mapping.begin(), mapping.end(), [&](node_id_t i, node_id_t j) { return keys[i] < keys[j]; }); size_t size = mapping.size() * sizeof(node_id_t); return dump_segment(dumper, kGraphMappingSegmentId, mapping.data(), size); } int64_t HnswEntity::dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const { HNSWHeader dump_hd(header()); dump_hd.graph.node_size = AlignSize(vector_size()); std::vector n2o_mapping; // map new id to origin id std::vector o2n_mapping; // map origin id to new id reshuffle_vectors(get_level, &n2o_mapping, &o2n_mapping, keys); if (!o2n_mapping.empty()) { dump_hd.hnsw.entry_point = o2n_mapping[entry_point()]; } //! Dump header int64_t hd_size = dump_header(dumper, dump_hd); if (hd_size < 0) { return hd_size; } //! Dump vectors int64_t vecs_size = dump_vectors(dumper, n2o_mapping); if (vecs_size < 0) { return vecs_size; } //! Dump neighbors auto neighbors_size = dump_neighbors(dumper, get_level, n2o_mapping, o2n_mapping); if (neighbors_size < 0) { return neighbors_size; } //! free memory n2o_mapping = std::vector(); o2n_mapping = std::vector(); //! Dump keys size_t key_segment_size = doc_cnt() * sizeof(key_t); int64_t keys_size = dump_segment(dumper, kGraphKeysSegmentId, keys, key_segment_size); if (keys_size < 0) { return keys_size; } //! Dump mapping int64_t mapping_size = dump_mapping_segment(dumper, keys); if (mapping_size < 0) { return mapping_size; } return hd_size + keys_size + vecs_size + neighbors_size + mapping_size; } int64_t HnswEntity::dump_vectors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const { size_t vector_dump_size = vector_size(); size_t padding_size = AlignSize(vector_dump_size) - vector_dump_size; std::vector padding(padding_size); memset(padding.data(), 0, sizeof(char) * padding_size); const void *data = nullptr; uint32_t crc = 0U; size_t vecs_size = 0UL; //! dump vectors for (node_id_t id = 0; id < doc_cnt(); ++id) { data = get_vector(reorder_mapping.empty() ? id : reorder_mapping[id]); if (ailego_unlikely(!data)) { return IndexError_ReadData; } size_t len = dumper->write(data, vector_size()); if (len != vector_size()) { LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, vector_size()); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(data, vector_size(), crc); vecs_size += vector_size(); if (padding_size == 0) { continue; } len = dumper->write(padding.data(), padding_size); if (len != padding_size) { LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, padding_size); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(padding.data(), padding_size, crc); vecs_size += padding_size; } int ret = dumper->append(kGraphFeaturesSegmentId, vecs_size, 0UL, crc); if (ret != 0) { LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); return ret; } return vecs_size; } int64_t HnswEntity::dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector graph_meta; graph_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; std::vector mapping(l0_neighbor_cnt()); uint32_t min_neighbor_count = 10000; uint32_t max_neighbor_count = 0; size_t sum_neighbor_count = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { const Neighbors neighbors = get_neighbors(0, reorder_mapping.empty() ? id : reorder_mapping[id]); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= l0_neighbor_cnt(), "invalid neighbors"); uint32_t neighbor_count = neighbors.size(); if (neighbor_count < min_neighbor_count) { min_neighbor_count = neighbor_count; } if (neighbor_count > max_neighbor_count) { max_neighbor_count = neighbor_count; } sum_neighbor_count += neighbor_count; graph_meta.emplace_back(offset, neighbor_count); size_t size = neighbors.size() * sizeof(node_id_t); const node_id_t *data = &neighbors[0]; if (!neighbor_mapping.empty()) { for (node_id_t i = 0; i < neighbors.size(); ++i) { mapping[i] = neighbor_mapping[neighbors[i]]; } data = mapping.data(); } if (dumper->write(data, size) != size) { LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, size); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(data, size, crc); offset += size; } uint32_t average_neighbor_count = 0; if (doc_cnt() > 0) { average_neighbor_count = sum_neighbor_count / doc_cnt(); } LOG_INFO( "Dump hnsw graph: min_neighbor_count[%u] max_neighbor_count[%u] " "average_neighbor_count[%u]", min_neighbor_count, max_neighbor_count, average_neighbor_count); size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kGraphNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kGraphNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kGraphOffsetsSegmentId, graph_meta.data(), graph_meta.size() * sizeof(GraphNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } int64_t HnswEntity::dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector hnsw_meta; hnsw_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; std::vector buffer(upper_neighbor_cnt() + 1); for (node_id_t id = 0; id < doc_cnt(); ++id) { node_id_t new_id = reorder_mapping.empty() ? id : reorder_mapping[id]; auto level = get_level(new_id); if (level == 0) { hnsw_meta.emplace_back(0U, 0U); continue; } hnsw_meta.emplace_back(offset, level); ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); for (level_t cur_level = 1; cur_level <= level; ++cur_level) { const Neighbors neighbors = get_neighbors(cur_level, new_id); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= neighbor_cnt(cur_level), "invalid neighbors"); memset(buffer.data(), 0, sizeof(node_id_t) * buffer.size()); buffer[0] = neighbors.size(); if (neighbor_mapping.empty()) { memcpy(&buffer[1], &neighbors[0], neighbors.size() * sizeof(node_id_t)); } else { for (node_id_t i = 0; i < neighbors.size(); ++i) { buffer[i + 1] = neighbor_mapping[neighbors[i]]; } } if (dumper->write(buffer.data(), sizeof(node_id_t) * buffer.size()) != sizeof(node_id_t) * buffer.size()) { LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, sizeof(node_id_t) * buffer.size()); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(buffer.data(), sizeof(node_id_t) * buffer.size(), crc); offset += sizeof(node_id_t) * buffer.size(); } } size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kHnswNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kHnswNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kHnswOffsetsSegmentId, hnsw_meta.data(), hnsw_meta.size() * sizeof(HnswNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include namespace zvec { namespace core { using node_id_t = uint32_t; using key_t = uint64_t; using level_t = int32_t; using dist_t = float; using TopkHeap = ailego::KeyValueHeap; using CandidateHeap = ailego::KeyValueHeap>; constexpr node_id_t kInvalidNodeId = static_cast(-1); constexpr key_t kInvalidKey = static_cast(-1); class DistCalculator; struct GraphHeader { uint32_t size; uint32_t version; uint32_t graph_type; uint32_t doc_count; uint32_t vector_size; uint32_t node_size; uint32_t l0_neighbor_count; uint32_t prune_type; uint32_t prune_neighbor_count; uint32_t ef_construction; uint32_t options; uint32_t min_neighbor_count; uint8_t reserved_[4080]; }; static_assert(sizeof(GraphHeader) % 32 == 0, "GraphHeader must be aligned with 32 bytes"); //! Hnsw upper neighbor header struct HnswHeader { uint32_t size; // header size uint32_t revision; // current total docs of the graph uint32_t upper_neighbor_count; uint32_t ef_construction; uint32_t scaling_factor; uint32_t max_level; uint32_t entry_point; uint32_t options; uint8_t reserved_[30]; }; static_assert(sizeof(HnswHeader) % 32 == 0, "GraphHeader must be aligned with 32 bytes"); //! Hnsw common header and upper neighbor header struct HNSWHeader { HNSWHeader() { clear(); } HNSWHeader(const HNSWHeader &header) { memcpy(this, &header, sizeof(header)); } HNSWHeader &operator=(const HNSWHeader &header) { memcpy(this, &header, sizeof(header)); return *this; } //! Reset state to zero, and the params is untouched void inline reset() { graph.doc_count = 0U; hnsw.entry_point = kInvalidNodeId; hnsw.max_level = 0; } //! Clear all fields to init value void inline clear() { memset(this, 0, sizeof(HNSWHeader)); hnsw.entry_point = kInvalidNodeId; graph.size = sizeof(GraphHeader); hnsw.size = sizeof(HnswHeader); } size_t l0_neighbor_cnt() const { return graph.l0_neighbor_count; } size_t upper_neighbor_cnt() const { return hnsw.upper_neighbor_count; } size_t vector_size() const { return graph.vector_size; } size_t ef_construction() const { return graph.ef_construction; } size_t scaling_factor() const { return hnsw.scaling_factor; } size_t neighbor_prune_cnt() const { return graph.prune_neighbor_count; } node_id_t entry_point() const { return hnsw.entry_point; } node_id_t doc_cnt() const { return graph.doc_count; } GraphHeader graph; HnswHeader hnsw; }; struct NeighborsHeader { uint32_t neighbor_cnt; node_id_t neighbors[0]; }; struct Neighbors { Neighbors() : cnt{0}, data{nullptr} {} Neighbors(uint32_t cnt_in, const node_id_t *data_in) : cnt{cnt_in}, data{data_in} {} Neighbors(const IndexStorage::MemoryBlock &mem_block) : neighbor_block{mem_block} { auto hd = reinterpret_cast(neighbor_block.data()); cnt = hd->neighbor_cnt; data = hd->neighbors; } size_t size(void) const { return cnt; } const node_id_t &operator[](size_t idx) const { return data[idx]; } uint32_t cnt; const node_id_t *data; IndexStorage::MemoryBlock neighbor_block; }; //! level 0 neighbors offset struct GraphNeighborMeta { GraphNeighborMeta(size_t o, size_t cnt) : offset(o), neighbor_cnt(cnt) {} uint64_t offset : 48; uint64_t neighbor_cnt : 16; }; //! hnsw upper neighbors meta struct HnswNeighborMeta { HnswNeighborMeta(size_t o, size_t l) : offset(o), level(l) {} uint64_t offset : 48; // offset = idx * upper neighors size uint64_t level : 16; }; class HnswEntity { public: //! Constructor HnswEntity() {} //! Constructor HnswEntity(const HNSWHeader &hd) { header_ = hd; } //! Destructor virtual ~HnswEntity() {} //! HnswEntity Pointerd; typedef std::shared_ptr Pointer; //! Get max neighbor size of graph level inline size_t neighbor_cnt(level_t level) const { return level == 0 ? header_.graph.l0_neighbor_count : header_.hnsw.upper_neighbor_count; } //! get max neighbor size of graph level 0 inline size_t l0_neighbor_cnt() const { return header_.graph.l0_neighbor_count; } //! get min neighbor size of graph inline size_t min_neighbor_cnt() const { return header_.graph.min_neighbor_count; } //! get upper neighbor size of graph level other than 0 inline size_t upper_neighbor_cnt() const { return header_.hnsw.upper_neighbor_count; } //! Get current total doc of the hnsw graph inline node_id_t *mutable_doc_cnt() { return &header_.graph.doc_count; } inline node_id_t doc_cnt() const { return header_.graph.doc_count; } //! Get hnsw graph scaling params inline size_t scaling_factor() const { return header_.hnsw.scaling_factor; } //! Get prune_size inline size_t prune_cnt() const { return header_.graph.prune_neighbor_count; } //! Current entity of top level graph inline node_id_t entry_point() const { return header_.hnsw.entry_point; } //! Current max graph level inline level_t cur_max_level() const { return header_.hnsw.max_level; } //! Retrieve index vector size size_t vector_size() const { return header_.graph.vector_size; } //! Retrieve node size size_t node_size() const { return header_.graph.node_size; } //! Retrieve ef constuction size_t ef_construction() const { return header_.graph.ef_construction; } void set_vector_size(size_t size) { header_.graph.vector_size = size; } void set_prune_cnt(size_t v) { header_.graph.prune_neighbor_count = v; } void set_scaling_factor(size_t val) { header_.hnsw.scaling_factor = val; } void set_l0_neighbor_cnt(size_t cnt) { header_.graph.l0_neighbor_count = cnt; } void set_min_neighbor_cnt(size_t cnt) { header_.graph.min_neighbor_count = cnt; } void set_upper_neighbor_cnt(size_t cnt) { header_.hnsw.upper_neighbor_count = cnt; } void set_ef_construction(size_t ef) { header_.graph.ef_construction = ef; } protected: inline const HNSWHeader &header() const { return header_; } inline HNSWHeader *mutable_header() { return &header_; } inline size_t header_size() const { return sizeof(header_); } void set_node_size(size_t size) { header_.graph.node_size = size; } //! Dump all segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const; private: //! dump mapping segment, for get_vector_by_key in provider int64_t dump_mapping_segment(const IndexDumper::Pointer &dumper, const key_t *keys) const; //! dump hnsw head by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_header(const IndexDumper::Pointer &dumper, const HNSWHeader &hd) const; //! dump vectors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_vectors(const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const; //! dump hnsw neighbors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_neighbors(const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { auto len1 = dump_graph_neighbors(dumper, reorder_mapping, neighbor_mapping); if (len1 < 0) { return len1; } auto len2 = dump_upper_neighbors(dumper, get_level, reorder_mapping, neighbor_mapping); if (len2 < 0) { return len2; } return len1 + len2; } //! dump segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const; //! Dump level 0 neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; //! Dump upper level neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; public: //! Cleanup the entity virtual int cleanup(void) { header_.clear(); return 0; } //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswEntity::Pointer clone() const { LOG_ERROR("Update neighbors not implemented"); return HnswEntity::Pointer(); } //! Get primary key of the node id virtual key_t get_key(node_id_t id) const = 0; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const = 0; //! Get vectors feature data by keys virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const = 0; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const = 0; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const = 0; //! Retrieve a vector using a primary key virtual const void *get_vector_by_key(uint64_t /*key*/) const { LOG_ERROR("get vector not implemented"); return nullptr; } virtual int get_vector_by_key(const key_t /*key*/, IndexStorage::MemoryBlock & /*block*/) const { return IndexError_NotImplemented; } //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const = 0; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t /*level*/, key_t /*key*/, const void * /*vec*/, node_id_t * /*id*/) { return IndexError_NotImplemented; } //! Add vector and id to hnsw entity virtual int add_vector_with_id(level_t /*level*/, node_id_t /*id*/, const void * /*vec*/) { return IndexError_NotImplemented; } virtual int update_neighbors( level_t /*level*/, node_id_t /*id*/, const std::vector> & /*neighbors*/) { LOG_ERROR("Update neighbors dense not implemented"); return 0; } //! Append neighbor_id to node id neighbors on level, size is the current //! neighbors size. Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t /*level*/, node_id_t /*id*/, uint32_t /*size*/, node_id_t /*neighbor_id*/) { LOG_ERROR("Add neighbor not implemented"); } //! Update entry point and max level virtual void update_ep_and_level(node_id_t ep, level_t level) { header_.hnsw.entry_point = ep; header_.hnsw.max_level = level; } virtual int load(const IndexStorage::Pointer & /*container*/, bool /*check_crc*/) { LOG_ERROR("Load not implemented"); return IndexError_NotImplemented; } virtual int dump(const IndexDumper::Pointer & /*dumper*/) { LOG_ERROR("Dump not implemented"); return IndexError_NotImplemented; } static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size); protected: static inline size_t AlignSize(size_t size) { return (size + 0x1F) & (~0x1F); } static inline size_t AlignPageSize(size_t size) { size_t page_mask = ailego::MemoryHelper::PageSize() - 1; return (size + page_mask) & (~page_mask); } static inline size_t AlignHugePageSize(size_t size) { size_t page_mask = ailego::MemoryHelper::HugePageSize() - 1; return (size + page_mask) & (~page_mask); } //! rearrange vectors to improve cache locality void reshuffle_vectors(const std::function &get_level, std::vector *n2o_mapping, std::vector *o2n_mapping, key_t *keys) const; public: const static std::string kGraphHeaderSegmentId; const static std::string kGraphFeaturesSegmentId; const static std::string kGraphKeysSegmentId; const static std::string kGraphNeighborsSegmentId; const static std::string kGraphOffsetsSegmentId; const static std::string kGraphMappingSegmentId; const static std::string kHnswHeaderSegmentId; const static std::string kHnswNeighborsSegmentId; const static std::string kHnswOffsetsSegmentId; constexpr static uint32_t kRevision = 0U; constexpr static size_t kMaxGraphLayers = 15; constexpr static uint32_t kDefaultEfConstruction = 500; constexpr static uint32_t kDefaultEf = 500; constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; constexpr static uint32_t kMaxNeighborCnt = 65535; constexpr static float kDefaultScanRatio = 0.1f; constexpr static uint32_t kDefaultMinScanLimit = 10000; constexpr static uint32_t kDefaultMaxScanLimit = std::numeric_limits::max(); constexpr static float kDefaultBFNegativeProbability = 0.001f; constexpr static uint32_t kDefaultScalingFactor = 50U; constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; constexpr static size_t kDefaultMaxChunkCnt = 50000UL; constexpr static float kDefaultNeighborPruneMultiplier = 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier constexpr static float kDefaultL0MaxNeighborCntMultiplier = 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier protected: HNSWHeader header_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_index_hash.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_chunk.h" namespace zvec { namespace core { //! Persistent hashmap implement through open addressing algorithm template ::value>::type> class HnswIndexHashMap { using key_type = Key; using val_type = Val; struct Iterator { key_type first; val_type second; }; typedef Iterator *iterator; typedef Iterator Item; typedef const Iterator *const_iterator; class Slot { public: Slot(Chunk::Pointer &&chunk, const void *data) : chunk_(std::move(chunk)), items_(reinterpret_cast(data)) {} //! Return a empty loc or the key item loc Slot(Chunk::Pointer &&chunk, IndexStorage::MemoryBlock &&mem_block) : chunk_(std::move(chunk)), items_block_(std::move(mem_block)) { items_ = reinterpret_cast(items_block_.data()); } const_iterator find(key_type key, uint32_t max_items, uint32_t mask) const { auto it = &items_[key & mask]; for (auto i = 0U; i < max_items; ++i) { if (it->first == key || it->second == EmptyVal) { // LOG_DEBUG("i=%u", i); return it; } ++it; if (it == &items_[max_items]) { it = &items_[0]; } } return nullptr; } bool update(const_iterator it) { uint32_t offset = reinterpret_cast(it) - reinterpret_cast(&items_[0]); if (ailego_unlikely(chunk_->write(offset, it, sizeof(Item)) != sizeof(Item))) { LOG_ERROR("Chunk write failed"); return false; } return true; } private: Chunk::Pointer chunk_{}; const Item *items_{nullptr}; // point to chunk data IndexStorage::MemoryBlock items_block_{}; }; public: //! Init the hash //! broker the index allocator //! chunk_size the size of per chunk allocated, actual size may greater //! factor factor = 1/ratio, ratio is the probability of a squence //! number inserted to this container //! max the max number key can be inserted //! expansion_ratio memory expansion ratio int init(ChunkBroker::Pointer &broker, uint32_t chunk_size, uint32_t factor, size_t max, float expansion_ratio) { ailego_assert_with(expansion_ratio > 1.0f, "ratio must > 1.0f"); broker_ = broker; size_t items = std::ceil(chunk_size * 1.0f / sizeof(Item)); slot_items_ = 1UL << static_cast((std::ceil(std::log2(items)))); size_t range = slot_items_ * factor / expansion_ratio; mask_bits_ = std::floor(std::log2(range)); range = 1UL << mask_bits_; size_t max_slots = std::ceil(max * 1.0f / range); slots_.reserve(max_slots); slot_loc_mask_ = slot_items_ - 1U; int ret = load(); if (ret != 0) { return ret; } LOG_DEBUG( "HnswIndexHash init, chunkSize=%u factor=%u max=%zu " "ratio=%f slotItems=%u maxSlots=%zu maskBits=%u " "range=%zu", chunk_size, factor, max, expansion_ratio, slot_items_, max_slots, mask_bits_, range); return 0; } int cleanup(void) { broker_.reset(); slots_.clear(); slots_.shrink_to_fit(); mask_bits_ = 0U; slot_items_ = 0U; slot_loc_mask_ = 0U; return 0; } const_iterator end(void) const { return nullptr; } const_iterator find(const key_type key) const { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { return end(); } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); return it && it->second != EmptyVal ? it : nullptr; } bool insert(key_type key, val_type val) { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { if (ailego_unlikely(idx >= slots_.capacity())) { LOG_ERROR("no space to insert"); return false; } for (auto i = slots_.size(); i <= idx; ++i) { if (ailego_unlikely(!alloc_slot(i))) { return false; } } } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); if (ailego_unlikely(it == nullptr)) { LOG_ERROR("no space to insert"); return false; } //! TODO: write memory is ok? const_cast(it)->first = key; const_cast(it)->second = val; return slots_[idx].update(it); } private: bool alloc_slot(size_t idx) { ailego_assert_with(idx == slots_.size(), "invalid idx"); size_t size = slot_items_ * sizeof(Item); auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, idx, size); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return false; } Chunk::Pointer chunk = p.second; if (ailego_unlikely(chunk->resize(size) != size)) { LOG_ERROR("Chunk resize failed, size=%zu", size); return false; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); return true; } int load(void) { size_t slots_cnt = broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX); for (size_t i = 0UL; i < slots_cnt; ++i) { auto chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, i); if (!chunk) { LOG_ERROR("Get chunk failed, seq=%zu", i); return IndexError_InvalidFormat; } size_t size = sizeof(Item) * slot_items_; if (chunk->data_size() < size) { LOG_ERROR( "Hash params may be mismatch, seq=%zu, data_size=%zu " "expect=%zu", i, chunk->data_size(), size); return IndexError_InvalidFormat; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); } return 0; } private: ChunkBroker::Pointer broker_{}; // chunk broker std::vector slots_{}; uint32_t mask_bits_{0U}; uint32_t slot_items_{}; // must be a power of 2 uint32_t slot_loc_mask_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_index_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "hnsw_entity.h" namespace zvec { namespace core { class HnswIndexProvider : public IndexProvider { public: HnswIndexProvider(const IndexMeta &meta, const HnswEntity::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} HnswIndexProvider(const HnswIndexProvider &) = delete; HnswIndexProvider &operator=(const HnswIndexProvider &) = delete; public: // holder interface //! Create a new iterator IndexProvider::Iterator::Pointer create_iterator() override { return HnswIndexProvider::Iterator::Pointer(new (std::nothrow) Iterator(entity_)); } //! Retrieve count of vectors size_t count(void) const override { return entity_->doc_cnt(); } //! Retrieve dimension of vector size_t dimension(void) const override { return meta_.dimension(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return meta_.data_type(); } //! Retrieve vector size in bytes size_t element_size(void) const override { return meta_.element_size(); } public: // provider's unique interface //! Retrieve a vector using a primary key const void *get_vector(uint64_t key) const override { return entity_->get_vector_by_key(key); } int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return entity_->get_vector_by_key(key, block); } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_class_; } private: class Iterator : public IndexProvider::Iterator { public: Iterator(const HnswEntity::Pointer &entity) : entity_(entity), cur_id_(0U) {} //! Retrieve pointer of data //! NOTICE: the vec feature will be changed after iterating to next, so //! the caller need to keep a copy of it before iterator to next vector virtual const void *data(void) const override { return entity_->get_vector(cur_id_); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return cur_id_ < entity_->doc_cnt(); } //! Retrieve primary key virtual uint64_t key(void) const override { return entity_->get_key(cur_id_); } //! Next iterator virtual void next(void) override { // cur_id_ += 1; cur_id_ = get_next_valid_id(cur_id_ + 1); } //! Reset the iterator void reset(void) { cur_id_ = get_next_valid_id(0); } private: node_id_t get_next_valid_id(node_id_t start_id) { for (node_id_t i = start_id; i < entity_->doc_cnt(); i++) { if (entity_->get_key(i) != kInvalidNodeId) { cur_id_ = i; return i; } } return kInvalidNodeId; } private: const HnswEntity::Pointer entity_; node_id_t cur_id_; }; private: const IndexMeta &meta_; const HnswEntity::Pointer entity_; const std::string owner_class_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { static const std::string PARAM_HNSW_BUILDER_THREAD_COUNT( "proxima.hnsw.builder.thread_count"); static const std::string PARAM_HNSW_BUILDER_MEMORY_QUOTA( "proxima.hnsw.builder.memory_quota"); static const std::string PARAM_HNSW_BUILDER_EFCONSTRUCTION( "proxima.hnsw.builder.efconstruction"); static const std::string PARAM_HNSW_BUILDER_SCALING_FACTOR( "proxima.hnsw.builder.scaling_factor"); static const std::string PARAM_HNSW_BUILDER_CHECK_INTERVAL_SECS( "proxima.hnsw.builder.check_interval_secs"); static const std::string PARAM_HNSW_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw.builder.neighbor_prune_multiplier"); static const std::string PARAM_HNSW_BUILDER_MIN_NEIGHBOR_COUNT( "proxima.hnsw.builder.min_neighbor_count"); static const std::string PARAM_HNSW_BUILDER_MAX_NEIGHBOR_COUNT( "proxima.hnsw.builder.max_neighbor_count"); static const std::string PARAM_HNSW_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw.builder.l0_max_neighbor_count_multiplier"); static const std::string PARAM_HNSW_SEARCHER_EF("proxima.hnsw.searcher.ef"); static const std::string PARAM_HNSW_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw.searcher.brute_force_threshold"); static const std::string PARAM_HNSW_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE( "proxima.hnsw.searcher.neighbors_in_memory_enable"); static const std::string PARAM_HNSW_SEARCHER_MAX_SCAN_RATIO( "proxima.hnsw.searcher.max_scan_ratio"); static const std::string PARAM_HNSW_SEARCHER_CHECK_CRC_ENABLE( "proxima.hnsw.searcher.check_crc_enable"); static const std::string PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw.searcher.visit_bloomfilter_enable"); static const std::string PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw.searcher.visit_bloomfilter_negative_prob"); static const std::string PARAM_HNSW_SEARCHER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw.searcher.force_padding_result_enable"); static const std::string PARAM_HNSW_STREAMER_MAX_SCAN_RATIO( "proxima.hnsw.streamer.max_scan_ratio"); static const std::string PARAM_HNSW_STREAMER_MIN_SCAN_LIMIT( "proxima.hnsw.streamer.min_scan_limit"); static const std::string PARAM_HNSW_STREAMER_MAX_SCAN_LIMIT( "proxima.hnsw.streamer.max_scan_limit"); static const std::string PARAM_HNSW_STREAMER_EF("proxima.hnsw.streamer.ef"); static const std::string PARAM_HNSW_STREAMER_EFCONSTRUCTION( "proxima.hnsw.streamer.efconstruction"); static const std::string PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT( "proxima.hnsw.streamer.max_neighbor_count"); static const std::string PARAM_HNSW_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw.streamer.l0_max_neighbor_count_multiplier"); static const std::string PARAM_HNSW_STREAMER_SCALING_FACTOR( "proxima.hnsw.streamer.scaling_factor"); static const std::string PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw.streamer.brute_force_threshold"); static const std::string PARAM_HNSW_STREAMER_DOCS_HARD_LIMIT( "proxima.hnsw.streamer.docs_hard_limit"); static const std::string PARAM_HNSW_STREAMER_DOCS_SOFT_LIMIT( "proxima.hnsw.streamer.docs_soft_limit"); static const std::string PARAM_HNSW_STREAMER_MAX_INDEX_SIZE( "proxima.hnsw.streamer.max_index_size"); static const std::string PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw.streamer.visit_bloomfilter_enable"); static const std::string PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw.streamer.visit_bloomfilter_negative_prob"); static const std::string PARAM_HNSW_STREAMER_CHECK_CRC_ENABLE( "proxima.hnsw.streamer.check_crc_enable"); static const std::string PARAM_HNSW_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw.streamer.neighbor_prune_multiplier"); static const std::string PARAM_HNSW_STREAMER_CHUNK_SIZE( "proxima.hnsw.streamer.chunk_size"); static const std::string PARAM_HNSW_STREAMER_FILTER_SAME_KEY( "proxima.hnsw.streamer.filter_same_key"); static const std::string PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE( "proxima.hnsw.streamer.get_vector_enable"); static const std::string PARAM_HNSW_STREAMER_MIN_NEIGHBOR_COUNT( "proxima.hnsw.streamer.min_neighbor_count"); static const std::string PARAM_HNSW_STREAMER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw.streamer.force_padding_result_enable"); static const std::string PARAM_HNSW_STREAMER_ESTIMATE_DOC_COUNT( "proxima.hnsw.streamer.estimate_doc_count"); static const std::string PARAM_HNSW_STREAMER_USE_ID_MAP( "proxima.hnsw.streamer.use_id_map"); static const std::string PARAM_HNSW_REDUCER_WORKING_PATH( "proxima.hnsw.reducer.working_path"); static const std::string PARAM_HNSW_REDUCER_NUM_OF_ADD_THREADS( "proxima.hnsw.reducer.num_of_add_threads"); static const std::string PARAM_HNSW_REDUCER_INDEX_NAME( "proxima.hnsw.reducer.index_name"); static const std::string PARAM_HNSW_REDUCER_EFCONSTRUCTION( "proxima.hnsw.reducer.efconstruction"); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_searcher.h" #include "hnsw_algorithm.h" #include "hnsw_index_provider.h" #include "hnsw_params.h" namespace zvec { namespace core { HnswSearcher::HnswSearcher() = default; HnswSearcher::~HnswSearcher() = default; int HnswSearcher::init(const ailego::Params &search_params) { params_ = search_params; params_.get(PARAM_HNSW_SEARCHER_EF, &ef_); params_.get(PARAM_HNSW_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); params_.get(PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params_.get(PARAM_HNSW_SEARCHER_CHECK_CRC_ENABLE, &check_crc_enabled_); params_.get(PARAM_HNSW_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE, &neighbors_in_memory_enabled_); params_.get(PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_probability_); params_.get(PARAM_HNSW_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params_.get(PARAM_HNSW_SEARCHER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); if (ef_ == 0) { ef_ = HnswEntity::kDefaultEf; } if (bf_negative_probability_ <= 0.0f || bf_negative_probability_ >= 1.0f) { LOG_ERROR("[%s] must be in range (0,1)", PARAM_HNSW_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } entity_.set_neighbors_in_memory(neighbors_in_memory_enabled_); state_ = STATE_INITED; LOG_DEBUG( "Init params: ef=%u maxScanRatio=%f bfEnabled=%u checkCrcEnabled=%u " "neighborsInMemoryEnabled=%u bfNagtiveProb=%f bruteForceThreshold=%u " "forcePadding=%u", ef_, max_scan_ratio_, bf_enabled_, check_crc_enabled_, neighbors_in_memory_enabled_, bf_negative_probability_, bruteforce_threshold_, force_padding_topk_enabled_); return 0; } void HnswSearcher::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } } int HnswSearcher::cleanup() { LOG_INFO("Begin HnswSearcher:cleanup"); metric_.reset(); meta_.clear(); stats_.clear_attributes(); stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); max_scan_ratio_ = HnswEntity::kDefaultScanRatio; max_scan_num_ = 0U; ef_ = HnswEntity::kDefaultEf; bf_enabled_ = false; bf_negative_probability_ = HnswEntity::kDefaultBFNegativeProbability; bruteforce_threshold_ = HnswEntity::kDefaultBruteForceThreshold; check_crc_enabled_ = false; neighbors_in_memory_enabled_ = false; entity_.cleanup(); state_ = STATE_INIT; LOG_INFO("End HnswSearcher:cleanup"); return 0; } int HnswSearcher::load(IndexStorage::Pointer container, IndexMetric::Pointer metric) { if (state_ != STATE_INITED) { LOG_ERROR("Init the searcher first before load index"); return IndexError_Runtime; } LOG_INFO("Begin HnswSearcher:load"); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } ret = entity_.load(container, check_crc_enabled_); if (ret != 0) { LOG_ERROR("HnswSearcher load index failed"); return ret; } alg_ = HnswAlgorithm::UPointer(new HnswAlgorithm(entity_)); if (metric) { metric_ = metric; } else { metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed, ret=%d", ret); return ret; } if (metric_->query_metric()) { metric_ = metric_->query_metric(); } } if (!metric_->is_matched(meta_)) { LOG_ERROR("IndexMetric not match index meta"); return IndexError_Mismatch; } max_scan_num_ = static_cast(max_scan_ratio_ * entity_.doc_cnt()); max_scan_num_ = std::max(4096U, max_scan_num_); stats_.set_loaded_count(entity_.doc_cnt()); stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = STATE_LOADED; magic_ = IndexContext::GenerateMagic(); LOG_INFO("End HnswSearcher::load"); return 0; } int HnswSearcher::unload() { LOG_INFO("HnswSearcher unload index"); meta_.clear(); entity_.cleanup(); metric_.reset(); max_scan_num_ = 0; stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); state_ = STATE_INITED; return 0; } int HnswSearcher::update_context(HnswContext *ctx) const { const HnswEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_num(max_scan_num_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswContext::kSearcherContext, meta_, metric_, entity, magic_); } int HnswSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); int ret = alg_->search(ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswSearcher::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().batch_dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->topk_heap().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().batch_dist(id); ctx->topk_heap().emplace(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswSearcher::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().batch_dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->topk_heap().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().batch_dist(id); ctx->topk_heap().emplace(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } IndexSearcher::Context::Pointer HnswSearcher::create_context() const { if (ailego_unlikely(state_ != STATE_LOADED)) { LOG_ERROR("Load the index first before create context"); return Context::Pointer(); } const HnswEntity::Pointer search_ctx_entity = entity_.clone(); if (!search_ctx_entity) { LOG_ERROR("Failed to create search context entity"); return Context::Pointer(); } HnswContext *ctx = new (std::nothrow) HnswContext(meta_.dimension(), metric_, search_ctx_entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_num(max_scan_num_); uint32_t filter_mode = bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; ctx->set_filter_mode(filter_mode); ctx->set_filter_negative_probability(bf_negative_probability_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswContext::kSearcherContext)) != 0) { LOG_ERROR("Init HnswContext failed"); delete ctx; return Context::Pointer(); } return Context::Pointer(ctx); } IndexProvider::Pointer HnswSearcher::create_provider(void) const { LOG_DEBUG("HnswSearcher create provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswEntity failed"); return Provider::Pointer(); } return Provider::Pointer( new (std::nothrow) HnswIndexProvider(meta_, entity, "HnswSearcher")); } const void *HnswSearcher::get_vector(uint64_t key) const { return entity_.get_vector_by_key(key); } INDEX_FACTORY_REGISTER_SEARCHER(HnswSearcher); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_searcher_entity.h" #include "hnsw_streamer.h" namespace zvec { namespace core { class HnswSearcher : public IndexSearcher { public: using ContextPointer = IndexSearcher::Context::Pointer; public: HnswSearcher(void); ~HnswSearcher(void); HnswSearcher(const HnswSearcher &) = delete; HnswSearcher &operator=(const HnswSearcher &) = delete; protected: //! Initialize Searcher virtual int init(const ailego::Params ¶ms) override; //! Cleanup Searcher virtual int cleanup(void) override; //! Load Index from storage virtual int load(IndexStorage::Pointer container, IndexMetric::Pointer metric) override; //! Unload index from storage virtual int unload(void) override; //! KNN Search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_impl(query, qmeta, 1, context); } //! KNN Search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Linear Search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_impl(query, qmeta, 1, context); } //! Linear Search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch vector by key virtual const void *get_vector(uint64_t key) const override; //! Create a searcher context virtual ContextPointer create_context() const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index virtual const ailego::Params ¶ms(void) const override { return params_; } virtual void print_debug_info() override; private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; HnswSearcherEntity entity_{}; HnswAlgorithm::UPointer alg_; // impl graph algorithm IndexMetric::Pointer metric_{}; IndexMeta meta_{}; ailego::Params params_{}; Stats stats_; uint32_t ef_{HnswEntity::kDefaultEf}; uint32_t max_scan_num_{0U}; uint32_t bruteforce_threshold_{HnswEntity::kDefaultBruteForceThreshold}; float max_scan_ratio_{HnswEntity::kDefaultScanRatio}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool neighbors_in_memory_enabled_{false}; bool force_padding_topk_enabled_{false}; float bf_negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; uint32_t magic_{0U}; State state_{STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_searcher_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_searcher_entity.h" #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswSearcherEntity::HnswSearcherEntity() {} int HnswSearcherEntity::cleanup(void) { storage_.reset(); vectors_.reset(); keys_.reset(); neighbors_.reset(); neighbors_meta_.reset(); neighbors_in_memory_enabled_ = false; loaded_ = false; this->HnswEntity::cleanup(); return 0; } key_t HnswSearcherEntity::get_key(node_id_t id) const { const void *key; if (ailego_unlikely(keys_->read(id * sizeof(key_t), &key, sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidKey; } return *(reinterpret_cast(key)); } //! Get vector local id by key node_id_t HnswSearcherEntity::get_id(key_t key) const { if (ailego_unlikely(!mapping_)) { LOG_ERROR("Index missing mapping segment"); return kInvalidNodeId; } //! Do binary search node_id_t start = 0UL; node_id_t end = doc_cnt(); const void *data; node_id_t idx = 0u; while (start < end) { idx = start + (end - start) / 2; if (ailego_unlikely( mapping_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != sizeof(node_id_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } const key_t *mkey; node_id_t local_id = *reinterpret_cast(data); if (ailego_unlikely(keys_->read(local_id * sizeof(key_t), (const void **)(&mkey), sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } if (*mkey < key) { start = idx + 1; } else if (*mkey > key) { end = idx; } else { return local_id; } } return kInvalidNodeId; } const void *HnswSearcherEntity::get_vector_by_key(key_t key) const { node_id_t local_id = get_id(key); if (ailego_unlikely(local_id == kInvalidNodeId)) { return nullptr; } return get_vector(local_id); } const void *HnswSearcherEntity::get_vector(node_id_t id) const { size_t read_size = vector_size(); size_t offset = node_size() * id; const void *vec; if (ailego_unlikely(vectors_->read(offset, &vec, read_size) != read_size)) { LOG_ERROR("Read vector from segment failed"); return nullptr; } return vec; } int HnswSearcherEntity::get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector(id); block.reset((void *)vec); return 0; } const void *HnswSearcherEntity::get_vectors() const { const void *vec; size_t len = node_size() * doc_cnt(); if (vectors_->read(0, &vec, len) != len) { LOG_ERROR("Read vectors from segment failed"); return nullptr; } return vec; } int HnswSearcherEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { ailego_assert_with(count <= segment_datas_.size(), "invalid count"); size_t read_size = vector_size(); for (uint32_t i = 0; i < count; ++i) { segment_datas_[i].offset = node_size() * ids[i]; segment_datas_[i].length = read_size; ailego_assert_with(segment_datas_[i].offset < vectors_->data_size(), "invalid offset"); } if (ailego_unlikely(!vectors_->read(&segment_datas_[0], count))) { LOG_ERROR("Read vectors from segment failed"); return IndexError_ReadData; } for (uint32_t i = 0; i < count; ++i) { vecs[i] = segment_datas_[i].data; } return 0; } int HnswSearcherEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { const void *vecs[count]; get_vector(ids, count, vecs); for (uint32_t i = 0; i < count; ++i) { vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } const Neighbors HnswSearcherEntity::get_neighbors(level_t level, node_id_t id) const { if (level == 0) { if (neighbors_in_memory_enabled_) { auto hd = reinterpret_cast( fixed_neighbors_.get() + neighbors_size() * id); return {hd->neighbor_cnt, hd->neighbors}; } const GraphNeighborMeta *m; if (ailego_unlikely(neighbors_meta_->read(id * sizeof(GraphNeighborMeta), (const void **)(&m), sizeof(GraphNeighborMeta)) != sizeof(GraphNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } const void *data; if (ailego_unlikely(neighbors_->read(m->offset, &data, m->neighbor_cnt * sizeof(node_id_t)) != m->neighbor_cnt * sizeof(node_id_t))) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } return {static_cast(m->neighbor_cnt), reinterpret_cast(data)}; } //! Read level > 0 neighbors const HnswNeighborMeta *m; if (ailego_unlikely(upper_neighbors_meta_->read(id * sizeof(HnswNeighborMeta), (const void **)(&m), sizeof(HnswNeighborMeta)) != sizeof(HnswNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } ailego_assert_with(level <= m->level, "invalid level"); size_t offset = m->offset + (level - 1) * upper_neighbors_size(); ailego_assert_with(offset <= upper_neighbors_->data_size(), "invalid offset"); const void *data; if (ailego_unlikely( upper_neighbors_->read(offset, &data, upper_neighbors_size()) != upper_neighbors_size())) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } auto hd = reinterpret_cast(data); return {hd->neighbor_cnt, hd->neighbors}; } int HnswSearcherEntity::load(const IndexStorage::Pointer &container, bool check_crc) { storage_ = container; int ret = load_segments(check_crc); if (ret != 0) { return ret; } loaded_ = true; LOG_INFO( "Index info: docCnt=%u entryPoint=%u maxLevel=%d efConstruct=%zu " "l0NeighborCnt=%zu upperNeighborCnt=%zu scalingFactor=%zu " "vectorSize=%zu nodeSize=%zu vectorSegmentSize=%zu keySegmentSize=%zu " "neighborsSegmentSize=%zu neighborsMetaSegmentSize=%zu ", doc_cnt(), entry_point(), cur_max_level(), ef_construction(), l0_neighbor_cnt(), upper_neighbor_cnt(), scaling_factor(), vector_size(), node_size(), vectors_->data_size(), keys_->data_size(), neighbors_ == nullptr ? 0 : neighbors_->data_size(), neighbors_meta_ == nullptr ? 0 : neighbors_meta_->data_size()); return 0; } int HnswSearcherEntity::load_segments(bool check_crc) { //! load header const void *data = nullptr; HNSWHeader hd; auto graph_hd_segment = storage_->get(kGraphHeaderSegmentId); if (!graph_hd_segment || graph_hd_segment->data_size() < sizeof(hd.graph)) { LOG_ERROR("Miss or invalid segment %s", kGraphHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (graph_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.graph)) != sizeof(hd.graph)) { LOG_ERROR("Read segment %s failed", kGraphHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.graph, data, sizeof(hd.graph)); auto hnsw_hd_segment = storage_->get(kHnswHeaderSegmentId); if (!hnsw_hd_segment || hnsw_hd_segment->data_size() < sizeof(hd.hnsw)) { LOG_ERROR("Miss or invalid segment %s", kHnswHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (hnsw_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.hnsw)) != sizeof(hd.hnsw)) { LOG_ERROR("Read segment %s failed", kHnswHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.hnsw, data, sizeof(hd.hnsw)); *mutable_header() = hd; segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt())); vectors_ = storage_->get(kGraphFeaturesSegmentId); if (!vectors_) { LOG_ERROR("IndexStorage get segment %s failed", kGraphFeaturesSegmentId.c_str()); return IndexError_InvalidFormat; } keys_ = storage_->get(kGraphKeysSegmentId); if (!keys_) { LOG_ERROR("IndexStorage get segment %s failed", kGraphKeysSegmentId.c_str()); return IndexError_InvalidFormat; } neighbors_ = storage_->get(kGraphNeighborsSegmentId); if (!neighbors_ || (neighbors_->data_size() == 0 && doc_cnt() > 1)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kGraphNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } neighbors_meta_ = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta_ || neighbors_meta_->data_size() < sizeof(GraphNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_ = storage_->get(kHnswNeighborsSegmentId); if (!upper_neighbors_ || (upper_neighbors_->data_size() == 0 && cur_max_level() > 0)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kHnswNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_meta_ = storage_->get(kHnswOffsetsSegmentId); if (!upper_neighbors_meta_ || upper_neighbors_meta_->data_size() < sizeof(HnswNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kHnswOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } mapping_ = storage_->get(kGraphMappingSegmentId); if (!mapping_ || mapping_->data_size() < sizeof(node_id_t) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kGraphMappingSegmentId.c_str()); return IndexError_InvalidArgument; } if (check_crc) { std::vector segments; segments.emplace_back(graph_hd_segment); segments.emplace_back(hnsw_hd_segment); segments.emplace_back(vectors_); segments.emplace_back(keys_); segments.emplace_back(neighbors_); segments.emplace_back(neighbors_meta_); segments.emplace_back(upper_neighbors_); segments.emplace_back(upper_neighbors_meta_); if (!do_crc_check(segments)) { LOG_ERROR("Check index crc failed, the index may broken"); return IndexError_Runtime; } } if (neighbors_in_memory_enabled_) { int ret = load_and_flat_neighbors(); if (ret != 0) { return ret; } } return 0; } int HnswSearcherEntity::load_and_flat_neighbors() { fixed_neighbors_.reset( new (std::nothrow) char[neighbors_size() * doc_cnt()]{}, std::default_delete()); if (!fixed_neighbors_) { LOG_ERROR("Malloc memory failed"); return IndexError_NoMemory; } //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const GraphNeighborMeta *neighbors_index = nullptr; if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), neighbors_meta->data_size()) != neighbors_meta->data_size()) { LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } const char *neighbor_data; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t rd_size = neighbors_index[id].neighbor_cnt * sizeof(node_id_t); if (ailego_unlikely( neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbor_data), rd_size) != rd_size)) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory char *dst = fixed_neighbors_.get() + neighbors_size() * id; *reinterpret_cast(dst) = neighbors_index[id].neighbor_cnt; memcpy(dst + sizeof(uint32_t), neighbor_data, rd_size); } return 0; } int HnswSearcherEntity::get_fixed_neighbors( std::vector *fixed_neighbors) const { //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const GraphNeighborMeta *neighbors_index = nullptr; size_t meta_size = neighbors_meta->data_size(); if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), meta_size) != meta_size) { LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } size_t fixed_neighbor_cnt = l0_neighbor_cnt(); fixed_neighbors->resize((fixed_neighbor_cnt + 1) * doc_cnt(), kInvalidNodeId); size_t neighbors_cnt_offset = fixed_neighbor_cnt * doc_cnt(); size_t total_neighbor_cnt = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t cur_neighbor_cnt = neighbors_index[id].neighbor_cnt; if (cur_neighbor_cnt == 0) { (*fixed_neighbors)[neighbors_cnt_offset + id] = 0; continue; } size_t rd_size = cur_neighbor_cnt * sizeof(node_id_t); const uint32_t *neighbors; if (neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbors), rd_size) != rd_size) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory auto it = fixed_neighbors->begin() + id * fixed_neighbor_cnt; std::copy(neighbors, neighbors + cur_neighbor_cnt, it); (*fixed_neighbors)[neighbors_cnt_offset + id] = cur_neighbor_cnt; total_neighbor_cnt += cur_neighbor_cnt; } LOG_INFO("total neighbor cnt: %zu, average neighbor cnt: %zu", total_neighbor_cnt, total_neighbor_cnt / doc_cnt()); return 0; } bool HnswSearcherEntity::do_crc_check( std::vector &segments) const { constexpr size_t blk_size = 4096; const void *data; for (auto &segment : segments) { size_t offset = 0; size_t rd_size; uint32_t crc = 0; while (offset < segment->data_size()) { size_t size = std::min(blk_size, segment->data_size() - offset); if ((rd_size = segment->read(offset, &data, size)) <= 0) { break; } offset += rd_size; crc = ailego::Crc32c::Hash(data, rd_size, crc); } if (crc != segment->data_crc()) { return false; } } return true; } const HnswEntity::Pointer HnswSearcherEntity::clone() const { auto vectors = vectors_->clone(); if (ailego_unlikely(!vectors)) { LOG_ERROR("clone segment %s failed", kGraphFeaturesSegmentId.c_str()); return HnswEntity::Pointer(); } auto keys = keys_->clone(); if (ailego_unlikely(!keys)) { LOG_ERROR("clone segment %s failed", kGraphKeysSegmentId.c_str()); return HnswEntity::Pointer(); } auto mapping = mapping_->clone(); if (ailego_unlikely(!mapping)) { LOG_ERROR("clone segment %s failed", kGraphMappingSegmentId.c_str()); return HnswEntity::Pointer(); } auto neighbors = neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kGraphNeighborsSegmentId.c_str()); return HnswEntity::Pointer(); } auto upper_neighbors = upper_neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kHnswNeighborsSegmentId.c_str()); return HnswEntity::Pointer(); } auto neighbors_meta = neighbors_meta_->clone(); if (ailego_unlikely(!neighbors_meta)) { LOG_ERROR("clone segment %s failed", kGraphOffsetsSegmentId.c_str()); return HnswEntity::Pointer(); } auto upper_neighbors_meta = upper_neighbors_meta_->clone(); if (ailego_unlikely(!upper_neighbors_meta)) { LOG_ERROR("clone segment %s failed", kHnswOffsetsSegmentId.c_str()); return HnswEntity::Pointer(); } SegmentGroupParam neighbor_group{neighbors, neighbors_meta, upper_neighbors, upper_neighbors_meta}; HnswSearcherEntity *entity = new (std::nothrow) HnswSearcherEntity(header(), vectors, keys, mapping, neighbor_group, fixed_neighbors_, neighbors_in_memory_enabled_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswSearcherEntity new failed"); } return HnswEntity::Pointer(entity); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_searcher_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_builder_entity.h" #include "hnsw_entity.h" namespace zvec { namespace core { class HnswSearcherEntity : public HnswEntity { public: using Pointer = std::shared_ptr; using SegmentPointer = IndexStorage::Segment::Pointer; public: struct SegmentGroupParam { SegmentGroupParam(SegmentPointer neighbors_in, SegmentPointer neighbors_meta_in, SegmentPointer upper_neighbors_in, SegmentPointer upper_neighbors_meta_in) : neighbors{neighbors_in}, neighbors_meta{neighbors_meta_in}, upper_neighbors{upper_neighbors_in}, upper_neighbors_meta{upper_neighbors_meta_in} {} SegmentPointer neighbors{nullptr}; SegmentPointer neighbors_meta{nullptr}; SegmentPointer upper_neighbors{nullptr}; SegmentPointer upper_neighbors_meta{nullptr}; }; //! Constructor HnswSearcherEntity(); //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector local id by key node_id_t get_id(key_t key) const; //! Get vector feature data by key virtual const void *get_vector_by_key(key_t key) const override; //! Get vector feature data by id virtual const void *get_vector(node_id_t id) const override; //! Get vector feature data by id virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get all vectors const void *get_vectors() const; //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; virtual int load(const IndexStorage::Pointer &container, bool check_crc) override; int load_segments(bool check_crc); virtual int cleanup(void) override; public: bool is_loaded() const { return loaded_; } void set_neighbors_in_memory(bool enabled) { neighbors_in_memory_enabled_ = enabled; } //! get fixed length neighbors data int get_fixed_neighbors(std::vector *fixed_neighbors) const; private: //! Constructor HnswSearcherEntity(const HNSWHeader &hd, const SegmentPointer &vectors, const SegmentPointer &keys, const SegmentPointer &mapping, const SegmentGroupParam &neighbor_group, const std::shared_ptr &fixed_neighbors, bool neighbors_in_memory_enabled) : HnswEntity(hd), vectors_(vectors), keys_(keys), mapping_(mapping), neighbors_(neighbor_group.neighbors), neighbors_meta_(neighbor_group.neighbors_meta), upper_neighbors_(neighbor_group.upper_neighbors), upper_neighbors_meta_(neighbor_group.upper_neighbors_meta), neighbors_in_memory_enabled_(neighbors_in_memory_enabled) { segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt()), IndexStorage::SegmentData(0U, 0U)); fixed_neighbors_ = fixed_neighbors; } bool do_crc_check(std::vector &segments) const; inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } //! If neighbors_in_memory_enabled, load the level0 neighbors to memory int load_and_flat_neighbors(void); public: HnswSearcherEntity(const HnswSearcherEntity &) = delete; HnswSearcherEntity &operator=(const HnswSearcherEntity &) = delete; private: IndexStorage::Pointer storage_{}; SegmentPointer vectors_{}; SegmentPointer keys_{}; SegmentPointer mapping_{}; SegmentPointer neighbors_{}; SegmentPointer neighbors_meta_{}; SegmentPointer upper_neighbors_{}; SegmentPointer upper_neighbors_meta_{}; mutable std::vector segment_datas_{}; std::shared_ptr fixed_neighbors_{}; // level 0 fixed size neighbors bool neighbors_in_memory_enabled_{false}; bool loaded_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_streamer.h" #include #include #include #include #include "utility/sparse_utility.h" #include "hnsw_algorithm.h" #include "hnsw_context.h" #include "hnsw_dist_calculator.h" #include "hnsw_index_provider.h" namespace zvec { namespace core { HnswStreamer::HnswStreamer() : entity_(stats_) {} HnswStreamer::~HnswStreamer() { if (state_ == STATE_INITED) { this->cleanup(); } } int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; meta_.set_streamer("HnswStreamer", HnswEntity::kRevision, params); params.get(PARAM_HNSW_STREAMER_MAX_INDEX_SIZE, &max_index_size_); params.get(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; multiplier = HnswEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_STREAMER_SCALING_FACTOR, &scaling_factor_); params.get(PARAM_HNSW_STREAMER_DOCS_HARD_LIMIT, &docs_hard_limit_); params.get(PARAM_HNSW_STREAMER_EF, &ef_); params.get(PARAM_HNSW_STREAMER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params.get(PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_prob_); params.get(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_HNSW_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); params.get(PARAM_HNSW_STREAMER_CHECK_CRC_ENABLE, &check_crc_enabled_); params.get(PARAM_HNSW_STREAMER_CHUNK_SIZE, &chunk_size_); params.get(PARAM_HNSW_STREAMER_FILTER_SAME_KEY, &filter_same_key_); params.get(PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, &get_vector_enabled_); params.get(PARAM_HNSW_STREAMER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); params.get(PARAM_HNSW_STREAMER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); params.get(PARAM_HNSW_STREAMER_USE_ID_MAP, &use_id_map_); entity_.set_use_key_info_map(use_id_map_); params.get(PARAM_HNSW_STREAMER_DOCS_SOFT_LIMIT, &docs_soft_limit_); if (docs_soft_limit_ > 0 && docs_soft_limit_ > docs_hard_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_STREAMER_DOCS_HARD_LIMIT.c_str(), PARAM_HNSW_STREAMER_DOCS_SOFT_LIMIT.c_str()); return IndexError_InvalidArgument; } else if (docs_soft_limit_ == 0UL) { docs_soft_limit_ = docs_hard_limit_ * HnswEntity::kDefaultDocsSoftLimitRatio; } if (ef_ == 0U) { ef_ = HnswEntity::kDefaultEf; } if (ef_construction_ == 0U) { ef_construction_ = HnswEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0U) { upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d)", PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), HnswEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0U) { l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswEntity::kMaxNeighborCnt) { LOG_ERROR("MaxL0NeighborCnt must be in range (0,%d)", HnswEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%u] must be <= [%s]-[%u]", PARAM_HNSW_STREAMER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), upper_max_neighbor_cnt_); return IndexError_InvalidArgument; } if (bf_negative_prob_ <= 0.0f || bf_negative_prob_ >= 1.0f) { LOG_ERROR("[%s] must be in range (0,1)", PARAM_HNSW_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_STREAMER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } if (chunk_size_ == 0UL) { chunk_size_ = HnswEntity::kDefaultChunkSize; } if (chunk_size_ > HnswEntity::kMaxChunkSize) { LOG_ERROR("[%s] must be < %zu", PARAM_HNSW_STREAMER_CHUNK_SIZE.c_str(), HnswEntity::kMaxChunkSize); return IndexError_InvalidArgument; } entity_.set_ef_construction(ef_construction_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_prune_cnt(prune_cnt); entity_.set_vector_size(meta_.element_size()); entity_.set_chunk_size(chunk_size_); entity_.set_filter_same_key(filter_same_key_); entity_.set_get_vector(get_vector_enabled_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); int ret = entity_.init(docs_hard_limit_); if (ret != 0) { LOG_ERROR("Hnsw entity init failed for %s", IndexError::What(ret)); return ret; } LOG_DEBUG( "Init params: maxIndexSize=%zu docsHardLimit=%zu docsSoftLimit=%zu " "efConstruction=%u ef=%u upperMaxNeighborCnt=%u l0MaxNeighborCnt=%u " "scalingFactor=%u maxScanRatio=%.3f minScanLimit=%zu maxScanLimit=%zu " "bfEnabled=%d bruteFoceThreshold=%zu bfNegativeProbability=%.5f " "checkCrcEnabled=%d pruneSize=%zu vectorSize=%u chunkSize=%zu " "filterSameKey=%u getVectorEnabled=%u minNeighborCount=%u " "forcePadding=%u ", max_index_size_, docs_hard_limit_, docs_soft_limit_, ef_construction_, ef_, upper_max_neighbor_cnt_, l0_max_neighbor_cnt_, scaling_factor_, max_scan_ratio_, min_scan_limit_, max_scan_limit_, bf_enabled_, bruteforce_threshold_, bf_negative_prob_, check_crc_enabled_, prune_cnt, meta_.element_size(), chunk_size_, filter_same_key_, get_vector_enabled_, min_neighbor_cnt_, force_padding_topk_enabled_); alg_ = HnswAlgorithm::UPointer(new HnswAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswStreamer::cleanup(void) { if (state_ == STATE_OPENED) { this->close(); } LOG_INFO("HnswStreamer cleanup"); meta_.clear(); metric_.reset(); stats_.clear(); entity_.cleanup(); if (alg_) { alg_->cleanup(); } max_index_size_ = 0UL; docs_hard_limit_ = HnswEntity::kDefaultDocsHardLimit; docs_soft_limit_ = 0UL; upper_max_neighbor_cnt_ = HnswEntity::kDefaultUpperMaxNeighborCnt; l0_max_neighbor_cnt_ = HnswEntity::kDefaultL0MaxNeighborCnt; ef_ = HnswEntity::kDefaultEf; ef_construction_ = HnswEntity::kDefaultEfConstruction; bf_enabled_ = false; scaling_factor_ = HnswEntity::kDefaultScalingFactor; bruteforce_threshold_ = HnswEntity::kDefaultBruteForceThreshold; max_scan_limit_ = HnswEntity::kDefaultMaxScanLimit; min_scan_limit_ = HnswEntity::kDefaultMinScanLimit; chunk_size_ = HnswEntity::kDefaultChunkSize; bf_negative_prob_ = HnswEntity::kDefaultBFNegativeProbability; max_scan_ratio_ = HnswEntity::kDefaultScanRatio; state_ = STATE_INIT; check_crc_enabled_ = false; filter_same_key_ = false; get_vector_enabled_ = false; return 0; } int HnswStreamer::open(IndexStorage::Pointer stg) { LOG_INFO("HnswStreamer open"); if (ailego_unlikely(state_ != STATE_INITED)) { LOG_ERROR("Open storage failed, init streamer first!"); return IndexError_NoReady; } int ret = entity_.open(std::move(stg), max_index_size_, check_crc_enabled_); if (ret != 0) { return ret; } IndexMeta index_meta; ret = entity_.get_index_meta(&index_meta); if (ret == IndexError_NoExist) { // Set IndexMeta for the new index ret = entity_.set_index_meta(meta_); if (ret != 0) { LOG_ERROR("Failed to set index meta for %s", IndexError::What(ret)); return ret; } } else if (ret != 0) { LOG_ERROR("Failed to get index meta for %s", IndexError::What(ret)); return ret; } else { if (index_meta.dimension() != meta_.dimension() || index_meta.element_size() != meta_.element_size() || index_meta.metric_name() != meta_.metric_name() || index_meta.data_type() != meta_.data_type()) { LOG_ERROR("IndexMeta mismatch from the previous in index"); return IndexError_Mismatch; } // The IndexMetric Params may be updated like MipsSquaredEuclidean auto metric_params = index_meta.metric_params(); metric_params.merge(meta_.metric_params()); meta_.set_metric(index_meta.metric_name(), 0, metric_params); } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("Failed to create metric %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("Failed to init metric, ret=%d", ret); return ret; } if (!metric_->distance()) { LOG_ERROR("Invalid metric distance"); return IndexError_InvalidArgument; } if (!metric_->batch_distance()) { LOG_ERROR("Invalid metric batch distance"); return IndexError_InvalidArgument; } add_distance_ = metric_->distance(); add_batch_distance_ = metric_->batch_distance(); search_distance_ = add_distance_; search_batch_distance_ = add_batch_distance_; if (metric_->query_metric() && metric_->query_metric()->distance() && metric_->query_metric()->batch_distance()) { search_distance_ = metric_->query_metric()->distance(); search_batch_distance_ = metric_->query_metric()->batch_distance(); } state_ = STATE_OPENED; magic_ = IndexContext::GenerateMagic(); return 0; } int HnswStreamer::close(void) { LOG_INFO("HnswStreamer close"); stats_.clear(); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); int ret = entity_.close(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswStreamer::flush(uint64_t checkpoint) { LOG_INFO("HnswStreamer flush checkpoint=%zu", (size_t)checkpoint); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); return entity_.flush(checkpoint); } int HnswStreamer::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("HnswStreamer dump"); shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); meta_.set_searcher("HnswSearcher", HnswEntity::kRevision, ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } return entity_.dump(dumper); } IndexStreamer::Context::Pointer HnswStreamer::create_context(void) const { if (ailego_unlikely(state_ != STATE_OPENED)) { LOG_ERROR("Create context failed, open storage first!"); return Context::Pointer(); } HnswEntity::Pointer entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("CreateContext clone init failed"); return Context::Pointer(); } HnswContext *ctx = new (std::nothrow) HnswContext(meta_.dimension(), metric_, entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_filter_mode(bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap); ctx->set_filter_negative_probability(bf_negative_prob_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswContext::kStreamerContext)) != 0) { LOG_ERROR("Init HnswContext failed"); delete ctx; return Context::Pointer(); } uint32_t estimate_doc_count = 0; if (meta_.streamer_params().get(PARAM_HNSW_STREAMER_ESTIMATE_DOC_COUNT, &estimate_doc_count)) { LOG_DEBUG("HnswStreamer doc_count[%zu] estimate[%zu]", (size_t)entity_.doc_cnt(), (size_t)estimate_doc_count); } ctx->check_need_adjuct_ctx(std::max(entity_.doc_cnt(), estimate_doc_count)); return Context::Pointer(ctx); } IndexProvider::Pointer HnswStreamer::create_provider(void) const { LOG_DEBUG("HnswStreamer create provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswEntity failed"); return nullptr; } return Provider::Pointer( new HnswIndexProvider(meta_, entity, "HnswStreamer")); } int HnswStreamer::update_context(HnswContext *ctx) const { const HnswEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswContext::kStreamerContext, meta_, metric_, entity, magic_); } //! Add a vector with id into index int HnswStreamer::add_with_id_impl(uint32_t id, const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); ctx->reset_query(query); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); if (metric_->support_train()) { const std::lock_guard lk(mutex_); ret = metric_->train(query, meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer metric train failed"); (*stats_.mutable_discarded_count())++; return ret; } } level_t level = alg_->get_random_level(); ret = entity_.add_vector_with_id(level, id, query); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw steamer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } //! Add a vector into index int HnswStreamer::add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); ctx->reset_query(query); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); if (metric_->support_train()) { const std::lock_guard lk(mutex_); ret = metric_->train(query, meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer metric train failed"); (*stats_.mutable_discarded_count())++; return ret; } } level_t level = alg_->get_random_level(); node_id_t id; ret = entity_.add_vector(level, pkey, query, &id); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw steamer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } int HnswStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_impl(query, qmeta, 1, context); } //! Similarity search int HnswStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ret = alg_->search(ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } void HnswStreamer::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; if (neighbours.size() == 0) std::cout << std::endl; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } // entity_.print_key_map(); } int HnswStreamer::search_bf_impl( const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_bf_impl(query, qmeta, 1, context); } int HnswStreamer::search_bf_impl( const void *query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().batch_dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { auto &filter = ctx->filter(); auto &topk = ctx->topk_heap(); for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); topk.clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!filter.is_valid() || !filter(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().batch_dist(id); topk.emplace(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswStreamer::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().batch_dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { auto &filter = ctx->filter(); auto &topk = ctx->topk_heap(); for (size_t q = 0; q < count; ++q) { ctx->reset_query(query); topk.clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { key_t pk = p_keys[q][idx]; if (!filter.is_valid() || !filter(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().batch_dist(id); topk.emplace(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } INDEX_FACTORY_REGISTER_STREAMER(HnswStreamer); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_algorithm.h" #include "hnsw_streamer_entity.h" namespace zvec { namespace core { class HnswStreamer : public IndexStreamer { public: using ContextPointer = IndexStreamer::Context::Pointer; HnswStreamer(void); virtual ~HnswStreamer(void); HnswStreamer(const HnswStreamer &streamer) = delete; HnswStreamer &operator=(const HnswStreamer &streamer) = delete; protected: //! Initialize Streamer virtual int init(const IndexMeta &imeta, const ailego::Params ¶ms) override; //! Cleanup Streamer virtual int cleanup(void) override; //! Create a context virtual Context::Pointer create_context(void) const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Add a vector into index virtual int add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Add a vector with id into index virtual int add_with_id_impl(uint32_t id, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch vector by key virtual const void *get_vector(uint64_t key) const override { return entity_.get_vector_by_key(key); } virtual int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return entity_.get_vector_by_key(key, block); } //! Fetch vector by id virtual const void *get_vector_by_id(uint32_t id) const override { return entity_.get_vector(id); } virtual int get_vector_by_id( const uint32_t id, IndexStorage::MemoryBlock &block) const override { return entity_.get_vector(id, block); } //! Open index from file path virtual int open(IndexStorage::Pointer stg) override; //! Close file virtual int close(void) override; //! flush file virtual int flush(uint64_t checkpoint) override; //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } virtual void print_debug_info() override; private: inline int check_params(const void *query, const IndexQueryMeta &qmeta) const { if (ailego_unlikely(!query)) { LOG_ERROR("null query"); return IndexError_InvalidArgument; } if (ailego_unlikely(qmeta.dimension() != meta_.dimension() || qmeta.data_type() != meta_.data_type() || qmeta.element_size() != meta_.element_size())) { LOG_ERROR("Unsupported query meta"); return IndexError_Mismatch; } return 0; } inline int check_sparse_count_is_zero(const uint32_t *sparse_count, uint32_t count) const { for (uint32_t i = 0; i < count; ++i) { if (sparse_count[i] != 0) LOG_ERROR("Sparse cout is not empty. Index: %u, Sparse Count: %u", i, sparse_count[i]); return IndexError_InvalidArgument; } return 0; } private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; class Stats : public IndexStreamer::Stats { public: void clear(void) { set_revision_id(0u); set_loaded_count(0u); set_added_count(0u); set_discarded_count(0u); set_index_size(0u); set_dumped_size(0u); set_check_point(0u); set_create_time(0u); set_update_time(0u); clear_attributes(); } }; HnswStreamerEntity entity_; HnswAlgorithm::UPointer alg_; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; IndexMetric::MatrixDistance add_distance_{}; IndexMetric::MatrixDistance search_distance_{}; IndexMetric::MatrixBatchDistance add_batch_distance_{}; IndexMetric::MatrixBatchDistance search_batch_distance_{}; Stats stats_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; size_t chunk_size_{HnswEntity::kDefaultChunkSize}; size_t docs_hard_limit_{HnswEntity::kDefaultDocsHardLimit}; size_t docs_soft_limit_{0UL}; uint32_t min_neighbor_cnt_{0u}; uint32_t upper_max_neighbor_cnt_{HnswEntity::kDefaultUpperMaxNeighborCnt}; uint32_t l0_max_neighbor_cnt_{HnswEntity::kDefaultL0MaxNeighborCnt}; uint32_t ef_{HnswEntity::kDefaultEf}; uint32_t ef_construction_{HnswEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswEntity::kDefaultScalingFactor}; size_t bruteforce_threshold_{HnswEntity::kDefaultBruteForceThreshold}; size_t max_scan_limit_{HnswEntity::kDefaultMaxScanLimit}; size_t min_scan_limit_{HnswEntity::kDefaultMinScanLimit}; float bf_negative_prob_{HnswEntity::kDefaultBFNegativeProbability}; float max_scan_ratio_{HnswEntity::kDefaultScanRatio}; uint32_t magic_{0U}; State state_{STATE_INIT}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; bool force_padding_topk_enabled_{false}; bool use_id_map_{true}; //! avoid add vector while dumping index ailego::SharedMutex shared_mutex_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_streamer_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_streamer_entity.h" #include // #define DEBUG_PRINT namespace zvec { namespace core { HnswStreamerEntity::HnswStreamerEntity(IndexStreamer::Stats &stats) : stats_(stats) {} HnswStreamerEntity::~HnswStreamerEntity() {} int HnswStreamerEntity::init(size_t max_doc_cnt) { if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); broker_ = std::make_shared(stats_); upper_neighbor_index_ = std::make_shared(); keys_map_lock_ = std::make_shared(); keys_map_ = std::make_shared>(); if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { LOG_ERROR("HnswStreamerEntity new object failed"); return IndexError_NoMemory; } keys_map_->set_empty_key(kInvalidKey); neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); //! vector + key + level 0 neighbors size_t size = vector_size() + sizeof(key_t) + neighbor_size_; size = AlignSize(size); set_node_size(size); return 0; } int HnswStreamerEntity::cleanup() { std::lock_guard lock(mutex_); mutable_header()->clear(); chunk_size_ = kDefaultChunkSize; node_index_mask_bits_ = 0U; node_index_mask_ = 0U; node_cnt_per_chunk_ = 0U; neighbor_size_ = 0U; upper_neighbor_size_ = 0U; if (upper_neighbor_index_) { upper_neighbor_index_->cleanup(); } if (keys_map_) { keys_map_->clear(); } node_chunks_.clear(); upper_neighbor_chunks_.clear(); filter_same_key_ = false; get_vector_enabled_ = false; broker_.reset(); return 0; } int HnswStreamerEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { std::vector buffer(neighbor_size_); NeighborsHeader *hd = reinterpret_cast(buffer.data()); hd->neighbor_cnt = neighbors.size(); size_t i = 0; for (; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } auto loc = get_neighbor_chunk_loc(level, id); size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; size_t ret = loc.first->write(loc.second, hd, size); if (ailego_unlikely(ret != size)) { LOG_ERROR("Write neighbor header failed, ret=%zu", ret); return IndexError_Runtime; } return 0; } const Neighbors HnswStreamerEntity::get_neighbors(level_t level, node_id_t id) const { Chunk *chunk = nullptr; size_t offset = 0UL; size_t neighbor_size = neighbor_size_; if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; offset = (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); chunk = node_chunks_[chunk_idx].get(); } else { auto p = get_upper_neighbor_chunk_loc(level, id); chunk = upper_neighbor_chunks_[p.first].get(); offset = p.second; neighbor_size = upper_neighbor_size_; } ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); IndexStorage::MemoryBlock neighbor_block; size_t size = chunk->read(offset, neighbor_block, neighbor_size); if (ailego_unlikely(size != neighbor_size)) { LOG_ERROR("Read neighbor header failed, ret=%zu", size); return Neighbors(); } return Neighbors(neighbor_block); } //! Get vector data by key const void *HnswStreamerEntity::get_vector(node_id_t id) const { auto loc = get_vector_chunk_loc(id); const void *vec = nullptr; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); } return vec; } int HnswStreamerEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } } return 0; } int HnswStreamerEntity::get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const { auto loc = get_vector_chunk_loc(id); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } return 0; } int HnswStreamerEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { vec_blocks.resize(count); for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, vec_blocks[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } } return 0; } key_t HnswStreamerEntity::get_key(node_id_t id) const { if (use_key_info_map_) { auto loc = get_key_chunk_loc(id); IndexStorage::MemoryBlock key_block; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t ret = node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); if (ailego_unlikely(ret != sizeof(key_t))) { LOG_ERROR("Read vector failed, ret=%zu", ret); return kInvalidKey; } return *reinterpret_cast(key_block.data()); } else { return id; } } void HnswStreamerEntity::add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) { auto loc = get_neighbor_chunk_loc(level, id); size_t offset = loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); if (ailego_unlikely(ret != sizeof(node_id_t))) { LOG_ERROR("Write neighbor id failed, ret=%zu", ret); return; } uint32_t neighbors = size + 1; ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); if (ailego_unlikely(ret != sizeof(uint32_t))) { LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); } return; } int HnswStreamerEntity::init_chunks(const Chunk::Pointer &header_chunk) { if (header_chunk->data_size() < header_size()) { LOG_ERROR("Invalid header chunk size"); return IndexError_InvalidFormat; } IndexStorage::MemoryBlock header_block; size_t size = header_chunk->read(0UL, header_block, header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Read header chunk failed"); return IndexError_ReadData; } *mutable_header() = *reinterpret_cast(header_block.data()); int ret = check_hnsw_index(&header()); if (ret != 0) { broker_->close(); return ret; } node_chunks_.resize(broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE)); for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { node_chunks_[seq] = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_NODE, seq); if (!node_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, node_chunks_.size()); return IndexError_InvalidFormat; } } upper_neighbor_chunks_.resize( broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { upper_neighbor_chunks_[seq] = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); if (!upper_neighbor_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, upper_neighbor_chunks_.size()); return IndexError_InvalidFormat; } } return 0; } int HnswStreamerEntity::open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc) { std::lock_guard lock(mutex_); bool huge_page = stg->isHugePage(); LOG_DEBUG("huge_page: %d", (int)huge_page); int ret = init_chunk_params(max_index_size, huge_page); if (ailego_unlikely(ret != 0)) { LOG_ERROR("init_chunk_params failed for %s", IndexError::What(ret)); return ret; } ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Open index failed for %s", IndexError::What(ret)); return ret; } ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, scaling_factor(), estimate_doc_capacity(), kUpperHashMemoryInflateRatio); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Init neighbor hash map failed"); return ret; } //! init header auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, ChunkBroker::kDefaultChunkSeqId); if (!header_chunk) { // open empty index, create one auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_HEADER, ChunkBroker::kDefaultChunkSeqId, header_size()); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc header chunk failed"); return p.first; } size_t size = p.second->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } //! Open an exist hnsw index ret = init_chunks(header_chunk); if (ailego_unlikely(ret != 0)) { return ret; } //! total docs including features wrote in index but neighbors may not ready node_id_t total_vecs = 0; if (node_chunks_.size() > 0) { size_t last_idx = node_chunks_.size() - 1; auto last_chunk = node_chunks_[last_idx]; if (last_chunk->data_size() % node_size()) { LOG_WARN("The index may broken"); return IndexError_InvalidFormat; } total_vecs = last_idx * node_cnt_per_chunk_ + node_chunks_[last_idx]->data_size() / node_size(); } LOG_INFO( "Open index, l0NeighborCnt=%zu upperNeighborCnt=%zu " "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), total_vecs, cur_max_level()); //! try to correct the docCnt if index not fully flushed if (doc_cnt() != total_vecs) { LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); *mutable_doc_cnt() = total_vecs; } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { for (node_id_t id = 0U; id < doc_cnt(); ++id) { if (get_key(id) == kInvalidKey) { continue; } (*keys_map_)[get_key(id)] = id; } } } stats_.set_loaded_count(doc_cnt()); return 0; } int HnswStreamerEntity::close() { LOG_DEBUG("close index"); std::lock_guard lock(mutex_); flush_header(); mutable_header()->reset(); upper_neighbor_index_->cleanup(); keys_map_->clear(); header_.clear(); node_chunks_.clear(); upper_neighbor_chunks_.clear(); return broker_->close(); } int HnswStreamerEntity::flush(uint64_t checkpoint) { LOG_INFO("Flush index, curDocs=%u", doc_cnt()); std::lock_guard lock(mutex_); flush_header(); int ret = broker_->flush(checkpoint); if (ret != 0) { return ret; } return 0; } int HnswStreamerEntity::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("Dump index, curDocs=%u", doc_cnt()); //! sort by keys, to support get_vector by key in searcher std::vector keys(doc_cnt()); for (node_id_t i = 0; i < doc_cnt(); ++i) { keys[i] = get_key(i); } //! dump neighbors auto get_level = [&](node_id_t id) { auto it = upper_neighbor_index_->find(id); if (it == upper_neighbor_index_->end()) { return 0U; }; auto meta = reinterpret_cast(&it->second); return meta->level; }; auto ret = dump_segments(dumper, keys.data(), get_level); if (ailego_unlikely(ret < 0)) { return ret; } *stats_.mutable_dumped_size() += ret; return 0; } int HnswStreamerEntity::check_hnsw_index(const HNSWHeader *hd) const { if (l0_neighbor_cnt() != hd->l0_neighbor_cnt() || upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { LOG_ERROR("Param neighbor cnt: %zu:%zu mismatch index previous %zu:%zu", l0_neighbor_cnt(), upper_neighbor_cnt(), hd->l0_neighbor_cnt(), hd->upper_neighbor_cnt()); return IndexError_Mismatch; } if (vector_size() != hd->vector_size()) { LOG_ERROR("vector size %zu mismatch index previous %zu", vector_size(), hd->vector_size()); return IndexError_Mismatch; } if (ef_construction() != hd->ef_construction()) { LOG_WARN("Param efConstruction %zu mismatch index previous %zu", ef_construction(), hd->ef_construction()); } if (scaling_factor() != hd->scaling_factor()) { LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", scaling_factor(), hd->scaling_factor()); return IndexError_Mismatch; } if (prune_cnt() != hd->neighbor_prune_cnt()) { LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), hd->neighbor_prune_cnt()); return IndexError_Mismatch; } if ((hd->entry_point() != kInvalidNodeId && hd->entry_point() >= hd->doc_cnt()) || (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { LOG_WARN("Invalid entryPoint %u, docCnt %u", hd->entry_point(), hd->doc_cnt()); return IndexError_InvalidFormat; } if (hd->entry_point() == kInvalidNodeId && broker_->get_chunk_cnt(ChunkBroker::CHUNK_TYPE_NODE) > 0) { LOG_WARN("The index is broken, maybe it haven't flush"); return IndexError_InvalidFormat; } return 0; } int HnswStreamerEntity::add_vector(level_t level, key_t key, const void *vec, node_id_t *id) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } node_id_t local_id = static_cast(doc_cnt()); uint32_t chunk_index = node_chunks_.size() - 1U; if (chunk_index == -1U || (node_chunks_[chunk_index]->data_size() >= node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } chunk_index++; auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; chunk_offset = 0UL; node_chunks_.emplace_back(node_chunk); } else { node_chunk = node_chunks_[chunk_index]; chunk_offset = node_chunk->data_size(); } size_t size = node_chunk->write(chunk_offset, vec, vector_size()); if (ailego_unlikely(size != vector_size())) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, local_id); if (ret != 0) { return ret; } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { keys_map_lock_->lock(); (*keys_map_)[key] = local_id; keys_map_lock_->unlock(); } } *mutable_doc_cnt() += 1; broker_->mark_dirty(); *id = local_id; return 0; } int HnswStreamerEntity::add_vector_with_id(level_t level, node_id_t id, const void *vec) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; key_t key = id; std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } // set node_chunk & chunk_offset if succeed auto func_get_node_chunk_and_offset = [&](node_id_t node_id) -> int { uint32_t chunk_index = node_id >> node_index_mask_bits_; ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); // belongs to next chunk if (chunk_index == node_chunks_.size()) { if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; node_chunks_.emplace_back(node_chunk); } node_chunk = node_chunks_[chunk_index]; chunk_offset = (node_id & node_index_mask_) * node_size(); return 0; }; for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { if (auto ret = func_get_node_chunk_and_offset(start_id); ret != 0) { LOG_ERROR("func_get_node_chunk_and_offset failed"); return ret; } size_t size = node_chunk->write(chunk_offset + vector_size(), &kInvalidKey, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write key failed, ret=%zu", size); return IndexError_WriteData; } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } if (auto ret = func_get_node_chunk_and_offset(id); ret != 0) { LOG_ERROR("func_get_node_chunk_and_offset failed"); return ret; } size_t size = node_chunk->write(chunk_offset, vec, vector_size()); if (ailego_unlikely(size != vector_size())) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, id); if (ret != 0) { return ret; } if (*mutable_doc_cnt() <= id) { *mutable_doc_cnt() = id + 1; chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { keys_map_lock_->lock(); (*keys_map_)[key] = id; keys_map_lock_->unlock(); } } broker_->mark_dirty(); return 0; } void HnswStreamerEntity::update_ep_and_level(node_id_t ep, level_t level) { HnswEntity::update_ep_and_level(ep, level); flush_header(); return; } const HnswEntity::Pointer HnswStreamerEntity::clone() const { std::vector node_chunks; node_chunks.reserve(node_chunks_.size()); for (size_t i = 0UL; i < node_chunks_.size(); ++i) { node_chunks.emplace_back(node_chunks_[i]->clone()); if (ailego_unlikely(!node_chunks[i])) { LOG_ERROR("HnswStreamerEntity get chunk failed in clone"); return HnswEntity::Pointer(); } } std::vector upper_neighbor_chunks; upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); if (ailego_unlikely(!upper_neighbor_chunks[i])) { LOG_ERROR("HnswStreamerEntity get chunk failed in clone"); return HnswEntity::Pointer(); } } HnswStreamerEntity *entity = new (std::nothrow) HnswStreamerEntity( stats_, header(), chunk_size_, node_index_mask_bits_, upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswStreamerEntity new failed"); } return HnswEntity::Pointer(entity); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw/hnsw_streamer_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "hnsw_chunk.h" #include "hnsw_entity.h" #include "hnsw_index_hash.h" #include "hnsw_params.h" namespace zvec { namespace core { //! HnswStreamerEntity manage vector data, pkey, and node's neighbors class HnswStreamerEntity : public HnswEntity { public: //! Cleanup //! return 0 on success, or errCode in failure virtual int cleanup() override; //! Make a copy of streamer entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const override; //! Get vectors feature data by local ids virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t level, key_t key, const void *vec, node_id_t *id) override; //! Add vector and id to hnsw entity virtual int add_vector_with_id(level_t level, node_id_t id, const void *vec) override; virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! Append neighbor_id to node id neighbors on level //! Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Dump index by dumper virtual int dump(const IndexDumper::Pointer &dumper) override; virtual void update_ep_and_level(node_id_t ep, level_t level) override; void set_use_key_info_map(bool use_id_map) { use_key_info_map_ = use_id_map; LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); } public: //! Constructor HnswStreamerEntity(IndexStreamer::Stats &stats); //! Destructor ~HnswStreamerEntity(); //! Get vector feature data by key virtual const void *get_vector_by_key(key_t key) const override { auto id = get_id(key); return id == kInvalidNodeId ? nullptr : get_vector(id); } virtual int get_vector_by_key( const key_t key, IndexStorage::MemoryBlock &block) const override { auto id = get_id(key); if (id != kInvalidNodeId) { return get_vector(id, block); } else { return IndexError_InvalidArgument; } } //! Init entity int init(size_t max_doc_cnt); //! Flush graph entity to disk //! return 0 on success, or errCode in failure int flush(uint64_t checkpoint); //! Open entity from storage //! return 0 on success, or errCode in failure int open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc); //! Close entity //! return 0 on success, or errCode in failure int close(); //! Set meta information from entity int set_index_meta(const IndexMeta &meta) const { return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); } //! Get meta information from entity int get_index_meta(IndexMeta *meta) const { return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); } //! Set params: chunk size inline void set_chunk_size(size_t val) { chunk_size_ = val; } //! Set params inline void set_filter_same_key(bool val) { filter_same_key_ = val; } //! Set params inline void set_get_vector(bool val) { get_vector_enabled_ = val; } //! Get vector local id by key inline node_id_t get_id(key_t key) const { if (use_key_info_map_) { keys_map_lock_->lock_shared(); auto it = keys_map_->find(key); keys_map_lock_->unlock_shared(); return it == keys_map_->end() ? kInvalidNodeId : it->second; } else { return key; } } void print_key_map() const { std::cout << "key map begins" << std::endl; auto iter = keys_map_->begin(); while (iter != keys_map_->end()) { std::cout << "key: " << iter->first << ", id: " << iter->second << std::endl; ; iter++; } std::cout << "key map ends" << std::endl; } //! Get l0 neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get neighbors size for level > 0 inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } private: union UpperNeighborIndexMeta { struct { uint32_t level : 4; uint32_t index : 28; // index is composite type: chunk idx, and the // N th neighbors in chunk, they two composite // the 28 bits location }; uint32_t data; }; template using HashMap = google::dense_hash_map>; template using HashMapPointer = std::shared_ptr>; template using HashSet = google::dense_hash_set>; template using HashSetPointer = std::shared_ptr>; //! upper neighbor index hashmap using NIHashMap = HnswIndexHashMap; using NIHashMapPointer = std::shared_ptr; //! Private construct, only be called by clone method HnswStreamerEntity(IndexStreamer::Stats &stats, const HNSWHeader &hd, size_t chunk_size, uint32_t node_index_mask_bits, uint32_t upper_neighbor_mask_bits, bool filter_same_key, bool get_vector_enabled, const NIHashMapPointer &upper_neighbor_index, std::shared_ptr &keys_map_lock, const HashMapPointer &keys_map, bool use_key_info_map, std::vector &&node_chunks, std::vector &&upper_neighbor_chunks, const ChunkBroker::Pointer &broker) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), node_cnt_per_chunk_(1UL << node_index_mask_bits_), node_index_mask_(node_cnt_per_chunk_ - 1), upper_neighbor_mask_bits_(upper_neighbor_mask_bits), upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), filter_same_key_(filter_same_key), get_vector_enabled_(get_vector_enabled), use_key_info_map_(use_key_info_map), upper_neighbor_index_(upper_neighbor_index), keys_map_lock_(keys_map_lock), keys_map_(keys_map), node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), broker_(broker) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); } //! Called only in searching procedure per context, so no need to lock void sync_chunks(ChunkBroker::CHUNK_TYPE type, size_t idx, std::vector *chunks) const { if (ailego_likely(idx < chunks->size())) { return; } for (size_t i = chunks->size(); i <= idx; ++i) { auto chunk = broker_->get_chunk(type, i); // the storage can ensure get chunk will success after the first get ailego_assert_with(!!chunk, "get chunk failed"); chunks->emplace_back(std::move(chunk)); } } //! return pair: chunk index + chunk offset inline std::pair get_vector_chunk_loc( node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size(); sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } //! return pair: chunk index + chunk offset inline std::pair get_key_chunk_loc(node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } inline std::pair get_upper_neighbor_chunk_loc( level_t level, node_id_t id) const { auto it = upper_neighbor_index_->find(id); ailego_assert_abort(it != upper_neighbor_index_->end(), "Get upper neighbor header failed"); auto meta = reinterpret_cast(&it->second); uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * upper_neighbor_size_; sync_chunks(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, &upper_neighbor_chunks_); ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(chunk_idx, offset); } //! return pair: chunk + chunk offset inline std::pair get_neighbor_chunk_loc(level_t level, node_id_t id) const { if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); sync_chunks(ChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(node_chunks_[chunk_idx].get(), offset); } else { auto p = get_upper_neighbor_chunk_loc(level, id); return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); } } //! Chunk hnsw index valid int check_hnsw_index(const HNSWHeader *hd) const; size_t get_total_upper_neighbors_size(level_t level) const { return level * upper_neighbor_size_; } //! Add upper neighbor header and reserve space for upper neighbor int add_upper_neighbor(level_t level, node_id_t id) { if (level == 0) { return 0; } Chunk::Pointer chunk; uint64_t chunk_offset = -1UL; size_t neighbors_size = get_total_upper_neighbors_size(level); uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; if (chunk_index == -1UL || (upper_neighbor_chunks_[chunk_index]->padding_size() < neighbors_size)) { // no space left and need to alloc chunk_index++; if (ailego_unlikely(upper_neighbor_chunks_.capacity() == upper_neighbor_chunks_.size())) { LOG_ERROR("add upper neighbor failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(ChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_index, upper_neighbor_chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } chunk = p.second; chunk_offset = 0UL; upper_neighbor_chunks_.emplace_back(chunk); } else { chunk = upper_neighbor_chunks_[chunk_index]; chunk_offset = chunk->data_size(); } ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, "invalid offset"); ailego_assert_with((chunk_offset / upper_neighbor_size_) < (1U << upper_neighbor_mask_bits_), "invalid offset"); ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), "invalid chunk index"); UpperNeighborIndexMeta meta; meta.level = level; meta.index = (chunk_index << upper_neighbor_mask_bits_) | (chunk_offset / upper_neighbor_size_); chunk_offset += upper_neighbor_size_ * level; if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { LOG_ERROR("HashMap insert value failed"); return IndexError_Runtime; } if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", (size_t)chunk_offset); return IndexError_Runtime; } return 0; } size_t estimate_doc_capacity() const { return node_chunks_.capacity() * node_cnt_per_chunk_; } int init_chunk_params(size_t max_index_size, bool huge_page) { node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); //! align node cnt per chunk to pow of 2 node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; if (huge_page) { chunk_size_ = AlignHugePageSize(node_cnt_per_chunk_ * node_size()); } else { chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); } node_index_mask_ = node_cnt_per_chunk_ - 1; if (max_index_size == 0UL) { max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; } else { max_index_size_ = max_index_size; } //! To get a balanced upper neighbor chunk size. //! If the upper chunk size is equal to node chunk size, it may waste //! upper neighbor chunk space; if the upper neighbor chunk size is too //! small, the will need large upper neighbor chunks index space. So to //! get a balanced ratio be sqrt of the node/neighbor size ratio float ratio = std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); if (huge_page) { upper_neighbor_chunk_size_ = AlignHugePageSize( std::max(get_total_upper_neighbors_size(kMaxGraphLayers), static_cast(chunk_size_ / ratio))); } else { upper_neighbor_chunk_size_ = AlignPageSize( std::max(get_total_upper_neighbors_size(kMaxGraphLayers), static_cast(chunk_size_ / ratio))); } upper_neighbor_mask_bits_ = std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); size_t max_upper_chunk_cnt = std::ceil( (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / (upper_neighbor_chunk_size_ / upper_neighbor_size_)); max_upper_chunk_cnt = max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); //! reserve space to avoid memmove in chunks vector emplace chunk, so //! as to lock-free in reading chunk node_chunks_.reserve(max_node_chunk_cnt); upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); LOG_DEBUG( "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " "upperNeighborChunkSize=%u " "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " "maxIndexSize=%zu ratio=%.3f", node_size(), chunk_size_, upper_neighbor_size_, upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, max_upper_chunk_cnt, max_index_size_, ratio); return 0; } //! Init node chunk and neighbor chunks int init_chunks(const Chunk::Pointer &header_chunk); int flush_header(void) { if (!broker_->dirty()) { // do not need to flush return 0; } auto header_chunk = broker_->get_chunk(ChunkBroker::CHUNK_TYPE_HEADER, ChunkBroker::kDefaultChunkSeqId); if (ailego_unlikely(!header_chunk)) { LOG_ERROR("get header chunk failed"); return IndexError_Runtime; } size_t size = header_chunk->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } private: HnswStreamerEntity(const HnswStreamerEntity &) = delete; HnswStreamerEntity &operator=(const HnswStreamerEntity &) = delete; static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; private: IndexStreamer::Stats &stats_; HNSWHeader header_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; uint32_t chunk_size_{kDefaultChunkSize}; uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; uint32_t node_index_mask_bits_{0U}; uint32_t node_cnt_per_chunk_{0U}; uint32_t node_index_mask_{0U}; uint32_t neighbor_size_{0U}; uint32_t upper_neighbor_size_{0U}; //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the //! following mask uint32_t upper_neighbor_mask_bits_{0U}; uint32_t upper_neighbor_mask_{0U}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; bool use_key_info_map_{true}; NIHashMapPointer upper_neighbor_index_{}; mutable std::shared_ptr keys_map_lock_{}; HashMapPointer keys_map_{}; //! the chunks will be changed in searcher, so need mutable //! data chunk include: vector, key, level 0 neighbors mutable std::vector node_chunks_{}; //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors mutable std::vector upper_neighbor_chunks_{}; ChunkBroker::Pointer broker_{}; // chunk broker }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) if(AUTO_DETECT_ARCH) foreach(FILE ${HNSW_RABITQ_FILES}) set_source_files_properties( ${FILE} PROPERTIES COMPILE_FLAGS "${RABITQ_ARCH_FLAG}" ) endforeach() endif() cc_library( NAME core_knn_hnsw_rabitq STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS core_framework rabitqlib sparsehash INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_algorithm.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_algorithm.h" #include #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { HnswRabitqAlgorithm::HnswRabitqAlgorithm(HnswRabitqEntity &entity) : entity_(entity), mt_(std::chrono::system_clock::now().time_since_epoch().count()), lock_pool_(kLockCnt) {} int HnswRabitqAlgorithm::cleanup() { return 0; } int HnswRabitqAlgorithm::add_node(node_id_t id, level_t level, HnswRabitqContext *ctx) { spin_lock_.lock(); // std::cout << "id: " << id << ", level: " << level << std::endl; auto cur_max_level = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { entity_.update_ep_and_level(id, level); spin_lock_.unlock(); return 0; } spin_lock_.unlock(); if (ailego_unlikely(level > cur_max_level)) { mutex_.lock(); // re-check max level cur_max_level = entity_.cur_max_level(); entry_point = entity_.entry_point(); if (level <= cur_max_level) { mutex_.unlock(); } } level_t cur_level = cur_max_level; ResultRecord dist = ctx->dist_calculator()(entry_point); for (; cur_level > level; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } for (; cur_level >= 0; --cur_level) { search_neighbors(cur_level, &entry_point, &dist, ctx->level_topk(cur_level), ctx); } // add neighbors from down level to top level, to avoid upper level visible // to knn_search but the under layer level not ready for (cur_level = 0; cur_level <= level; ++cur_level) { add_neighbors(id, cur_level, ctx->level_topk(cur_level), ctx); ctx->level_topk(cur_level).clear(); } if (ailego_unlikely(level > cur_max_level)) { spin_lock_.lock(); entity_.update_ep_and_level(id, level); spin_lock_.unlock(); mutex_.unlock(); } return 0; } //! select_entry_point on hnsw level, ef = 1 void HnswRabitqAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, ResultRecord *dist, HnswRabitqContext *ctx) const { auto &entity = ctx->get_entity(); HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); while (true) { const Neighbors neighbors = entity.get_neighbors(level, *entry_point); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } uint32_t size = neighbors.size(); if (size == 0) { break; } std::vector neighbor_vec_blocks; int ret = dc.get_vector(&neighbors[0], size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } bool find_closer = false; std::vector dists(size); std::vector neighbor_vecs(size); for (uint32_t i = 0; i < size; ++i) { neighbor_vecs[i] = neighbor_vec_blocks[i].data(); } dc.batch_dist(neighbor_vecs.data(), size, dists.data()); for (uint32_t i = 0; i < size; ++i) { ResultRecord cur_dist = dists[i]; if (cur_dist < *dist) { *entry_point = neighbors[i]; *dist = cur_dist; find_closer = true; } } if (!find_closer) { break; } } return; } void HnswRabitqAlgorithm::add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswRabitqContext *ctx) { if (ailego_unlikely(topk_heap.size() == 0)) { return; } HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); update_neighbors(dc, id, level, topk_heap); // reverse update neighbors for (size_t i = 0; i < topk_heap.size(); ++i) { reverse_update_neighbors(dc, topk_heap[i].first, level, id, topk_heap[i].second, ctx->update_heap()); } return; } void HnswRabitqAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, ResultRecord *dist, TopkHeap &topk, HnswRabitqContext *ctx) const { const auto &entity = ctx->get_entity(); HnswRabitqAddDistCalculator &dc = ctx->dist_calculator(); VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } candidates.clear(); visit.clear(); visit.set_visited(*entry_point); if (!filter(*entry_point)) { topk.emplace(*entry_point, *dist); } candidates.emplace(*entry_point, *dist); while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; ResultRecord main_dist = top->second; if (topk.full() && main_dist > topk[0].second) { break; } candidates.pop(); const Neighbors neighbors = entity.get_neighbors(level, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } std::vector neighbor_vec_blocks; int ret = dc.get_vector(neighbor_ids.data(), size, neighbor_vec_blocks); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } // do prefetch static constexpr node_id_t BATCH_SIZE = 12; static constexpr node_id_t PREFETCH_STEP = 2; for (uint32_t i = 0; i < std::min(BATCH_SIZE * PREFETCH_STEP, size); ++i) { ailego_prefetch(neighbor_vec_blocks[i].data()); } // done std::vector dists(size); std::vector neighbor_vecs(size); for (uint32_t i = 0; i < size; ++i) { neighbor_vecs[i] = neighbor_vec_blocks[i].data(); } dc.batch_dist(neighbor_vecs.data(), size, dists.data()); for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; ResultRecord cur_dist = dists[i]; if ((!topk.full()) || cur_dist < topk[0].second) { candidates.emplace(node, cur_dist); // update entry_point for next level scan if (cur_dist < *dist) { *entry_point = node; *dist = cur_dist; } if (!filter(node)) { topk.emplace(node, cur_dist); } } // end if } // end for } // while return; } void HnswRabitqAlgorithm::update_neighbors(HnswRabitqAddDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap) { topk_heap.sort(); uint32_t max_neighbor_cnt = entity_.neighbor_cnt(level); if (topk_heap.size() <= static_cast(entity_.prune_cnt())) { if (topk_heap.size() <= static_cast(max_neighbor_cnt)) { entity_.update_neighbors(level, id, topk_heap); return; } } uint32_t cur_size = 0; for (size_t i = 0; i < topk_heap.size(); ++i) { node_id_t cur_node = topk_heap[i].first; ResultRecord cur_node_dist = topk_heap[i].second; bool good = true; for (uint32_t j = 0; j < cur_size; ++j) { ResultRecord tmp_dist = dc.dist(cur_node, topk_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { topk_heap[cur_size].first = cur_node; topk_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } // when after-prune neighbor count is too seldom, // we use this strategy to make-up enough edges // not only just make-up out-degrees // we also make-up enough in-degrees uint32_t min_neighbors = entity_.min_neighbor_cnt(); for (size_t k = cur_size; cur_size < min_neighbors && k < topk_heap.size(); ++k) { bool exist = false; for (size_t j = 0; j < cur_size; ++j) { if (topk_heap[j].first == topk_heap[k].first) { exist = true; break; } } if (!exist) { topk_heap[cur_size].first = topk_heap[k].first; topk_heap[cur_size].second = topk_heap[k].second; cur_size++; } } topk_heap.resize(cur_size); entity_.update_neighbors(level, id, topk_heap); return; } void HnswRabitqAlgorithm::reverse_update_neighbors( HnswRabitqAddDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, ResultRecord dist, TopkHeap &update_heap) { const size_t max_neighbor_cnt = entity_.neighbor_cnt(level); uint32_t lock_idx = id & kLockMask; lock_pool_[lock_idx].lock(); const Neighbors neighbors = entity_.get_neighbors(level, id); size_t size = neighbors.size(); ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size"); if (size < max_neighbor_cnt) { entity_.add_neighbor(level, id, size, link_id); lock_pool_[lock_idx].unlock(); return; } update_heap.emplace(link_id, dist); for (size_t i = 0; i < size; ++i) { node_id_t node = neighbors[i]; ResultRecord cur_dist = dc.dist(id, node); update_heap.emplace(node, cur_dist); } //! TODO: optimize prune //! prune edges update_heap.sort(); size_t cur_size = 0; for (size_t i = 0; i < update_heap.size(); ++i) { node_id_t cur_node = update_heap[i].first; ResultRecord cur_node_dist = update_heap[i].second; bool good = true; for (size_t j = 0; j < cur_size; ++j) { ResultRecord tmp_dist = dc.dist(cur_node, update_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { update_heap[cur_size].first = cur_node; update_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } update_heap.resize(cur_size); entity_.update_neighbors(level, id, update_heap); lock_pool_[lock_idx].unlock(); update_heap.clear(); return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_algorithm.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_rabitq_context.h" #include "hnsw_rabitq_dist_calculator.h" #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { //! hnsw graph algorithm implement class HnswRabitqAlgorithm { public: typedef std::unique_ptr UPointer; public: //! Constructor explicit HnswRabitqAlgorithm(HnswRabitqEntity &entity); //! Destructor ~HnswRabitqAlgorithm() = default; //! Cleanup HnswRabitqAlgorithm int cleanup(); //! Add a node to hnsw graph //! @id: the node unique id //! @level: a node will be add to graph in each level [0, level] //! return 0 on success, or errCode in failure int add_node(node_id_t id, level_t level, HnswRabitqContext *ctx); //! Initiate HnswRabitqAlgorithm int init() { level_probas_.clear(); double level_mult = 1 / std::log(static_cast(entity_.scaling_factor())); for (int level = 0;; level++) { // refers faiss get_random_level alg double proba = std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); if (proba < 1e-9) { break; } level_probas_.push_back(proba); } return 0; } //! Generate a random level //! return graph level uint32_t get_random_level() const { // gen rand float (0, 1) double f = mt_() / static_cast(mt_.max()); for (size_t level = 0; level < level_probas_.size(); level++) { if (f < level_probas_[level]) { return level; } f -= level_probas_[level]; } return level_probas_.size() - 1; } private: //! Select in upper layer to get entry point for next layer search void select_entry_point(level_t level, node_id_t *entry_point, ResultRecord *dist, HnswRabitqContext *ctx) const; //! update node id neighbors from topkHeap, and reverse link is also updated void add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswRabitqContext *ctx); //! Given a node id and level, search the nearest neighbors in graph //! Note: the nearest neighbors result keeps in topk, and entry_point and //! dist will be updated to current level nearest node id and distance void search_neighbors(level_t level, node_id_t *entry_point, ResultRecord *dist, TopkHeap &topk, HnswRabitqContext *ctx) const; //! Update the node's neighbors void update_neighbors(HnswRabitqAddDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap); //! Checking linkId could be id's new neighbor, and add as neighbor if true //! @dc distance calculator //! @updateHeap temporary heap in updating neighbors void reverse_update_neighbors(HnswRabitqAddDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, ResultRecord dist, TopkHeap &update_heap); private: HnswRabitqAlgorithm(const HnswRabitqAlgorithm &) = delete; HnswRabitqAlgorithm &operator=(const HnswRabitqAlgorithm &) = delete; private: static constexpr uint32_t kLockCnt{1U << 8}; static constexpr uint32_t kLockMask{kLockCnt - 1U}; HnswRabitqEntity &entity_; mutable std::mt19937 mt_{}; std::vector level_probas_{}; mutable ailego::SpinMutex spin_lock_{}; // global spin lock std::mutex mutex_{}; // global mutex // TODO: spin lock? std::vector lock_pool_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_builder.h" #include #include #include #include #include #include #include #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_logger.h" #include "zvec/core/framework/index_memory.h" #include "zvec/core/framework/index_meta.h" #include "zvec/core/framework/index_provider.h" #include "hnsw_rabitq_algorithm.h" #include "hnsw_rabitq_entity.h" #include "hnsw_rabitq_params.h" #include "rabitq_converter.h" #include "rabitq_params.h" #include "rabitq_reformer.h" namespace zvec { namespace core { HnswRabitqBuilder::HnswRabitqBuilder() {} int HnswRabitqBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { LOG_INFO("Begin HnswRabitqBuilder::init"); meta_ = meta; auto params_copy = params; meta_.set_builder("HnswRabitqBuilder", HnswRabitqEntity::kRevision, std::move(params_copy)); size_t memory_quota = 0UL; params.get(PARAM_HNSW_RABITQ_BUILDER_MEMORY_QUOTA, &memory_quota); params.get(PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT, &thread_cnt_); params.get(PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); params.get(PARAM_HNSW_RABITQ_BUILDER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_RABITQ_BUILDER_CHECK_INTERVAL_SECS, &check_interval_secs_); params.get(PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswRabitqEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_RABITQ_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR, &scaling_factor_); multiplier = HnswRabitqEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_RABITQ_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; if (ef_construction_ == 0) { ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0) { upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d]", PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%d] must be <= [%s]-[%d]", PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), upper_max_neighbor_cnt_); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0) { l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { LOG_ERROR("L0MaxNeighborCnt must be in range (0,%d)", HnswRabitqEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (thread_cnt_ == 0) { thread_cnt_ = std::thread::hardware_concurrency(); } if (thread_cnt_ > std::thread::hardware_concurrency()) { LOG_WARN("[%s] greater than cpu cores %zu", PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT.c_str(), static_cast(std::thread::hardware_concurrency())); } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } int ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed, ret=%d", ret); return ret; } uint32_t total_bits = 0; params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); if (total_bits == 0) { total_bits = kDefaultRabitqTotalBits; } if (total_bits < 1 || total_bits > 9) { LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); return IndexError_InvalidArgument; } uint8_t ex_bits = total_bits - 1; entity_.set_ex_bits(ex_bits); uint32_t dimension = 0; params.get(PARAM_HNSW_RABITQ_GENERAL_DIMENSION, &dimension); if (dimension == 0) { LOG_ERROR("%s not set", PARAM_HNSW_RABITQ_GENERAL_DIMENSION.c_str()); return IndexError_InvalidArgument; } if (dimension < kMinRabitqDimSize || dimension > kMaxRabitqDimSize) { LOG_ERROR("Invalid dimension: %u, must be in [%d, %d]", dimension, kMinRabitqDimSize, kMaxRabitqDimSize); return IndexError_InvalidArgument; } entity_.update_rabitq_params_and_vector_size(dimension); entity_.set_ef_construction(ef_construction_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_memory_quota(memory_quota); entity_.set_prune_cnt(prune_cnt); ret = entity_.init(); if (ret != 0) { return ret; } alg_ = HnswRabitqAlgorithm::UPointer(new HnswRabitqAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } // Create and initialize RaBitQ converter converter_ = std::make_shared(); IndexMeta converter_meta = meta_; converter_meta.set_dimension(dimension); ret = converter_->init(converter_meta, params); if (ret != 0) { LOG_ERROR("Failed to initialize RabitqConverter: %d", ret); return ret; } state_ = BUILD_STATE_INITED; LOG_INFO( "End HnswRabitqBuilder::init, params: rawVectorSize=%u vectorSize=%zu " "efConstruction=%u " "l0NeighborCnt=%u upperNeighborCnt=%u scalingFactor=%u " "memoryQuota=%zu neighborPruneCnt=%zu metricName=%s ", meta_.element_size(), entity_.vector_size(), ef_construction_, l0_max_neighbor_cnt_, upper_max_neighbor_cnt_, scaling_factor_, memory_quota, prune_cnt, meta_.metric_name().c_str()); return 0; } int HnswRabitqBuilder::cleanup(void) { LOG_INFO("Begin HnswRabitqBuilder::cleanup"); l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; min_neighbor_cnt_ = 0; upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; check_interval_secs_ = kDefaultLogIntervalSecs; errcode_ = 0; error_ = false; entity_.cleanup(); if (alg_) { alg_->cleanup(); } meta_.clear(); metric_.reset(); stats_.clear_attributes(); stats_.set_trained_count(0UL); stats_.set_built_count(0UL); stats_.set_dumped_count(0UL); stats_.set_discarded_count(0UL); stats_.set_trained_costtime(0UL); stats_.set_built_costtime(0UL); stats_.set_dumped_costtime(0UL); state_ = BUILD_STATE_INIT; LOG_INFO("End HnswRabitqBuilder::cleanup"); return 0; } int HnswRabitqBuilder::train(IndexThreads::Pointer, IndexHolder::Pointer holder) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswRabitqBuilder::train"); return IndexError_NoReady; } if (!holder) { LOG_ERROR("Input holder is nullptr while training index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while training index"); return IndexError_Mismatch; } LOG_INFO("Begin HnswRabitqBuilder::train"); size_t trained_cost_time = 0; size_t trained_count = 0; int ret = train_converter_and_load_reformer(holder); if (ret != 0) { return ret; } if (metric_->support_train()) { auto start_time = ailego::Monotime::MilliSeconds(); auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } while (iter->is_valid()) { ret = metric_->train(iter->data(), meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw build measure train failed, ret=%d", ret); return ret; } iter->next(); ++trained_count; } trained_cost_time = ailego::Monotime::MilliSeconds() - start_time; } stats_.set_trained_count(trained_count); stats_.set_trained_costtime(trained_cost_time); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswRabitqBuilder::train"); return 0; } int HnswRabitqBuilder::train_converter_and_load_reformer( IndexHolder::Pointer holder) { // Train converter (KMeans clustering) int ret = converter_->train(holder); if (ret != 0) { LOG_ERROR("Failed to train RabitqConverter: %d", ret); return ret; } auto memory_dumper = IndexFactory::CreateDumper("MemoryDumper"); memory_dumper->init(ailego::Params()); std::string file_id = ailego::StringHelper::Concat( "rabitq_converter_", ailego::Monotime::MilliSeconds(), rand()); ret = memory_dumper->create(file_id); if (ret != 0) { LOG_ERROR("Failed to create memory dumper: %d", ret); return ret; } // Release memory AILEGO_DEFER([&file_id]() { IndexMemory::Instance()->remove(file_id); }); ret = converter_->dump(memory_dumper); if (ret != 0) { LOG_ERROR("Failed to dump RabitqConverter: %d", ret); return ret; } ret = memory_dumper->close(); if (ret != 0) { LOG_ERROR("Failed to close memory dumper: %d", ret); return ret; } reformer_ = std::make_shared(); ailego::Params reformer_params; reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); ret = reformer_->init(reformer_params); if (ret != 0) { LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); return ret; } auto memory_storage = IndexFactory::CreateStorage("MemoryReadStorage"); ret = memory_storage->open(file_id, false); if (ret != 0) { LOG_ERROR("Failed to open memory storage: %d", ret); return ret; } ret = reformer_->load(memory_storage); if (ret != 0) { LOG_ERROR("Failed to load RabitqReformer: %d", ret); return ret; } return 0; } int HnswRabitqBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswRabitqBuilder::train"); return IndexError_NoReady; } LOG_INFO("Begin HnswRabitqBuilder::train by trainer"); stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswRabitqBuilder::train by trainer"); return 0; } int HnswRabitqBuilder::build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) { if (state_ != BUILD_STATE_TRAINED) { LOG_ERROR("Train the index before HnswRabitqBuilder::build"); return IndexError_NoReady; } if (!holder) { LOG_ERROR("Input holder is nullptr while building index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while building index"); return IndexError_Mismatch; } IndexProvider::Pointer provider = std::dynamic_pointer_cast(holder); if (!provider) { LOG_ERROR("Rabitq builder expect IndexProvider"); return IndexError_InvalidArgument; } if (!threads) { threads = std::make_shared(thread_cnt_, false); } auto start_time = ailego::Monotime::MilliSeconds(); LOG_INFO("Begin HnswRabitqBuilder::build"); if (holder->count() != static_cast(-1)) { LOG_DEBUG("HnswRabitqBuilder holder documents count %lu", holder->count()); int ret = entity_.reserve_space(holder->count()); if (ret != 0) { LOG_ERROR("HnswBuilde reserver space failed"); return ret; } } auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } int ret; error_ = false; IndexQueryMeta ometa; ometa.set_meta(holder->data_type(), holder->dimension()); while (iter->is_valid()) { const void *vec = iter->data(); // quantize vector std::string converted_vector; IndexQueryMeta converted_meta; ret = reformer_->convert(vec, ometa, &converted_vector, &converted_meta); if (ret != 0) { LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); return ret; } level_t level = alg_->get_random_level(); node_id_t id; if (converted_vector.size() != entity_.vector_size()) { LOG_ERROR( "Converted vector size %zu is not equal to entity vector size %zu", converted_vector.size(), entity_.vector_size()); return IndexError_InvalidArgument; } ret = entity_.add_vector(level, iter->key(), converted_vector.data(), &id); if (ailego_unlikely(ret != 0)) { return ret; } iter->next(); } LOG_INFO("Finished save vector, start build graph..."); auto task_group = threads->make_group(); if (!task_group) { LOG_ERROR("Failed to create task group"); return IndexError_Runtime; } std::atomic finished{0}; for (size_t i = 0; i < threads->count(); ++i) { task_group->submit(ailego::Closure ::New(this, &HnswRabitqBuilder::do_build, i, threads->count(), provider, &finished)); } while (!task_group->is_finished()) { std::unique_lock lk(mutex_); cond_.wait_until(lk, std::chrono::system_clock::now() + std::chrono::seconds(check_interval_secs_)); if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } LOG_INFO("Built cnt %zu, finished percent %.3f%%", static_cast(finished.load()), finished.load() * 100.0f / entity_.doc_cnt()); } if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } task_group->wait_finish(); stats_.set_built_count(finished.load()); stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = BUILD_STATE_BUILT; LOG_INFO("End HnswRabitqBuilder::build with RaBitQ quantization"); return 0; } void HnswRabitqBuilder::do_build(node_id_t idx, size_t step_size, IndexProvider::Pointer provider, std::atomic *finished) { AILEGO_DEFER([&]() { std::lock_guard latch(mutex_); cond_.notify_one(); }); HnswRabitqContext *ctx = new (std::nothrow) HnswRabitqContext( meta_.dimension(), metric_, std::shared_ptr(&entity_, [](HnswRabitqEntity *) {})); if (ailego_unlikely(ctx == nullptr)) { if (!error_.exchange(true)) { LOG_ERROR("Failed to create context"); errcode_ = IndexError_NoMemory; } return; } HnswRabitqContext::Pointer auto_ptr(ctx); ctx->set_provider(std::move(provider)); ctx->set_max_scan_num(entity_.doc_cnt()); int ret = ctx->init(HnswRabitqContext::kBuilderContext); if (ret != 0) { if (!error_.exchange(true)) { LOG_ERROR("Failed to init context"); errcode_ = IndexError_Runtime; } return; } for (node_id_t id = idx; id < entity_.doc_cnt(); id += step_size) { ctx->reset_query(ctx->dist_calculator().get_vector(id)); ret = alg_->add_node(id, entity_.get_level(id), ctx); if (ailego_unlikely(ret != 0)) { if (!error_.exchange(true)) { LOG_ERROR("Hnsw graph add node failed"); errcode_ = ret; } return; } ctx->clear(); (*finished)++; } } int HnswRabitqBuilder::dump(const IndexDumper::Pointer &dumper) { if (state_ != BUILD_STATE_BUILT) { LOG_INFO("Build the index before HnswRabitqBuilder::dump"); return IndexError_NoReady; } LOG_INFO("Begin HnswRabitqBuilder::dump"); meta_.set_searcher("HnswRabitqSearcher", HnswRabitqEntity::kRevision, ailego::Params()); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } // Dump RaBitQ centroids first if (converter_) { ret = converter_->dump(dumper); if (ret != 0) { LOG_ERROR("Failed to dump RabitqConverter: %d", ret); return ret; } LOG_INFO("RaBitQ centroids dumped: %zu bytes, cost %zu ms", converter_->stats().dumped_size(), static_cast(converter_->stats().dumped_costtime())); } ret = entity_.dump(dumper); if (ret != 0) { LOG_ERROR("HnswRabitqBuilder dump index failed"); return ret; } stats_.set_dumped_count(entity_.doc_cnt()); stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); LOG_INFO("End HnswRabitqBuilder::dump"); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "zvec/core/framework/index_builder.h" #include "zvec/core/framework/index_converter.h" #include "zvec/core/framework/index_reformer.h" #include "hnsw_rabitq_algorithm.h" #include "hnsw_rabitq_builder_entity.h" namespace zvec { namespace core { class HnswRabitqBuilder : public IndexBuilder { public: //! Constructor HnswRabitqBuilder(); //! Initialize the builder virtual int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder virtual int cleanup(void) override; //! Train the data virtual int train(IndexThreads::Pointer, IndexHolder::Pointer holder) override; //! Train the data virtual int train(const IndexTrainer::Pointer &trainer) override; //! Build the index virtual int build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) override; //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } private: void do_build(node_id_t idx, size_t step_size, IndexProvider::Pointer provider, std::atomic *finished); int train_converter_and_load_reformer(IndexHolder::Pointer holder); constexpr static uint32_t kDefaultLogIntervalSecs = 15U; constexpr static uint32_t kMaxNeighborCnt = 65535; private: enum BUILD_STATE { BUILD_STATE_INIT = 0, BUILD_STATE_INITED = 1, BUILD_STATE_TRAINED = 2, BUILD_STATE_BUILT = 3 }; HnswRabitqBuilderEntity entity_{}; HnswRabitqAlgorithm::UPointer alg_; // impl graph algorithm uint32_t thread_cnt_{0}; uint32_t min_neighbor_cnt_{0}; uint32_t upper_max_neighbor_cnt_{ HnswRabitqEntity::kDefaultUpperMaxNeighborCnt}; uint32_t l0_max_neighbor_cnt_{HnswRabitqEntity::kDefaultL0MaxNeighborCnt}; uint32_t ef_construction_{HnswRabitqEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswRabitqEntity::kDefaultScalingFactor}; uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; int errcode_{0}; std::atomic_bool error_{false}; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; IndexConverter::Pointer converter_{}; // RaBitQ converter IndexReformer::Pointer reformer_{}; // RaBitQ reformer std::mutex mutex_{}; std::condition_variable cond_{}; Stats stats_{}; BUILD_STATE state_{BUILD_STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_builder_entity.h" #include #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswRabitqBuilderEntity::HnswRabitqBuilderEntity() { update_ep_and_level(kInvalidNodeId, 0U); } int HnswRabitqBuilderEntity::cleanup() { memory_quota_ = 0UL; neighbors_size_ = 0U; upper_neighbors_size_ = 0U; padding_size_ = 0U; vectors_buffer_.clear(); keys_buffer_.clear(); neighbors_buffer_.clear(); upper_neighbors_buffer_.clear(); neighbors_index_.clear(); vectors_buffer_.shrink_to_fit(); keys_buffer_.shrink_to_fit(); neighbors_buffer_.shrink_to_fit(); upper_neighbors_buffer_.shrink_to_fit(); neighbors_index_.shrink_to_fit(); this->HnswRabitqEntity::cleanup(); return 0; } int HnswRabitqBuilderEntity::init() { size_t size = vector_size(); //! aligned size to 32 set_node_size(AlignSize(size)); //! if node size is aligned to 1k, the build performance will downgrade if (node_size() % 1024 == 0) { set_node_size(AlignSize(node_size() + 1)); } padding_size_ = node_size() - size; neighbors_size_ = neighbors_size(); upper_neighbors_size_ = upper_neighbors_size(); return 0; } int HnswRabitqBuilderEntity::reserve_space(size_t docs) { if (memory_quota_ > 0 && (node_size() * docs + neighbors_size_ * docs + sizeof(NeighborIndex) * docs > memory_quota_)) { return IndexError_NoMemory; } vectors_buffer_.reserve(node_size() * docs); keys_buffer_.reserve(sizeof(key_t) * docs); neighbors_buffer_.reserve(neighbors_size_ * docs); neighbors_index_.reserve(docs); return 0; } int HnswRabitqBuilderEntity::add_vector(level_t level, key_t key, const void *vec, node_id_t *id) { if (memory_quota_ > 0 && (vectors_buffer_.capacity() + keys_buffer_.capacity() + neighbors_buffer_.capacity() + upper_neighbors_buffer_.capacity() + neighbors_index_.capacity() * sizeof(NeighborIndex)) > memory_quota_) { LOG_ERROR("Add vector failed, used memory exceed quota, cur_doc=%zu", static_cast(doc_cnt())); return IndexError_NoMemory; } vectors_buffer_.append(reinterpret_cast(vec), vector_size()); vectors_buffer_.append(padding_size_, '\0'); keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); // init level 0 neighbors neighbors_buffer_.append(neighbors_size_, '\0'); neighbors_index_.emplace_back(upper_neighbors_buffer_.size(), level); // init upper layer neighbors for (level_t cur_level = 1; cur_level <= level; ++cur_level) { upper_neighbors_buffer_.append(upper_neighbors_size_, '\0'); } *id = (*mutable_doc_cnt())++; return 0; } key_t HnswRabitqBuilderEntity::get_key(node_id_t id) const { return *(reinterpret_cast(keys_buffer_.data() + id * sizeof(key_t))); } const void *HnswRabitqBuilderEntity::get_vector(node_id_t id) const { return vectors_buffer_.data() + id * node_size(); } int HnswRabitqBuilderEntity::get_vector( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector(id); block.reset((void *)vec); return 0; } int HnswRabitqBuilderEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { for (uint32_t i = 0; i < count; ++i) { vecs[i] = vectors_buffer_.data() + ids[i] * node_size(); } return 0; } int HnswRabitqBuilderEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { std::vector vecs(count); get_vector(ids, count, vecs.data()); for (uint32_t i = 0; i < count; ++i) { vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } const Neighbors HnswRabitqBuilderEntity::get_neighbors(level_t level, node_id_t id) const { const NeighborsHeader *hd = get_neighbor_header(level, id); return {hd->neighbor_cnt, hd->neighbors}; } int HnswRabitqBuilderEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); for (size_t i = 0; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } hd->neighbor_cnt = neighbors.size(); // std::cout << "id: " << id << ", neighbour, id: "; // for (size_t i = 0; i < neighbors.size(); ++i) { // if (i == neighbors.size()-1) // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // std::endl; // else // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // ", id: "; // } return 0; } void HnswRabitqBuilderEntity::add_neighbor(level_t level, node_id_t id, uint32_t /*size*/, node_id_t neighbor_id) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); hd->neighbors[hd->neighbor_cnt++] = neighbor_id; return; } int HnswRabitqBuilderEntity::dump(const IndexDumper::Pointer &dumper) { key_t *keys = reinterpret_cast(const_cast(keys_buffer_.data())); auto ret = dump_segments(dumper, keys, [&](node_id_t id) { return get_level(id); }); if (ailego_unlikely(ret < 0)) { return ret; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_builder_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { class HnswRabitqBuilderEntity : public HnswRabitqEntity { public: //! Add vector and key to hnsw entity, and local id will be saved to id virtual int add_vector(level_t level, key_t key, const void *vec, node_id_t *id) override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const override; //! Batch get vectors feature data by keys virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get the node id's neighbors on graph level const NeighborsHeader *get_neighbor_header(level_t level, node_id_t id) const { if (level == 0) { return reinterpret_cast( neighbors_buffer_.data() + neighbors_size_ * id); } else { size_t offset = neighbors_index_[id].offset; return reinterpret_cast( upper_neighbors_buffer_.data() + offset + (level - 1) * upper_neighbors_size_); } } //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Replace node id in level's neighbors virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! add a neighbor to id in graph level virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Dump the hnsw graph to dumper virtual int dump(const IndexDumper::Pointer &dumper) override; //! Cleanup the entity virtual int cleanup(void) override; public: //! Constructor HnswRabitqBuilderEntity(); //! Get the node graph level by id level_t get_level(node_id_t id) const { return neighbors_index_[id].level; } //! Init builerEntity int init(); //! reserve buffer space for documents //! @param docs number of documents int reserve_space(size_t docs); //! Set memory quota params inline void set_memory_quota(size_t memory_quota) { memory_quota_ = memory_quota; } //! Get neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get upper neighbors size inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } public: HnswRabitqBuilderEntity(const HnswRabitqBuilderEntity &) = delete; HnswRabitqBuilderEntity &operator=(const HnswRabitqBuilderEntity &) = delete; private: friend class HnswRabitqSearcherEntity; //! class internal used only struct NeighborIndex { NeighborIndex(size_t off, level_t l) : offset(off), level(l) {} uint64_t offset : 48; uint64_t level : 16; }; std::string vectors_buffer_{}; // aligned vectors std::string keys_buffer_{}; // aligned vectors std::string neighbors_buffer_{}; // level 0 neighbors buffer std::string upper_neighbors_buffer_{}; // upper layer neighbors buffer std::string sparse_data_buffer_{}; // aligned spase data buffer size_t sparse_data_offset_{0}; // // upper layer offset + level in upper_neighbors_buffer_ std::vector neighbors_index_{}; size_t memory_quota_{0UL}; size_t neighbors_size_{0U}; // level 0 neighbors size size_t upper_neighbors_size_{0U}; // level 0 neighbors size size_t padding_size_{}; // padding size for each vector element }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_chunk.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_chunk.h" #include #include #include #include #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_helper.h" #include "zvec/core/framework/index_logger.h" #include "zvec/core/framework/index_streamer.h" namespace zvec { namespace core { int HnswRabitqChunkBroker::init_storage(size_t chunk_size) { chunk_meta_.clear(); chunk_meta_.chunk_size = chunk_size; chunk_meta_.create_time = ailego::Realtime::Seconds(); stats_.set_create_time(chunk_meta_.create_time); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); //! alloc meta chunk size_t size = sizeof(HnswChunkMeta); size = (size + page_mask_) & (~page_mask_); const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return ret; } chunk_meta_segment_ = get_chunk(CHUNK_TYPE_META, kDefaultChunkSeqId); if (ailego_unlikely(!chunk_meta_segment_)) { LOG_ERROR("Get meta segment failed"); return IndexError_Runtime; } //! update meta info and write to storage chunk_meta_.chunk_cnts[CHUNK_TYPE_META] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); return IndexError_WriteData; } return 0; } int HnswRabitqChunkBroker::load_storage(size_t chunk_size) { IndexStorage::MemoryBlock data_block; size_t size = chunk_meta_segment_->read(0UL, data_block, chunk_meta_segment_->data_size()); if (size != sizeof(HnswChunkMeta)) { LOG_ERROR("Invalid hnsw meta chunk, read size=%zu chunk size=%zu", size, chunk_meta_segment_->data_size()); return IndexError_InvalidFormat; } std::memcpy(&chunk_meta_, data_block.data(), size); if (chunk_meta_.chunk_size != chunk_size) { LOG_ERROR( "Params hnsw chunk size=%zu mismatch from previous %zu " "in index", chunk_size, (size_t)chunk_meta_.chunk_size); return IndexError_Mismatch; } *stats_.mutable_check_point() = stg_->check_point(); stats_.set_revision_id(chunk_meta_.revision_id); stats_.set_update_time(chunk_meta_.update_time); stats_.set_create_time(chunk_meta_.create_time); char create_time[32]; char update_time[32]; ailego::Realtime::Gmtime(chunk_meta_.create_time, "%Y-%m-%d %H:%M:%S", create_time, sizeof(create_time)); ailego::Realtime::Gmtime(chunk_meta_.update_time, "%Y-%m-%d %H:%M:%S", update_time, sizeof(update_time)); LOG_DEBUG( "Load index, indexSize=%zu chunkSize=%zu nodeChunks=%zu " "upperNeighborChunks=%zu revisionId=%zu " "createTime=%s updateTime=%s", (size_t)chunk_meta_.total_size, (size_t)chunk_meta_.chunk_size, (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_NODE], (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_UPPER_NEIGHBOR], (size_t)chunk_meta_.revision_id, create_time, update_time); return 0; } int HnswRabitqChunkBroker::open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc) { if (ailego_unlikely(stg_)) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } stg_ = std::move(stg); if (stg_->isHugePage()) { page_mask_ = ailego::MemoryHelper::HugePageSize() - 1; } else { page_mask_ = ailego::MemoryHelper::PageSize() - 1; } check_crc_ = check_crc; max_chunks_size_ = max_index_size; dirty_ = false; const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); chunk_meta_segment_ = stg_->get(segment_id); if (!chunk_meta_segment_) { LOG_DEBUG("Create new index"); return init_storage(chunk_size); } return load_storage(chunk_size); } int HnswRabitqChunkBroker::close(void) { flush(0UL); stg_.reset(); check_crc_ = false; dirty_ = false; return 0; } int HnswRabitqChunkBroker::flush(uint64_t checkpoint) { ailego_assert_with(chunk_meta_segment_, "invalid meta segment"); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); size_t size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); } stg_->refresh(checkpoint); int ret = stg_->flush(); if (ret == 0) { (*stats_.mutable_check_point()) = checkpoint; } else { LOG_ERROR("Storage flush failed for %s", IndexError::What(ret)); } return ret; } std::pair HnswRabitqChunkBroker::alloc_chunk( int type, uint64_t seq_id, size_t size) { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); Chunk::Pointer chunk; if (ailego_unlikely(!stg_)) { LOG_ERROR("Init storage first"); return std::make_pair(IndexError_Uninitialized, chunk); } //! check exist a empty chunk with the same name chunk = get_chunk(type, seq_id); if (chunk) { if (ailego_unlikely(chunk->capacity() == size && chunk->data_size() == 0UL)) { LOG_ERROR("Exist invalid chunk size %zu, expect size %zu", chunk->capacity(), size); chunk.reset(); return std::make_pair(IndexError_Runtime, chunk); } return std::make_pair(0, chunk); } //! align to page size size = (size + page_mask_) & (~page_mask_); if (ailego_unlikely(chunk_meta_.total_size + size >= max_chunks_size_)) { LOG_ERROR("No space to new a chunk, curIndexSize=%zu allocSize=%zu", (size_t)chunk_meta_.total_size, size); return std::make_pair(IndexError_IndexFull, chunk); } std::string segment_id = make_segment_id(type, seq_id); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return std::make_pair(ret, chunk); } chunk_meta_.chunk_cnts[type] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswChunkMeta)); if (ailego_unlikely(size != sizeof(HnswChunkMeta))) { LOG_ERROR("Storage append segment failed, wsize=%zu", size); } chunk = get_chunk(type, seq_id); return std::make_pair(chunk ? 0 : IndexError_NoMemory, chunk); } Chunk::Pointer HnswRabitqChunkBroker::get_chunk(int type, uint64_t seq_id) const { std::string segment_id = make_segment_id(type, seq_id); return stg_->get(segment_id); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_chunk.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_logger.h" #include "zvec/core/framework/index_storage.h" #include "zvec/core/framework/index_streamer.h" namespace zvec { namespace core { using Chunk = IndexStorage::Segment; class HnswRabitqChunkBroker { public: typedef std::shared_ptr Pointer; enum CHUNK_TYPE { CHUNK_TYPE_HEADER = 1, CHUNK_TYPE_META = 2, CHUNK_TYPE_NODE = 3, CHUNK_TYPE_UPPER_NEIGHBOR = 4, CHUNK_TYPE_NEIGHBOR_INDEX = 5, CHUNK_TYPE_SPARSE_NODE = 6, CHUNK_TYPE_MAX = 8 }; static constexpr size_t kDefaultChunkSeqId = 0UL; HnswRabitqChunkBroker(IndexStreamer::Stats &stats) : stats_(stats) {} //! Open storage int open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc); int close(void); int flush(uint64_t checkpoint); //! alloc a new chunk with size, not thread-safe std::pair alloc_chunk(int type, uint64_t seq_id, size_t size); //! alloc a new chunk with chunk size inline std::pair alloc_chunk(int type, uint64_t seq_id) { return alloc_chunk(type, seq_id, chunk_meta_.chunk_size); } Chunk::Pointer get_chunk(int type, uint64_t seq_id) const; inline size_t get_chunk_cnt(int type) const { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); return chunk_meta_.chunk_cnts[type]; } inline bool dirty(void) const { return dirty_; } inline void mark_dirty(void) { if (!dirty_) { dirty_ = true; chunk_meta_.revision_id += 1; stats_.set_revision_id(chunk_meta_.revision_id); } } const IndexStorage::Pointer storage(void) const { return stg_; } private: HnswRabitqChunkBroker(const HnswRabitqChunkBroker &) = delete; HnswRabitqChunkBroker &operator=(const HnswRabitqChunkBroker &) = delete; struct HnswChunkMeta { HnswChunkMeta(void) { memset(this, 0, sizeof(HnswChunkMeta)); } void clear() { memset(this, 0, sizeof(HnswChunkMeta)); } uint64_t chunk_cnts[CHUNK_TYPE_MAX]; uint64_t chunk_size; // size of per chunk uint64_t total_size; // total size of allocated chunk uint64_t revision_id; // index revision uint64_t create_time; uint64_t update_time; uint64_t reserved[3]; }; static_assert(sizeof(HnswChunkMeta) % 32 == 0, "HnswChunkMeta must be aligned with 32 bytes"); //! Init the storage after open an empty index int init_storage(size_t chunk_size); //! Load index from storage int load_storage(size_t chunk_size); static inline const std::string make_segment_id(int type, uint64_t seq_id) { return "HnswT" + ailego::StringHelper::ToString(type) + "S" + ailego::StringHelper::ToString(seq_id); } private: IndexStreamer::Stats &stats_; HnswChunkMeta chunk_meta_{}; size_t page_mask_{0UL}; size_t max_chunks_size_{0UL}; IndexStorage::Pointer stg_{}; IndexStorage::Segment::Pointer chunk_meta_segment_{}; bool check_crc_{false}; bool dirty_{false}; // set as true if index is modified , the flag // will not be cleared even if flushed }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_context.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_context.h" #include #include "hnsw_rabitq_params.h" namespace zvec { namespace core { HnswRabitqContext::HnswRabitqContext(size_t dimension, const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity) : IndexContext(metric), entity_(entity), add_dc_(entity_.get(), metric, dimension) {} HnswRabitqContext::HnswRabitqContext(const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity) : IndexContext(metric), entity_(entity), add_dc_(entity_.get(), metric) {} HnswRabitqContext::~HnswRabitqContext() { visit_filter_.destroy(); } int HnswRabitqContext::init(ContextType type) { int ret; uint32_t doc_cnt; type_ = type; switch (type) { case kBuilderContext: ret = visit_filter_.init(VisitFilter::ByteMap, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); update_heap_.limit(entity_->l0_neighbor_cnt() + 1); break; case kSearcherContext: ret = visit_filter_.init(filter_mode_, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); break; case kStreamerContext: // maxScanNum is unknown if inited from streamer, so the docCnt may // change. we need to compute maxScanNum by scan ratio, and preserve // max_doc_cnt space from visit filter doc_cnt = entity_->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); ret = visit_filter_.init(filter_mode_, reserve_max_doc_cnt_, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } update_heap_.limit(entity_->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); check_need_adjuct_ctx(); break; default: LOG_ERROR("Init context failed"); return IndexError_Runtime; } return 0; } int HnswRabitqContext::update(const ailego::Params ¶ms) { auto update_visit_filter_param = [&]() { bool need_update = false; std::string p; switch (type_) { case kSearcherContext: p = PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE; break; case kStreamerContext: p = PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE; break; } if (params.has(p)) { bool bf_enabled; params.get(p, &bf_enabled); if (bf_enabled ^ (filter_mode_ == VisitFilter::BloomFilter)) { need_update = true; filter_mode_ = bf_enabled ? VisitFilter::BloomFilter : VisitFilter::ByteMap; } } float prob = negative_probability_; p.clear(); switch (type_) { case kSearcherContext: p = PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; case kStreamerContext: p = PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; } params.get(p, &prob); if (filter_mode_ == VisitFilter::BloomFilter && std::abs(prob - negative_probability_) > 1e-6) { need_update = true; } if (need_update) { visit_filter_.destroy(); int max_doc_cnt = 0; if (type_ == kSearcherContext) { max_doc_cnt = entity_->doc_cnt(); } else { max_doc_cnt = reserve_max_doc_cnt_; } int ret = visit_filter_.init(filter_mode_, max_doc_cnt, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } } return 0; }; switch (type_) { case kSearcherContext: if (params.has(PARAM_HNSW_RABITQ_SEARCHER_EF)) { params.get(PARAM_HNSW_RABITQ_SEARCHER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } if (params.has(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO)) { params.get(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); max_scan_num_ = static_cast(max_scan_ratio_ * entity_->doc_cnt()); max_scan_num_ = std::max(10000U, max_scan_num_); } if (params.has(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); case kStreamerContext: if (params.has(PARAM_HNSW_RABITQ_STREAMER_EF)) { params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (params.has(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); default: LOG_ERROR("update context failed, type=%zu", static_cast(type_)); return IndexError_Runtime; } } int HnswRabitqContext::update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity, uint32_t magic_num) { uint32_t doc_cnt; if (ailego_unlikely(type != type_)) { LOG_ERROR( "HnswRabitqContext doesn't support shared by different type, " "src=%u dst=%u", type_, type); return IndexError_Unsupported; } magic_ = kInvalidMgic; // TODO: support change filter mode? switch (type) { case kBuilderContext: LOG_ERROR("BuildContext doesn't support update"); return IndexError_NotImplemented; case kSearcherContext: if (!visit_filter_.reset(entity->doc_cnt(), max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; case kStreamerContext: doc_cnt = entity->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); if (!visit_filter_.reset(reserve_max_doc_cnt_, max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } update_heap_.limit(entity->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; default: LOG_ERROR("update context failed"); return IndexError_Runtime; } entity_ = entity; dc().update(entity_.get(), metric, meta.dimension()); magic_ = magic_num; level_topks_.clear(); return 0; } void HnswRabitqContext::fill_random_to_topk_full(void) { static std::mt19937 mt( std::chrono::system_clock::now().time_since_epoch().count()); std::uniform_int_distribution dt(0, entity_->doc_cnt() - 1); std::function gen; node_id_t seqid; std::function myfilter = [](node_id_t) { return false; }; if (this->filter().is_valid()) { myfilter = [&](node_id_t id) { return this->filter()(entity_->get_key(id)); }; } if (topk_heap_.limit() < entity_->doc_cnt() / 2) { gen = [&](void) { return dt(mt); }; } else { // If topk limit is big value, gen sequential id from an random initial seqid = dt(mt); gen = [&](void) { seqid = seqid == (entity_->doc_cnt() - 1) ? 0 : (seqid + 1); return seqid; }; } for (size_t i = 0; !topk_heap_.full() && i < entity_->doc_cnt(); ++i) { const auto id = gen(); if (!visit_filter_.visited(id) && !myfilter(id)) { visit_filter_.set_visited(id); topk_heap_.emplace(id, dc().dist(id)); } } return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "utility/visit_filter.h" #include "zvec/core/framework/index_context.h" #include "zvec/core/framework/index_provider.h" #include "hnsw_rabitq_dist_calculator.h" #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { class HnswRabitqContext : public IndexContext { public: //! Index Context Pointer typedef std::unique_ptr Pointer; enum ContextType { kUnknownContext = 0, kSearcherContext = 1, kBuilderContext = 2, kStreamerContext = 3 }; //! Construct HnswRabitqContext(size_t dimension, const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity); //! Construct HnswRabitqContext(const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity); //! Destructor virtual ~HnswRabitqContext(); public: //! Set topk of search result virtual void set_topk(uint32_t val) override { topk_ = val; topk_heap_.limit(std::max(val, ef_)); } //! Retrieve search result virtual const IndexDocumentList &result(void) const override { return results_[0]; } //! Retrieve search result virtual const IndexDocumentList &result(size_t idx) const override { return results_[idx]; } //! Retrieve result object for output virtual IndexDocumentList *mutable_result(size_t idx) override { ailego_assert_with(idx < results_.size(), "invalid idx"); return &results_[idx]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result( size_t idx) const override { return group_results_[idx]; } virtual uint32_t magic(void) const override { return magic_; } //! Set mode of debug virtual void set_debug_mode(bool enable) override { debug_mode_ = enable; } //! Retrieve mode of debug virtual bool debug_mode(void) const override { return this->debugging(); } //! Retrieve string of debug virtual std::string debug_string(void) const override { char buf[4096]; size_t size = snprintf( buf, sizeof(buf), "scan_cnt=%zu,get_vector_cnt=%u,get_neighbors_cnt=%u,dup_node=%u", get_scan_num(), stats_get_vector_cnt_, stats_get_neighbors_cnt_, stats_visit_dup_cnt_); return std::string(buf, size); } //! Update the parameters of context virtual int update(const ailego::Params ¶ms) override; public: //! Init context int init(ContextType type); //! Update context, the context may be shared by different searcher/streamer int update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, const HnswRabitqEntity::Pointer &entity, uint32_t magic_num); inline const HnswRabitqEntity &get_entity() const { return *entity_; } inline void resize_results(size_t size) { if (group_by_search()) { group_results_.resize(size); } else { results_.resize(size); } } inline void topk_to_result() { return topk_to_result(0); } //! Construct result from topk heap, result will be normalized inline void topk_to_result(uint32_t idx) { if (group_by_search()) { topk_to_group_result(idx); } else { topk_to_single_result(idx); } } inline void topk_to_single_result(uint32_t idx) { if (force_padding_topk_ && !topk_heap_.full() && topk_heap_.size() < entity_->doc_cnt()) { this->fill_random_to_topk_full(); } if (ailego_unlikely(topk_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(topk_heap_.size())); topk_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = topk_heap_[i].second; if (score.est_dist > this->threshold()) { break; } node_id_t id = topk_heap_[i].first; if (fetch_vector_) { results_[idx].emplace_back(entity_->get_key(id), score.est_dist, id, entity_->get_vector(id)); } else { results_[idx].emplace_back(entity_->get_key(id), score.est_dist, id); } } return; } //! Construct result from topk heap, result will be normalized inline void topk_to_group_result(uint32_t idx) { ailego_assert_with(idx < group_results_.size(), "invalid idx"); group_results_[idx].clear(); std::vector> group_topk_list; std::vector> best_score_in_groups; for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; heap.sort(); if (heap.size() > 0) { ResultRecord best_score = heap[0].second; best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; group_topk_list.emplace_back( std::make_pair(group_id, group_topk_heaps_[group_id])); } group_results_[idx].resize(group_topk_list.size()); for (uint32_t i = 0; i < group_topk_list.size(); ++i) { const std::string &group_id = group_topk_list[i].first; group_results_[idx][i].set_group_id(group_id); uint32_t size = std::min( group_topk_, static_cast(group_topk_list[i].second.size())); for (uint32_t j = 0; j < size; ++j) { auto score = group_topk_list[i].second[j].second; if (score > this->threshold()) { break; } node_id_t id = group_topk_list[i].second[j].first; if (fetch_vector_) { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score.est_dist, id, entity_->get_vector(id)); } else { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score.est_dist, id); } } } } inline void reset_query(const void *query) { if (auto query_preprocess_func = index_metric_->get_query_preprocess_func(); query_preprocess_func != nullptr) { size_t dim = dc().dimension(); preprocess_buffer_.resize(dim); memcpy(preprocess_buffer_.data(), query, dim); query_preprocess_func(preprocess_buffer_.data(), dim); query = preprocess_buffer_.data(); } dc().reset_query(query); dc().clear_compare_cnt(); query_ = query; } inline HnswRabitqAddDistCalculator &dist_calculator() { return dc(); } inline TopkHeap &topk_heap() { return topk_heap_; } inline TopkHeap &update_heap() { return update_heap_; } inline VisitFilter &visit_filter() { return visit_filter_; } inline CandidateHeap &candidates() { return candidates_; } inline void set_max_scan_num(uint32_t max_scan_num) { max_scan_num_ = max_scan_num; } inline void set_max_scan_limit(uint32_t max_scan_limit) { max_scan_limit_ = max_scan_limit; } inline void set_min_scan_limit(uint32_t min_scan_limit) { min_scan_limit_ = min_scan_limit; } inline void set_ef(uint32_t v) { ef_ = v; } inline void set_filter_mode(uint32_t v) { filter_mode_ = v; } inline void set_filter_negative_probability(float v) { negative_probability_ = v; } inline void set_max_scan_ratio(float v) { max_scan_ratio_ = v; } virtual void set_magic(uint32_t v) { magic_ = v; } virtual void set_force_padding_topk(bool v) { force_padding_topk_ = v; } void set_bruteforce_threshold(uint32_t v) override { bruteforce_threshold_ = v; } inline uint32_t get_bruteforce_threshold() const { return bruteforce_threshold_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector() const override { return fetch_vector_; } //! Reset context void reset(void) override { set_filter(nullptr); reset_threshold(); set_fetch_vector(false); set_group_params(0, 0); reset_group_by(); } inline std::map &group_topk_heaps() { return group_topk_heaps_; } inline TopkHeap &level_topk(int level) { if (ailego_unlikely(level_topks_.size() <= static_cast(level))) { int cur_level = level_topks_.size(); level_topks_.resize(level + 1); for (; cur_level <= level; ++cur_level) { size_t heap_size = std::max(entity_->neighbor_cnt(cur_level), entity_->ef_construction()); level_topks_[cur_level].clear(); level_topks_[cur_level].limit(heap_size); } } return level_topks_[level]; } inline void check_need_adjuct_ctx(void) { check_need_adjuct_ctx(entity_->doc_cnt()); } inline size_t compute_reserve_cnt(uint32_t cur_doc) const { if (cur_doc > kMaxReserveDocCnt) { return kMaxReserveDocCnt; } else if (cur_doc < kMinReserveDocCnt) { return kMinReserveDocCnt; } return cur_doc; } //! candidates heap and visitfilter need to resize as doc cnt growing up inline void check_need_adjuct_ctx(uint32_t doc_cnt) { if (ailego_unlikely(doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_)) { while (doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_) { reserve_max_doc_cnt_ = reserve_max_doc_cnt_ + compute_reserve_cnt(reserve_max_doc_cnt_); } uint32_t max_scan_cnt = compute_max_scan_num(reserve_max_doc_cnt_); max_scan_num_ = max_scan_cnt; visit_filter_.reset(reserve_max_doc_cnt_, max_scan_cnt); candidates_.clear(); candidates_.limit(max_scan_num_); } } inline uint32_t compute_max_scan_num(uint32_t max_doc_cnt) const { uint32_t max_scan = max_doc_cnt * max_scan_ratio_; if (max_scan < min_scan_limit_) { max_scan = min_scan_limit_; } else if (max_scan > max_scan_limit_) { max_scan = max_scan_limit_; } return max_scan; } inline size_t get_scan_num() const { return dc().compare_cnt(); } inline uint64_t reach_scan_limit() const { return dc().compare_cnt() >= max_scan_num_; } inline bool error() const { return dc().error(); } inline void clear() { add_dc_.clear(); if (ailego_unlikely(this->debugging())) { stats_get_neighbors_cnt_ = 0u; stats_get_vector_cnt_ = 0u; stats_visit_dup_cnt_ = 0u; } // do not clear results_ for the next query will need it for (auto &it : results_) { it.clear(); } } uint32_t *mutable_stats_get_neighbors() { return &stats_get_neighbors_cnt_; } uint32_t *mutable_stats_get_vector() { return &stats_get_vector_cnt_; } uint32_t *mutable_stats_visit_dup_cnt() { return &stats_visit_dup_cnt_; } inline bool debugging(void) const { return debug_mode_; } inline void update_dist_caculator_distance( const IndexMetric::MatrixDistance &distance, const IndexMetric::MatrixBatchDistance &batch_distance) { dc().update_distance(distance, batch_distance); } //! Get topk inline uint32_t topk() const override { return topk_; } //! Get group topk inline uint32_t group_topk() const { return group_topk_; } //! Get group num inline uint32_t group_num() const { return group_num_; } //! Get if group by search inline bool group_by_search() { return group_num_ > 0; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; topk_ = group_topk_ * group_num_; topk_heap_.limit(std::max(topk_, ef_)); group_topk_heaps_.clear(); } void set_provider(IndexProvider::Pointer provider) { add_dc_.set_provider(std::move(provider)); } const void *query() const { return query_; } private: inline HnswRabitqAddDistCalculator &dc() { return add_dc_; } inline const HnswRabitqAddDistCalculator &dc() const { return add_dc_; } private: // Filling random nodes if topk not full void fill_random_to_topk_full(void); constexpr static uint32_t kTriggerReserveCnt = 4096UL; constexpr static uint32_t kMinReserveDocCnt = 4096UL; constexpr static uint32_t kMaxReserveDocCnt = 128 * 1024UL; constexpr static uint32_t kInvalidMgic = -1U; private: HnswRabitqEntity::Pointer entity_; HnswRabitqAddDistCalculator add_dc_; IndexMetric::Pointer metric_; bool debug_mode_{false}; bool force_padding_topk_{false}; uint32_t max_scan_num_{0}; uint32_t max_scan_limit_{0}; uint32_t min_scan_limit_{0}; uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt}; uint32_t topk_{0}; uint32_t group_topk_{0}; uint32_t filter_mode_{VisitFilter::ByteMap}; float negative_probability_{HnswRabitqEntity::kDefaultBFNegativeProbability}; uint32_t ef_{HnswRabitqEntity::kDefaultEf}; float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; uint32_t magic_{0U}; std::vector results_{}; std::vector group_results_{}; TopkHeap topk_heap_{}; TopkHeap update_heap_{}; std::vector level_topks_{}; CandidateHeap candidates_{}; VisitFilter visit_filter_{}; uint32_t bruteforce_threshold_{}; bool fetch_vector_{false}; uint32_t group_num_{0}; std::map group_topk_heaps_{}; uint32_t type_{kUnknownContext}; //! debug stats info uint32_t stats_get_neighbors_cnt_{0u}; uint32_t stats_get_vector_cnt_{0u}; uint32_t stats_visit_dup_cnt_{0u}; std::string preprocess_buffer_; const void *query_{nullptr}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_dist_calculator.cc ================================================ // Copyright 2025-present the centaurdb project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License #include "core/algorithm/hnsw_rabitq/hnsw_rabitq_dist_calculator.h" #include "zvec/core/framework/index_error.h" namespace zvec::core { int HnswRabitqAddDistCalculator::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { for (uint32_t i = 0; i < count; ++i) { const node_id_t id = ids[i]; key_t key = entity_->get_key(id); if (key == kInvalidKey) { return IndexError_NoExist; } IndexStorage::MemoryBlock block; int ret = provider_->get_vector(key, block); if (ret != 0) { return ret; } vec_blocks.push_back(std::move(block)); } return 0; } } // namespace zvec::core ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_dist_calculator.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "zvec/core/framework/index_meta.h" #include "zvec/core/framework/index_metric.h" #include "zvec/core/framework/index_provider.h" #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { //! HnswRabitqAddDistCalculator is only used for index construction class HnswRabitqAddDistCalculator { public: typedef std::shared_ptr Pointer; public: enum DistType { DIST_NONE = 0, DIST_DENSE = 1, DIST_HYBRID = 2, DIST_SPARSE = 3 }; public: //! Constructor HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(nullptr), dim_(dim), compare_cnt_(0) {} //! Constructor HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim, const void *query) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(query), dim_(dim), compare_cnt_(0) {} //! Constructor HnswRabitqAddDistCalculator(const HnswRabitqEntity *entity, const IndexMetric::Pointer &metric) : entity_(entity), distance_(metric->distance()), batch_distance_(metric->batch_distance()), query_(nullptr), dim_(0), compare_cnt_(0) {} void update(const HnswRabitqEntity *entity, const IndexMetric::Pointer &metric) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); } void update(const HnswRabitqEntity *entity, const IndexMetric::Pointer &metric, uint32_t dim) { entity_ = entity; distance_ = metric->distance(); batch_distance_ = metric->batch_distance(); dim_ = dim; } inline void update_distance( const IndexMetric::MatrixDistance &distance, const IndexMetric::MatrixBatchDistance &batch_distance) { distance_ = distance; batch_distance_ = batch_distance; } //! Reset query vector data inline void reset_query(const void *query) { error_ = false; query_ = query; } //! Returns distance inline dist_t dist(const void *vec_lhs, const void *vec_rhs) { if (ailego_unlikely(vec_lhs == nullptr || vec_rhs == nullptr)) { LOG_ERROR("Nullptr of dense vector"); error_ = true; return 0.0f; } float score{0.0f}; distance_(vec_lhs, vec_rhs, dim_, &score); return score; } //! Returns distance between query and vec. inline dist_t dist(const void *vec) { compare_cnt_++; return dist(vec, query_); } //! Return distance between query and node id. inline dist_t dist(node_id_t id) { compare_cnt_++; const void *feat = get_vector(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; return 0.0f; } return dist(feat, query_); } //! Return dist node lhs between node rhs inline dist_t dist(node_id_t lhs, node_id_t rhs) { compare_cnt_++; const void *feat = get_vector(lhs); const void *query = get_vector(rhs); if (ailego_unlikely(feat == nullptr || query == nullptr)) { LOG_ERROR("Get nullptr vector"); error_ = true; return 0.0f; } return dist(feat, query); } dist_t operator()(const void *vec) { return dist(vec); } dist_t operator()(id_t i) { return dist(i); } dist_t operator()(id_t lhs, id_t rhs) { return dist(lhs, rhs); } void batch_dist(const void **vecs, size_t num, dist_t *distances) { compare_cnt_++; batch_distance_(vecs, query_, num, dim_, distances); } inline dist_t batch_dist(node_id_t id) { compare_cnt_++; const void *feat = get_vector(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; return 0.0f; } dist_t score = 0; batch_distance_(&feat, query_, 1, dim_, &score); return score; } inline void clear() { compare_cnt_ = 0; error_ = false; } inline void clear_compare_cnt() { compare_cnt_ = 0; } inline bool error() const { return error_; } //! Get distances compute times inline uint32_t compare_cnt() const { return compare_cnt_; } inline uint32_t dimension() const { return dim_; } void set_provider(IndexProvider::Pointer provider) { provider_ = std::move(provider); } int get_vector(const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const; const void *get_vector(node_id_t id) const { key_t key = entity_->get_key(id); if (key == kInvalidKey) { return nullptr; } return provider_->get_vector(key); } private: HnswRabitqAddDistCalculator(const HnswRabitqAddDistCalculator &) = delete; HnswRabitqAddDistCalculator &operator=(const HnswRabitqAddDistCalculator &) = delete; private: const HnswRabitqEntity *entity_; IndexMetric::MatrixDistance distance_; IndexMetric::MatrixBatchDistance batch_distance_; const void *query_; uint32_t dim_; uint32_t compare_cnt_; // record distance compute times uint32_t compare_cnt_batch_; // record batch distance compute time bool error_{false}; // get raw vector IndexProvider::Pointer provider_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_entity.h" #include #include "utility/sparse_utility.h" #include "zvec/core/framework/index_stats.h" namespace zvec { namespace core { const std::string HnswRabitqEntity::kGraphHeaderSegmentId = "graph.header"; const std::string HnswRabitqEntity::kGraphFeaturesSegmentId = "graph.features"; const std::string HnswRabitqEntity::kGraphKeysSegmentId = "graph.keys"; const std::string HnswRabitqEntity::kGraphNeighborsSegmentId = "graph.neighbors"; const std::string HnswRabitqEntity::kGraphOffsetsSegmentId = "graph.offsets"; const std::string HnswRabitqEntity::kGraphMappingSegmentId = "graph.mapping"; const std::string HnswRabitqEntity::kHnswHeaderSegmentId = "hnsw.header"; const std::string HnswRabitqEntity::kHnswNeighborsSegmentId = "hnsw.neighbors"; const std::string HnswRabitqEntity::kHnswOffsetsSegmentId = "hnsw.offsets"; void HnswRabitqEntity::update_rabitq_params_and_vector_size( uint32_t dimension) { uint32_t padded_dim = ((dimension + 63) / 64) * 64; header_.graph.padded_dim = padded_dim; // BinDataMap layout: bin_code (padded_dim/8) + f_add + f_rescale + f_error header_.graph.size_bin_data = rabitqlib::BinDataMap::data_bytes(padded_dim); // ExDataMap layout: ex_code (padded_dim*ex_bits/8) + f_add_ex + f_rescale_ex header_.graph.size_ex_data = rabitqlib::ExDataMap::data_bytes( padded_dim, header_.graph.ex_bits); // quantized vector format: cluster_id + bin_data + ex_data header_.graph.vector_size = sizeof(uint32_t) + size_bin_data() + size_ex_data(); } int HnswRabitqEntity::CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size) { *padding_size = AlignSize(data_size) - data_size; if (*padding_size == 0) { return 0; } std::string padding(*padding_size, '\0'); if (dumper->write(padding.data(), *padding_size) != *padding_size) { LOG_ERROR("Append padding failed, size %zu", *padding_size); return IndexError_WriteData; } return 0; } int64_t HnswRabitqEntity::dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const { size_t len = dumper->write(data, size); if (len != size) { LOG_ERROR("Dump segment %s data failed, expect: %zu, actual: %zu", segment_id.c_str(), size, len); return IndexError_WriteData; } size_t padding_size = AlignSize(size) - size; if (padding_size > 0) { std::string padding(padding_size, '\0'); if (dumper->write(padding.data(), padding_size) != padding_size) { LOG_ERROR("Append padding failed, size %zu", padding_size); return IndexError_WriteData; } } uint32_t crc = ailego::Crc32c::Hash(data, size); int ret = dumper->append(segment_id, size, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); return ret; } return len + padding_size; } int64_t HnswRabitqEntity::dump_header(const IndexDumper::Pointer &dumper, const HNSWHeader &hd) const { //! dump basic graph header. header is aligned and does not need padding int64_t graph_hd_size = dump_segment(dumper, kGraphHeaderSegmentId, &hd.graph, hd.graph.size); if (graph_hd_size < 0) { return graph_hd_size; } //! dump basic graph header. header is aligned and does not need padding int64_t hnsw_hd_size = dump_segment(dumper, kHnswHeaderSegmentId, &hd.hnsw, hd.hnsw.size); if (hnsw_hd_size < 0) { return hnsw_hd_size; } return graph_hd_size + hnsw_hd_size; } void HnswRabitqEntity::reshuffle_vectors( const std::function & /*get_level*/, std::vector * /*n2o_mapping*/, std::vector * /*o2n_mapping*/, key_t * /*keys*/) const { // TODO return; } int64_t HnswRabitqEntity::dump_mapping_segment( const IndexDumper::Pointer &dumper, const key_t *keys) const { std::vector mapping(doc_cnt()); std::iota(mapping.begin(), mapping.end(), 0U); std::sort(mapping.begin(), mapping.end(), [&](node_id_t i, node_id_t j) { return keys[i] < keys[j]; }); size_t size = mapping.size() * sizeof(node_id_t); return dump_segment(dumper, kGraphMappingSegmentId, mapping.data(), size); } int64_t HnswRabitqEntity::dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const { HNSWHeader dump_hd(header()); dump_hd.graph.node_size = AlignSize(vector_size()); std::vector n2o_mapping; // map new id to origin id std::vector o2n_mapping; // map origin id to new id reshuffle_vectors(get_level, &n2o_mapping, &o2n_mapping, keys); if (!o2n_mapping.empty()) { dump_hd.hnsw.entry_point = o2n_mapping[entry_point()]; } //! Dump header int64_t hd_size = dump_header(dumper, dump_hd); if (hd_size < 0) { return hd_size; } //! Dump vectors int64_t vecs_size = dump_vectors(dumper, n2o_mapping); if (vecs_size < 0) { return vecs_size; } //! Dump neighbors auto neighbors_size = dump_neighbors(dumper, get_level, n2o_mapping, o2n_mapping); if (neighbors_size < 0) { return neighbors_size; } //! free memory n2o_mapping = std::vector(); o2n_mapping = std::vector(); //! Dump keys size_t key_segment_size = doc_cnt() * sizeof(key_t); int64_t keys_size = dump_segment(dumper, kGraphKeysSegmentId, keys, key_segment_size); if (keys_size < 0) { return keys_size; } //! Dump mapping int64_t mapping_size = dump_mapping_segment(dumper, keys); if (mapping_size < 0) { return mapping_size; } return hd_size + keys_size + vecs_size + neighbors_size + mapping_size; } int64_t HnswRabitqEntity::dump_vectors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const { size_t vector_dump_size = vector_size(); size_t padding_size = AlignSize(vector_dump_size) - vector_dump_size; char padding[padding_size]; memset(padding, 0, sizeof(padding)); const void *data = nullptr; uint32_t crc = 0U; size_t vecs_size = 0UL; //! dump vectors for (node_id_t id = 0; id < doc_cnt(); ++id) { data = get_vector(reorder_mapping.empty() ? id : reorder_mapping[id]); if (ailego_unlikely(!data)) { return IndexError_ReadData; } size_t len = dumper->write(data, vector_size()); if (len != vector_size()) { LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, vector_size()); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(data, vector_size(), crc); vecs_size += vector_size(); if (padding_size == 0) { continue; } len = dumper->write(padding, padding_size); if (len != padding_size) { LOG_ERROR("Dump vectors failed, write=%zu expect=%zu", len, padding_size); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(padding, padding_size, crc); vecs_size += padding_size; } int ret = dumper->append(kGraphFeaturesSegmentId, vecs_size, 0UL, crc); if (ret != 0) { LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); return ret; } return vecs_size; } int64_t HnswRabitqEntity::dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector graph_meta; graph_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; node_id_t mapping[l0_neighbor_cnt()]; uint32_t min_neighbor_count = 10000; uint32_t max_neighbor_count = 0; size_t sum_neighbor_count = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { const Neighbors neighbors = get_neighbors(0, reorder_mapping.empty() ? id : reorder_mapping[id]); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= l0_neighbor_cnt(), "invalid neighbors"); uint32_t neighbor_count = neighbors.size(); if (neighbor_count < min_neighbor_count) { min_neighbor_count = neighbor_count; } if (neighbor_count > max_neighbor_count) { max_neighbor_count = neighbor_count; } sum_neighbor_count += neighbor_count; graph_meta.emplace_back(offset, neighbor_count); size_t size = neighbors.size() * sizeof(node_id_t); const node_id_t *data = &neighbors[0]; if (!neighbor_mapping.empty()) { for (node_id_t i = 0; i < neighbors.size(); ++i) { mapping[i] = neighbor_mapping[neighbors[i]]; } data = mapping; } if (dumper->write(data, size) != size) { LOG_ERROR("Dump graph neighbor id=%zu failed, size %zu", static_cast(id), size); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(data, size, crc); offset += size; } uint32_t average_neighbor_count = 0; if (doc_cnt() > 0) { average_neighbor_count = sum_neighbor_count / doc_cnt(); } LOG_INFO( "Dump hnsw graph: min_neighbor_count[%u] max_neighbor_count[%u] " "average_neighbor_count[%u]", min_neighbor_count, max_neighbor_count, average_neighbor_count); size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kGraphNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kGraphNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kGraphOffsetsSegmentId, graph_meta.data(), graph_meta.size() * sizeof(GraphNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } int64_t HnswRabitqEntity::dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector hnsw_meta; hnsw_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; node_id_t buffer[upper_neighbor_cnt() + 1]; for (node_id_t id = 0; id < doc_cnt(); ++id) { node_id_t new_id = reorder_mapping.empty() ? id : reorder_mapping[id]; auto level = get_level(new_id); if (level == 0) { hnsw_meta.emplace_back(0U, 0U); continue; } hnsw_meta.emplace_back(offset, level); ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); for (level_t cur_level = 1; cur_level <= level; ++cur_level) { const Neighbors neighbors = get_neighbors(cur_level, new_id); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= neighbor_cnt(cur_level), "invalid neighbors"); memset(buffer, 0, sizeof(buffer)); buffer[0] = neighbors.size(); if (neighbor_mapping.empty()) { memcpy(&buffer[1], &neighbors[0], neighbors.size() * sizeof(node_id_t)); } else { for (node_id_t i = 0; i < neighbors.size(); ++i) { buffer[i + 1] = neighbor_mapping[neighbors[i]]; } } if (dumper->write(buffer, sizeof(buffer)) != sizeof(buffer)) { LOG_ERROR("Dump graph neighbor id=%zu failed, size %zu", static_cast(id), sizeof(buffer)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(buffer, sizeof(buffer), crc); offset += sizeof(buffer); } } size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kHnswNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kHnswNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kHnswOffsetsSegmentId, hnsw_meta.data(), hnsw_meta.size() * sizeof(HnswNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include "zvec/core/framework/index_dumper.h" #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_storage.h" namespace zvec { namespace core { using node_id_t = uint32_t; using key_t = uint64_t; using level_t = int32_t; using dist_t = float; struct EstimateRecord { float ip_x0_qr; float est_dist; float low_dist; bool operator<(const EstimateRecord &other) const { return this->est_dist < other.est_dist; } }; struct ResultRecord { float est_dist; float low_dist; ResultRecord() : est_dist(0.0f), low_dist(0.0f) {} ResultRecord(float dist) : est_dist(dist), low_dist(dist) {} explicit ResultRecord(const EstimateRecord &other) : est_dist(other.est_dist), low_dist(other.low_dist) {} ResultRecord(float est_dist, float low_dist) : est_dist(est_dist), low_dist(low_dist) {} bool operator<(const ResultRecord &other) const { return this->est_dist < other.est_dist; } bool operator<=(const ResultRecord &other) const { return this->est_dist <= other.est_dist; } bool operator>(const ResultRecord &other) const { return this->est_dist > other.est_dist; } }; using TopkHeap = ailego::KeyValueHeap; using CandidateHeap = ailego::KeyValueHeap>; constexpr node_id_t kInvalidNodeId = static_cast(-1); constexpr key_t kInvalidKey = static_cast(-1); class DistCalculator; struct GraphHeader { uint32_t size; uint32_t version; uint32_t graph_type; uint32_t doc_count; uint32_t vector_size; uint32_t node_size; uint32_t l0_neighbor_count; uint32_t prune_type; uint32_t prune_neighbor_count; uint32_t ef_construction; uint32_t options; uint32_t min_neighbor_count; uint32_t padded_dim; uint32_t size_bin_data; uint32_t size_ex_data; uint8_t ex_bits; uint8_t reserved_[4067]; }; static_assert(sizeof(GraphHeader) % 32 == 0, "GraphHeader must be aligned with 32 bytes"); //! Hnsw upper neighbor header struct HnswHeader { uint32_t size; // header size uint32_t revision; // current total docs of the graph uint32_t upper_neighbor_count; uint32_t ef_construction; uint32_t scaling_factor; uint32_t max_level; uint32_t entry_point; uint32_t options; uint8_t reserved_[30]; }; static_assert(sizeof(HnswHeader) % 32 == 0, "GraphHeader must be aligned with 32 bytes"); //! Hnsw common header and upper neighbor header struct HNSWHeader { HNSWHeader() { clear(); } HNSWHeader(const HNSWHeader &header) { memcpy(this, &header, sizeof(header)); } HNSWHeader &operator=(const HNSWHeader &header) { memcpy(this, &header, sizeof(header)); return *this; } //! Reset state to zero, and the params is untouched void inline reset() { graph.doc_count = 0U; hnsw.entry_point = kInvalidNodeId; hnsw.max_level = 0; } //! Clear all fields to init value void inline clear() { memset(this, 0, sizeof(HNSWHeader)); hnsw.entry_point = kInvalidNodeId; graph.size = sizeof(GraphHeader); hnsw.size = sizeof(HnswHeader); } size_t l0_neighbor_cnt() const { return graph.l0_neighbor_count; } size_t upper_neighbor_cnt() const { return hnsw.upper_neighbor_count; } size_t vector_size() const { return graph.vector_size; } uint8_t ex_bits() const { return graph.ex_bits; } uint32_t padded_dim() const { return graph.padded_dim; } size_t ef_construction() const { return graph.ef_construction; } size_t scaling_factor() const { return hnsw.scaling_factor; } size_t neighbor_prune_cnt() const { return graph.prune_neighbor_count; } node_id_t entry_point() const { return hnsw.entry_point; } node_id_t doc_cnt() const { return graph.doc_count; } GraphHeader graph; HnswHeader hnsw; }; struct NeighborsHeader { uint32_t neighbor_cnt; node_id_t neighbors[0]; }; struct Neighbors { Neighbors() : cnt{0}, data{nullptr} {} Neighbors(uint32_t cnt_in, const node_id_t *data_in) : cnt{cnt_in}, data{data_in} {} Neighbors(IndexStorage::MemoryBlock &&mem_block) : neighbor_block{std::move(mem_block)} { auto hd = reinterpret_cast(neighbor_block.data()); cnt = hd->neighbor_cnt; data = hd->neighbors; } size_t size(void) const { return cnt; } const node_id_t &operator[](size_t idx) const { return data[idx]; } uint32_t cnt; const node_id_t *data; IndexStorage::MemoryBlock neighbor_block; }; //! level 0 neighbors offset struct GraphNeighborMeta { GraphNeighborMeta(size_t o, size_t cnt) : offset(o), neighbor_cnt(cnt) {} uint64_t offset : 48; uint64_t neighbor_cnt : 16; }; //! hnsw upper neighbors meta struct HnswNeighborMeta { HnswNeighborMeta(size_t o, size_t l) : offset(o), level(l) {} uint64_t offset : 48; // offset = idx * upper neighors size uint64_t level : 16; }; class HnswRabitqEntity { public: //! Constructor HnswRabitqEntity() {} //! Constructor HnswRabitqEntity(const HNSWHeader &hd) { header_ = hd; } //! Destructor virtual ~HnswRabitqEntity() {} //! HnswRabitqEntity Pointerd; typedef std::shared_ptr Pointer; //! Get max neighbor size of graph level inline size_t neighbor_cnt(level_t level) const { return level == 0 ? header_.graph.l0_neighbor_count : header_.hnsw.upper_neighbor_count; } //! get max neighbor size of graph level 0 inline size_t l0_neighbor_cnt() const { return header_.graph.l0_neighbor_count; } //! get min neighbor size of graph inline size_t min_neighbor_cnt() const { return header_.graph.min_neighbor_count; } //! get upper neighbor size of graph level other than 0 inline size_t upper_neighbor_cnt() const { return header_.hnsw.upper_neighbor_count; } //! Get current total doc of the hnsw graph inline node_id_t *mutable_doc_cnt() { return &header_.graph.doc_count; } inline node_id_t doc_cnt() const { return header_.graph.doc_count; } //! Get hnsw graph scaling params inline size_t scaling_factor() const { return header_.hnsw.scaling_factor; } //! Get prune_size inline size_t prune_cnt() const { return header_.graph.prune_neighbor_count; } //! Current entity of top level graph inline node_id_t entry_point() const { return header_.hnsw.entry_point; } //! Current max graph level inline level_t cur_max_level() const { return header_.hnsw.max_level; } //! Retrieve index vector size size_t vector_size() const { return header_.graph.vector_size; } //! Retrieve node size size_t node_size() const { return header_.graph.node_size; } //! Retrieve ef constuction size_t ef_construction() const { return header_.graph.ef_construction; } uint8_t ex_bits() const { return header_.graph.ex_bits; } uint32_t padded_dim() const { return header_.graph.padded_dim; } uint32_t size_bin_data() const { return header_.graph.size_bin_data; } uint32_t size_ex_data() const { return header_.graph.size_ex_data; } void update_rabitq_params_and_vector_size(uint32_t dimension); void set_ex_bits(uint8_t ex_bits) { header_.graph.ex_bits = ex_bits; } void set_prune_cnt(size_t v) { header_.graph.prune_neighbor_count = v; } void set_scaling_factor(size_t val) { header_.hnsw.scaling_factor = val; } void set_l0_neighbor_cnt(size_t cnt) { header_.graph.l0_neighbor_count = cnt; } void set_min_neighbor_cnt(size_t cnt) { header_.graph.min_neighbor_count = cnt; } void set_upper_neighbor_cnt(size_t cnt) { header_.hnsw.upper_neighbor_count = cnt; } void set_ef_construction(size_t ef) { header_.graph.ef_construction = ef; } protected: inline const HNSWHeader &header() const { return header_; } inline HNSWHeader *mutable_header() { return &header_; } inline size_t header_size() const { return sizeof(header_); } void set_node_size(size_t size) { header_.graph.node_size = size; } //! Dump all segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const; private: //! dump mapping segment, for get_vector_by_key in provider int64_t dump_mapping_segment(const IndexDumper::Pointer &dumper, const key_t *keys) const; //! dump hnsw head by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_header(const IndexDumper::Pointer &dumper, const HNSWHeader &hd) const; //! dump vectors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_vectors(const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const; //! dump hnsw neighbors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_neighbors(const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { auto len1 = dump_graph_neighbors(dumper, reorder_mapping, neighbor_mapping); if (len1 < 0) { return len1; } auto len2 = dump_upper_neighbors(dumper, get_level, reorder_mapping, neighbor_mapping); if (len2 < 0) { return len2; } return len1 + len2; } //! dump segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const; //! Dump level 0 neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; //! Dump upper level neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; public: //! Cleanup the entity virtual int cleanup(void) { header_.clear(); return 0; } //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswRabitqEntity::Pointer clone() const { LOG_ERROR("Update neighbors not implemented"); return HnswRabitqEntity::Pointer(); } //! Get primary key of the node id virtual key_t get_key(node_id_t id) const = 0; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const = 0; //! Get vectors feature data by keys virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const = 0; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const = 0; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const = 0; //! Retrieve a vector using a primary key virtual const void *get_vector_by_key(uint64_t /*key*/) const { LOG_ERROR("get vector not implemented"); return nullptr; } virtual int get_vector_by_key(const key_t /*key*/, IndexStorage::MemoryBlock & /*block*/) const { return IndexError_NotImplemented; } //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const = 0; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t /*level*/, key_t /*key*/, const void * /*vec*/, node_id_t * /*id*/) { return IndexError_NotImplemented; } //! Add vector and id to hnsw entity virtual int add_vector_with_id(level_t /*level*/, node_id_t /*id*/, const void * /*vec*/) { return IndexError_NotImplemented; } virtual int update_neighbors( level_t /*level*/, node_id_t /*id*/, const std::vector> & /*neighbors*/) { LOG_ERROR("Update neighbors dense not implemented"); return 0; } //! Append neighbor_id to node id neighbors on level, size is the current //! neighbors size. Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t /*level*/, node_id_t /*id*/, uint32_t /*size*/, node_id_t /*neighbor_id*/) { LOG_ERROR("Add neighbor not implemented"); } //! Update entry point and max level virtual void update_ep_and_level(node_id_t ep, level_t level) { header_.hnsw.entry_point = ep; header_.hnsw.max_level = level; } virtual int load(const IndexStorage::Pointer & /*container*/, bool /*check_crc*/) { LOG_ERROR("Load not implemented"); return IndexError_NotImplemented; } virtual int dump(const IndexDumper::Pointer & /*dumper*/) { LOG_ERROR("Dump not implemented"); return IndexError_NotImplemented; } static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size); uint32_t get_cluster_id(const void *vec) const { return *reinterpret_cast( reinterpret_cast(vec) + cluster_id_offset()); } const char *get_bin_data(const void *vec) const { return reinterpret_cast(vec) + bin_data_offset(); } const char *get_ex_data(const void *vec) const { return reinterpret_cast(vec) + ex_data_offset(); } uint32_t cluster_id_offset() const { return 0; } uint32_t bin_data_offset() const { return cluster_id_offset() + sizeof(uint32_t); } uint32_t ex_data_offset() const { return bin_data_offset() + size_bin_data(); } protected: static inline size_t AlignSize(size_t size) { return (size + 0x1F) & (~0x1F); } static inline size_t AlignPageSize(size_t size) { size_t page_mask = ailego::MemoryHelper::PageSize() - 1; return (size + page_mask) & (~page_mask); } static inline size_t AlignHugePageSize(size_t size) { size_t page_mask = ailego::MemoryHelper::HugePageSize() - 1; return (size + page_mask) & (~page_mask); } //! rearrange vectors to improve cache locality void reshuffle_vectors(const std::function &get_level, std::vector *n2o_mapping, std::vector *o2n_mapping, key_t *keys) const; public: const static std::string kGraphHeaderSegmentId; const static std::string kGraphFeaturesSegmentId; const static std::string kGraphKeysSegmentId; const static std::string kGraphNeighborsSegmentId; const static std::string kGraphOffsetsSegmentId; const static std::string kGraphMappingSegmentId; const static std::string kHnswHeaderSegmentId; const static std::string kHnswNeighborsSegmentId; const static std::string kHnswOffsetsSegmentId; constexpr static uint32_t kRevision = 0U; constexpr static size_t kMaxGraphLayers = 15; constexpr static uint32_t kDefaultEfConstruction = 500; constexpr static uint32_t kDefaultEf = 500; constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; constexpr static uint32_t kMaxNeighborCnt = 65535; constexpr static float kDefaultScanRatio = 0.1f; constexpr static uint32_t kDefaultMinScanLimit = 10000; constexpr static uint32_t kDefaultMaxScanLimit = std::numeric_limits::max(); constexpr static float kDefaultBFNegativeProbability = 0.001f; constexpr static uint32_t kDefaultScalingFactor = 50U; constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; constexpr static size_t kDefaultMaxChunkCnt = 50000UL; constexpr static float kDefaultNeighborPruneMultiplier = 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier constexpr static float kDefaultL0MaxNeighborCntMultiplier = 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier protected: HNSWHeader header_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_index_hash.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_rabitq_chunk.h" namespace zvec { namespace core { //! Persistent hashmap implement through open addressing algorithm template ::value>::type> class HnswIndexHashMap { using key_type = Key; using val_type = Val; struct Iterator { key_type first; val_type second; }; typedef Iterator *iterator; typedef Iterator Item; typedef const Iterator *const_iterator; class Slot { public: Slot(Chunk::Pointer &&chunk, const void *data) : chunk_(std::move(chunk)), items_(reinterpret_cast(data)) {} //! Return a empty loc or the key item loc Slot(Chunk::Pointer &&chunk, IndexStorage::MemoryBlock &&mem_block) : chunk_(std::move(chunk)), items_block_(std::move(mem_block)) { items_ = reinterpret_cast(items_block_.data()); } const_iterator find(key_type key, uint32_t max_items, uint32_t mask) const { auto it = &items_[key & mask]; for (auto i = 0U; i < max_items; ++i) { if (it->first == key || it->second == EmptyVal) { // LOG_DEBUG("i=%u", i); return it; } ++it; if (it == &items_[max_items]) { it = &items_[0]; } } return nullptr; } bool update(const_iterator it) { uint32_t offset = reinterpret_cast(it) - reinterpret_cast(&items_[0]); if (ailego_unlikely(chunk_->write(offset, it, sizeof(Item)) != sizeof(Item))) { LOG_ERROR("Chunk write failed"); return false; } return true; } private: Chunk::Pointer chunk_{}; const Item *items_{nullptr}; // point to chunk data IndexStorage::MemoryBlock items_block_{}; }; public: //! Init the hash //! broker the index allocator //! chunk_size the size of per chunk allocated, actual size may greater //! factor factor = 1/ratio, ratio is the probability of a squence //! number inserted to this container //! max the max number key can be inserted //! expansion_ratio memory expansion ratio int init(HnswRabitqChunkBroker::Pointer &broker, uint32_t chunk_size, uint32_t factor, size_t max, float expansion_ratio) { ailego_assert_with(expansion_ratio > 1.0f, "ratio must > 1.0f"); broker_ = broker; size_t items = std::ceil(chunk_size * 1.0f / sizeof(Item)); slot_items_ = 1UL << static_cast((std::ceil(std::log2(items)))); size_t range = slot_items_ * factor / expansion_ratio; mask_bits_ = std::floor(std::log2(range)); range = 1UL << mask_bits_; size_t max_slots = std::ceil(max * 1.0f / range); slots_.reserve(max_slots); slot_loc_mask_ = slot_items_ - 1U; int ret = load(); if (ret != 0) { return ret; } LOG_DEBUG( "HnswRabitqIndexHash init, chunkSize=%u factor=%u max=%zu " "ratio=%f slotItems=%u maxSlots=%zu maskBits=%u " "range=%zu", chunk_size, factor, max, expansion_ratio, slot_items_, max_slots, mask_bits_, range); return 0; } int cleanup(void) { broker_.reset(); slots_.clear(); slots_.shrink_to_fit(); mask_bits_ = 0U; slot_items_ = 0U; slot_loc_mask_ = 0U; return 0; } const_iterator end(void) const { return nullptr; } const_iterator find(const key_type key) const { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { return end(); } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); return it && it->second != EmptyVal ? it : nullptr; } bool insert(key_type key, val_type val) { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { if (ailego_unlikely(idx >= slots_.capacity())) { LOG_ERROR("no space to insert"); return false; } for (auto i = slots_.size(); i <= idx; ++i) { if (ailego_unlikely(!alloc_slot(i))) { return false; } } } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); if (ailego_unlikely(it == nullptr)) { LOG_ERROR("no space to insert"); return false; } //! TODO: write memory is ok? const_cast(it)->first = key; const_cast(it)->second = val; return slots_[idx].update(it); } private: bool alloc_slot(size_t idx) { ailego_assert_with(idx == slots_.size(), "invalid idx"); size_t size = slot_items_ * sizeof(Item); auto p = broker_->alloc_chunk( HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, idx, size); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return false; } Chunk::Pointer chunk = p.second; if (ailego_unlikely(chunk->resize(size) != size)) { LOG_ERROR("Chunk resize failed, size=%zu", size); return false; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); return true; } int load(void) { size_t slots_cnt = broker_->get_chunk_cnt( HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX); for (size_t i = 0UL; i < slots_cnt; ++i) { auto chunk = broker_->get_chunk( HnswRabitqChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, i); if (!chunk) { LOG_ERROR("Get chunk failed, seq=%zu", i); return IndexError_InvalidFormat; } size_t size = sizeof(Item) * slot_items_; if (chunk->data_size() < size) { LOG_ERROR( "Hash params may be mismatch, seq=%zu, data_size=%zu " "expect=%zu", i, chunk->data_size(), size); return IndexError_InvalidFormat; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); } return 0; } private: HnswRabitqChunkBroker::Pointer broker_{}; // chunk broker std::vector slots_{}; uint32_t mask_bits_{0U}; uint32_t slot_items_{}; // must be a power of 2 uint32_t slot_loc_mask_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_index_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "zvec/core/framework/index_provider.h" #include "zvec/core/framework/index_searcher.h" #include "zvec/core/framework/index_streamer.h" #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { class HnswRabitqIndexProvider : public IndexProvider { public: HnswRabitqIndexProvider(const IndexMeta &meta, const HnswRabitqEntity::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} HnswRabitqIndexProvider(const HnswRabitqIndexProvider &) = delete; HnswRabitqIndexProvider &operator=(const HnswRabitqIndexProvider &) = delete; public: // holder interface //! Create a new iterator IndexProvider::Iterator::Pointer create_iterator() override { return HnswRabitqIndexProvider::Iterator::Pointer(new (std::nothrow) Iterator(entity_)); } //! Retrieve count of vectors size_t count(void) const override { return entity_->doc_cnt(); } //! Retrieve dimension of vector size_t dimension(void) const override { return meta_.dimension(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return meta_.data_type(); } //! Retrieve vector size in bytes size_t element_size(void) const override { return meta_.element_size(); } public: // provider's unique interface //! Retrieve a vector using a primary key const void *get_vector(uint64_t key) const override { return entity_->get_vector_by_key(key); } int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return entity_->get_vector_by_key(key, block); } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_class_; } private: class Iterator : public IndexProvider::Iterator { public: Iterator(const HnswRabitqEntity::Pointer &entity) : entity_(entity), cur_id_(0U) {} //! Retrieve pointer of data //! NOTICE: the vec feature will be changed after iterating to next, so //! the caller need to keep a copy of it before iterator to next vector virtual const void *data(void) const override { return entity_->get_vector(cur_id_); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return cur_id_ < entity_->doc_cnt(); } //! Retrieve primary key virtual uint64_t key(void) const override { return entity_->get_key(cur_id_); } //! Next iterator virtual void next(void) override { // cur_id_ += 1; cur_id_ = get_next_valid_id(cur_id_ + 1); } //! Reset the iterator void reset(void) { cur_id_ = get_next_valid_id(0); } private: node_id_t get_next_valid_id(node_id_t start_id) { for (node_id_t i = start_id; i < entity_->doc_cnt(); i++) { if (entity_->get_key(i) != kInvalidNodeId) { cur_id_ = i; return i; } } return kInvalidNodeId; } private: const HnswRabitqEntity::Pointer entity_; node_id_t cur_id_; }; private: const IndexMeta &meta_; const HnswRabitqEntity::Pointer entity_; const std::string owner_class_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { inline const std::string PARAM_HNSW_RABITQ_GENERAL_DIMENSION( "proxima.hnsw_rabitq.general.dimension"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_THREAD_COUNT( "proxima.hnsw_rabitq.builder.thread_count"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_MEMORY_QUOTA( "proxima.hnsw_rabitq.builder.memory_quota"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_EFCONSTRUCTION( "proxima.hnsw_rabitq.builder.efconstruction"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_SCALING_FACTOR( "proxima.hnsw_rabitq.builder.scaling_factor"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_CHECK_INTERVAL_SECS( "proxima.hnsw_rabitq.builder.check_interval_secs"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw_rabitq.builder.neighbor_prune_multiplier"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_MIN_NEIGHBOR_COUNT( "proxima.hnsw_rabitq.builder.min_neighbor_count"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_MAX_NEIGHBOR_COUNT( "proxima.hnsw_rabitq.builder.max_neighbor_count"); inline const std::string PARAM_HNSW_RABITQ_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw_rabitq.builder.l0_max_neighbor_count_multiplier"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_EF( "proxima.hnsw_rabitq.searcher.ef"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw_rabitq.searcher.brute_force_threshold"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE( "proxima.hnsw_rabitq.searcher.neighbors_in_memory_enable"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO( "proxima.hnsw_rabitq.searcher.max_scan_ratio"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_CHECK_CRC_ENABLE( "proxima.hnsw_rabitq.searcher.check_crc_enable"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw_rabitq.searcher.visit_bloomfilter_enable"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw_rabitq.searcher.visit_bloomfilter_negative_prob"); inline const std::string PARAM_HNSW_RABITQ_SEARCHER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw_rabitq.searcher.force_padding_result_enable"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO( "proxima.hnsw_rabitq.streamer.max_scan_ratio"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT( "proxima.hnsw_rabitq.streamer.min_scan_limit"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT( "proxima.hnsw_rabitq.streamer.max_scan_limit"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_EF( "proxima.hnsw_rabitq.streamer.ef"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION( "proxima.hnsw_rabitq.streamer.efconstruction"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT( "proxima.hnsw_rabitq.streamer.max_neighbor_count"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw_rabitq.streamer.l0_max_neighbor_count_multiplier"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR( "proxima.hnsw_rabitq.streamer.scaling_factor"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw_rabitq.streamer.brute_force_threshold"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT( "proxima.hnsw_rabitq.streamer.docs_hard_limit"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT( "proxima.hnsw_rabitq.streamer.docs_soft_limit"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MAX_INDEX_SIZE( "proxima.hnsw_rabitq.streamer.max_index_size"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw_rabitq.streamer.visit_bloomfilter_enable"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw_rabitq.streamer.visit_bloomfilter_negative_prob"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_CHECK_CRC_ENABLE( "proxima.hnsw_rabitq.streamer.check_crc_enable"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw_rabitq.streamer.neighbor_prune_multiplier"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE( "proxima.hnsw_rabitq.streamer.chunk_size"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_FILTER_SAME_KEY( "proxima.hnsw_rabitq.streamer.filter_same_key"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE( "proxima.hnsw_rabitq.streamer.get_vector_enable"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT( "proxima.hnsw_rabitq.streamer.min_neighbor_count"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw_rabitq.streamer.force_padding_result_enable"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_ESTIMATE_DOC_COUNT( "proxima.hnsw_rabitq.streamer.estimate_doc_count"); inline const std::string PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP( "proxima.hnsw_rabitq.streamer.use_id_map"); inline const std::string PARAM_HNSW_RABITQ_REDUCER_WORKING_PATH( "proxima.hnsw_rabitq.reducer.working_path"); inline const std::string PARAM_HNSW_RABITQ_REDUCER_NUM_OF_ADD_THREADS( "proxima.hnsw_rabitq.reducer.num_of_add_threads"); inline const std::string PARAM_HNSW_RABITQ_REDUCER_INDEX_NAME( "proxima.hnsw_rabitq.reducer.index_name"); inline const std::string PARAM_HNSW_RABITQ_REDUCER_EFCONSTRUCTION( "proxima.hnsw_rabitq.reducer.efconstruction"); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_query_algorithm.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_query_algorithm.h" #include #include #include #include #include "zvec/ailego/internal/platform.h" #include "hnsw_rabitq_entity.h" #include "hnsw_rabitq_query_entity.h" namespace zvec { namespace core { HnswRabitqQueryAlgorithm::HnswRabitqQueryAlgorithm(HnswRabitqEntity &entity, size_t num_clusters, RabitqMetricType metric_type) : entity_(entity), mt_(std::chrono::system_clock::now().time_since_epoch().count()), lock_pool_(kLockCnt), num_clusters_(num_clusters), metric_type_(metric_type) { ex_bits_ = entity_.ex_bits(); padded_dim_ = entity_.padded_dim(); ip_func_ = rabitqlib::select_excode_ipfunc(ex_bits_); LOG_INFO( "Create query algorithm. num_clusters=%zu ex_bits=%zu padded_dim=%zu", num_clusters_, ex_bits_, padded_dim_); } int HnswRabitqQueryAlgorithm::cleanup() { return 0; } int HnswRabitqQueryAlgorithm::search(HnswRabitqQueryEntity *entity, HnswRabitqContext *ctx) const { spin_lock_.lock(); auto maxLevel = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); spin_lock_.unlock(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { return 0; } EstimateRecord curest; get_bin_est(entity_.get_vector(entry_point), curest, *entity); for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { select_entry_point(cur_level, &entry_point, &curest, ctx, entity); } auto &topk_heap = ctx->topk_heap(); topk_heap.clear(); search_neighbors(0, &entry_point, &curest, topk_heap, ctx, entity); if (ctx->group_by_search()) { expand_neighbors_by_group(topk_heap, ctx, entity); } return 0; } //! select_entry_point on hnsw level, ef = 1 void HnswRabitqQueryAlgorithm::select_entry_point( level_t level, node_id_t *entry_point, EstimateRecord *curest, HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const { auto &entity = ctx->get_entity(); while (true) { const Neighbors neighbors = entity.get_neighbors(level, *entry_point); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } ailego_prefetch(neighbors.data); uint32_t size = neighbors.size(); if (size == 0) { break; } bool find_closer = false; for (uint32_t i = 0; i < size; ++i) { EstimateRecord candest; get_bin_est(entity_.get_vector(neighbors[i]), candest, *query_entity); if (candest.est_dist < curest->est_dist) { *curest = candest; *entry_point = neighbors[i]; find_closer = true; } } if (!find_closer) { break; } } return; } void HnswRabitqQueryAlgorithm::search_neighbors( level_t level, node_id_t *entry_point, EstimateRecord *dist, TopkHeap &topk, HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const { const auto &entity = ctx->get_entity(); VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } candidates.clear(); visit.clear(); visit.set_visited(*entry_point); if (!filter(*entry_point)) { topk.emplace(*entry_point, ResultRecord(*dist)); } candidates.emplace(*entry_point, ResultRecord(*dist)); while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; auto main_dist = top->second; if (topk.full() && main_dist.est_dist > topk[0].second.est_dist) { break; } candidates.pop(); const Neighbors neighbors = entity.get_neighbors(level, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; EstimateRecord candest; auto *cand_vector = entity_.get_vector(node); ailego_prefetch(cand_vector); get_bin_est(cand_vector, candest, *query_entity); if (ex_bits_ > 0) { // Check preliminary score against current worst full estimate. bool flag_update_KNNs = (!topk.full()) || candest.low_dist < topk[0].second.est_dist; if (flag_update_KNNs) { // Compute the full estimate if promising. get_full_est(cand_vector, candest, *query_entity); } else { continue; } } else { // ex_bits_ == 0: est_dist is already the best estimate if (topk.full() && candest.est_dist >= topk[0].second.est_dist) { continue; } } candidates.emplace(node, ResultRecord(candest)); // update entry_point for next level scan if (candest < *dist) { *entry_point = node; *dist = candest; } if (!filter(node)) { topk.emplace(node, ResultRecord(candest)); } } // end for } // while return; } void HnswRabitqQueryAlgorithm::expand_neighbors_by_group( TopkHeap &topk, HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const { if (!ctx->group_by().is_valid()) { return; } const auto &entity = ctx->get_entity(); std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity.get_key(id)); }; // devide into groups std::map &group_topk_heaps = ctx->group_topk_heaps(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; auto score = topk[i].second; std::string group_id = group_by(id); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, score); } // stage 2, expand to reach group num as possible if (group_topk_heaps.size() < ctx->group_num()) { VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } // refill to get enough groups candidates.clear(); visit.clear(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; auto score = topk[i].second; visit.set_visited(id); candidates.emplace_back(id, score); } // do expand while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; candidates.pop(); const Neighbors neighbors = entity.get_neighbors(0, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; EstimateRecord candest; auto *cand_vector = entity_.get_vector(node); ailego_prefetch(cand_vector); get_full_est(cand_vector, candest, *query_entity); if (!filter(node)) { std::string group_id = group_by(node); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(node, ResultRecord(candest)); if (group_topk_heaps.size() >= ctx->group_num()) { break; } } candidates.emplace(node, ResultRecord(candest)); } // end for } // end while } // end if } void HnswRabitqQueryAlgorithm::get_bin_est( const void *vector, EstimateRecord &res, HnswRabitqQueryEntity &entity) const { const auto &q_to_centroids = entity.q_to_centroids; auto &query_wrapper = *entity.query_wrapper; uint32_t cluster_id = entity_.get_cluster_id(vector); const char *bin_data = entity_.get_bin_data(vector); if (metric_type_ == RabitqMetricType::kIP) { float norm = q_to_centroids[cluster_id]; float error = q_to_centroids[cluster_id + num_clusters_]; rabitqlib::split_single_estdist(bin_data, query_wrapper, padded_dim_, res.ip_x0_qr, res.est_dist, res.low_dist, -norm, error); } else { // L2 distance float norm = q_to_centroids[cluster_id]; rabitqlib::split_single_estdist(bin_data, query_wrapper, padded_dim_, res.ip_x0_qr, res.est_dist, res.low_dist, norm * norm, norm); } } void HnswRabitqQueryAlgorithm::get_full_est( const void *vector, EstimateRecord &res, HnswRabitqQueryEntity &entity) const { const auto &q_to_centroids = entity.q_to_centroids; auto &query_wrapper = *entity.query_wrapper; uint32_t cluster_id = entity_.get_cluster_id(vector); const char *bin_data = entity_.get_bin_data(vector); const char *ex_data = entity_.get_ex_data(vector); if (metric_type_ == RabitqMetricType::kIP) { float norm = q_to_centroids[cluster_id]; float error = q_to_centroids[cluster_id + num_clusters_]; rabitqlib::split_single_fulldist(bin_data, ex_data, ip_func_, query_wrapper, padded_dim_, ex_bits_, res.est_dist, res.low_dist, res.ip_x0_qr, -norm, error); } else { // L2 distance float norm = q_to_centroids[cluster_id]; rabitqlib::split_single_fulldist( bin_data, ex_data, ip_func_, query_wrapper, padded_dim_, ex_bits_, res.est_dist, res.low_dist, res.ip_x0_qr, norm * norm, norm); } } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_query_algorithm.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_rabitq_context.h" #include "hnsw_rabitq_dist_calculator.h" #include "hnsw_rabitq_entity.h" #include "rabitq_params.h" namespace zvec { namespace core { class HnswRabitqQueryEntity; //! hnsw graph algorithm implement class HnswRabitqQueryAlgorithm { public: typedef std::unique_ptr UPointer; public: //! Constructor explicit HnswRabitqQueryAlgorithm(HnswRabitqEntity &entity, size_t num_clusters, RabitqMetricType metric_type); //! Destructor ~HnswRabitqQueryAlgorithm() = default; //! Cleanup HnswRabitqQueryAlgorithm int cleanup(); //! do knn search in graph //! return 0 on success, or errCode in failure. results saved in ctx int search(HnswRabitqQueryEntity *entity, HnswRabitqContext *ctx) const; //! Initiate HnswRabitqQueryAlgorithm int init() { level_probas_.clear(); double level_mult = 1 / std::log(static_cast(entity_.scaling_factor())); for (int level = 0;; level++) { // refers faiss get_random_level alg double proba = std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); if (proba < 1e-9) { break; } level_probas_.push_back(proba); } return 0; } //! Generate a random level //! return graph level uint32_t get_random_level() const { // gen rand float (0, 1) double f = mt_() / static_cast(mt_.max()); for (size_t level = 0; level < level_probas_.size(); level++) { if (f < level_probas_[level]) { return level; } f -= level_probas_[level]; } return level_probas_.size() - 1; } void get_full_est(node_id_t id, EstimateRecord &res, HnswRabitqQueryEntity &entity) const { return get_full_est(entity_.get_vector(id), res, entity); } private: //! Select in upper layer to get entry point for next layer search void select_entry_point(level_t level, node_id_t *entry_point, EstimateRecord *dist, HnswRabitqContext *ctx, HnswRabitqQueryEntity *entity) const; //! Given a node id and level, search the nearest neighbors in graph //! Note: the nearest neighbors result keeps in topk, and entry_point and //! dist will be updated to current level nearest node id and distance void search_neighbors(level_t level, node_id_t *entry_point, EstimateRecord *dist, TopkHeap &topk, HnswRabitqContext *ctx, HnswRabitqQueryEntity *entity) const; //! expand neighbors until group nums are reached void expand_neighbors_by_group(TopkHeap &topk, HnswRabitqContext *ctx, HnswRabitqQueryEntity *query_entity) const; void get_full_est(const void *vector, EstimateRecord &res, HnswRabitqQueryEntity &entity) const; void get_bin_est(const void *vector, EstimateRecord &res, HnswRabitqQueryEntity &entity) const; private: HnswRabitqQueryAlgorithm(const HnswRabitqQueryAlgorithm &) = delete; HnswRabitqQueryAlgorithm &operator=(const HnswRabitqQueryAlgorithm &) = delete; private: static constexpr uint32_t kLockCnt{1U << 8}; static constexpr uint32_t kLockMask{kLockCnt - 1U}; HnswRabitqEntity &entity_; mutable std::mt19937 mt_{}; std::vector level_probas_{}; mutable ailego::SpinMutex spin_lock_{}; // global spin lock std::mutex mutex_{}; // global mutex // TODO: spin lock? std::vector lock_pool_{}; size_t num_clusters_{0}; RabitqMetricType metric_type_{RabitqMetricType::kL2}; size_t padded_dim_{0}; size_t ex_bits_{0}; float (*ip_func_)(const float *, const uint8_t *, size_t); }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_query_entity.h ================================================ // Copyright 2025-present the centaurdb project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License #pragma once #include #include #include namespace zvec::core { struct HnswRabitqQueryEntity { std::vector rotated_query; std::vector q_to_centroids; std::unique_ptr> query_wrapper; }; } // namespace zvec::core ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_register.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License #include "hnsw_rabitq_builder.h" #include "hnsw_rabitq_searcher.h" #include "hnsw_rabitq_streamer.h" #include "rabitq_converter.h" #include "rabitq_reformer.h" namespace zvec::core { INDEX_FACTORY_REGISTER_STREAMER(HnswRabitqStreamer); INDEX_FACTORY_REGISTER_REFORMER_ALIAS(RabitqReformer, RabitqReformer); INDEX_FACTORY_REGISTER_SEARCHER(HnswRabitqSearcher); INDEX_FACTORY_REGISTER_CONVERTER_ALIAS(RabitqConverter, RabitqConverter); INDEX_FACTORY_REGISTER_BUILDER(HnswRabitqBuilder); } // namespace zvec::core ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_searcher.h" #include #include "hnsw_rabitq_algorithm.h" #include "hnsw_rabitq_entity.h" #include "hnsw_rabitq_index_provider.h" #include "hnsw_rabitq_params.h" #include "hnsw_rabitq_query_entity.h" #include "hnsw_rabitq_searcher_entity.h" #include "rabitq_params.h" namespace zvec { namespace core { HnswRabitqSearcher::HnswRabitqSearcher() {} HnswRabitqSearcher::~HnswRabitqSearcher() {} int HnswRabitqSearcher::init(const ailego::Params &search_params) { params_ = search_params; params_.get(PARAM_HNSW_RABITQ_SEARCHER_EF, &ef_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_CHECK_CRC_ENABLE, &check_crc_enabled_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE, &neighbors_in_memory_enabled_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_probability_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params_.get(PARAM_HNSW_RABITQ_SEARCHER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); if (ef_ == 0) { ef_ = HnswRabitqEntity::kDefaultEf; } if (bf_negative_probability_ <= 0.0f || bf_negative_probability_ >= 1.0f) { LOG_ERROR( "[%s] must be in range (0,1)", PARAM_HNSW_RABITQ_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } entity_.set_neighbors_in_memory(neighbors_in_memory_enabled_); ailego::Params reformer_params; reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); int ret = reformer_.init(reformer_params); if (ret != 0) { LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); return ret; } state_ = STATE_INITED; LOG_DEBUG( "Init params: ef=%u maxScanRatio=%f bfEnabled=%u checkCrcEnabled=%u " "neighborsInMemoryEnabled=%u bfNagtiveProb=%f bruteForceThreshold=%u " "forcePadding=%u", ef_, max_scan_ratio_, bf_enabled_, check_crc_enabled_, neighbors_in_memory_enabled_, bf_negative_probability_, bruteforce_threshold_, force_padding_topk_enabled_); return 0; } void HnswRabitqSearcher::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } } int HnswRabitqSearcher::cleanup() { LOG_INFO("Begin HnswRabitqSearcher:cleanup"); metric_.reset(); meta_.clear(); stats_.clear_attributes(); stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); max_scan_ratio_ = HnswRabitqEntity::kDefaultScanRatio; max_scan_num_ = 0U; ef_ = HnswRabitqEntity::kDefaultEf; bf_enabled_ = false; bf_negative_probability_ = HnswRabitqEntity::kDefaultBFNegativeProbability; bruteforce_threshold_ = HnswRabitqEntity::kDefaultBruteForceThreshold; check_crc_enabled_ = false; neighbors_in_memory_enabled_ = false; entity_.cleanup(); state_ = STATE_INIT; LOG_INFO("End HnswRabitqSearcher:cleanup"); return 0; } int HnswRabitqSearcher::load(IndexStorage::Pointer container, IndexMetric::Pointer metric) { if (state_ != STATE_INITED) { LOG_ERROR("Init the searcher first before load index"); return IndexError_Runtime; } LOG_INFO("Begin HnswRabitqSearcher:load"); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } ret = reformer_.load(container); if (ret != 0) { LOG_ERROR("Failed to load reformer from container: %d", ret); return ret; } ret = entity_.load(container, check_crc_enabled_); if (ret != 0) { LOG_ERROR("HnswRabitqSearcher load index failed"); return ret; } alg_ = HnswRabitqQueryAlgorithm::UPointer(new HnswRabitqQueryAlgorithm( entity_, reformer_.num_clusters(), reformer_.rabitq_metric_type())); if (metric) { metric_ = metric; } else { metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMetric failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed, ret=%d", ret); return ret; } if (metric_->query_metric()) { metric_ = metric_->query_metric(); } } if (!metric_->is_matched(meta_)) { LOG_ERROR("IndexMetric not match index meta"); return IndexError_Mismatch; } max_scan_num_ = static_cast(max_scan_ratio_ * entity_.doc_cnt()); max_scan_num_ = std::max(4096U, max_scan_num_); stats_.set_loaded_count(entity_.doc_cnt()); stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = STATE_LOADED; magic_ = IndexContext::GenerateMagic(); LOG_INFO("End HnswRabitqSearcher::load"); return 0; } int HnswRabitqSearcher::unload() { LOG_INFO("HnswRabitqSearcher unload index"); meta_.clear(); entity_.cleanup(); metric_.reset(); max_scan_num_ = 0; stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); state_ = STATE_INITED; return 0; } int HnswRabitqSearcher::update_context(HnswRabitqContext *ctx) const { const HnswRabitqEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_num(max_scan_num_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswRabitqContext::kSearcherContext, meta_, metric_, entity, magic_); } int HnswRabitqSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; int ret = reformer_.transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ret = alg_->search(&entity, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswRabitqSearcher::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; int ret = reformer_.transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { EstimateRecord dist; alg_->get_full_est(id, dist, entity); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; int ret = reformer_.transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ctx->topk_heap().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { EstimateRecord dist; alg_->get_full_est(id, dist, entity); ctx->topk_heap().emplace(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswRabitqSearcher::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; int ret = reformer_.transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { EstimateRecord dist; alg_->get_full_est(id, dist, entity); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; int ret = reformer_.transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ctx->topk_heap().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { EstimateRecord dist; alg_->get_full_est(id, dist, entity); ctx->topk_heap().emplace(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } IndexSearcher::Context::Pointer HnswRabitqSearcher::create_context() const { if (ailego_unlikely(state_ != STATE_LOADED)) { LOG_ERROR("Load the index first before create context"); return Context::Pointer(); } const HnswRabitqEntity::Pointer search_ctx_entity = entity_.clone(); if (!search_ctx_entity) { LOG_ERROR("Failed to create search context entity"); return Context::Pointer(); } HnswRabitqContext *ctx = new (std::nothrow) HnswRabitqContext(meta_.dimension(), metric_, search_ctx_entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswRabitqContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_num(max_scan_num_); uint32_t filter_mode = bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; ctx->set_filter_mode(filter_mode); ctx->set_filter_negative_probability(bf_negative_probability_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswRabitqContext::kSearcherContext)) != 0) { LOG_ERROR("Init HnswRabitqContext failed"); delete ctx; return Context::Pointer(); } return Context::Pointer(ctx); } IndexProvider::Pointer HnswRabitqSearcher::create_provider(void) const { LOG_DEBUG("HnswRabitqSearcher create provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswRabitqEntity failed"); return Provider::Pointer(); } return Provider::Pointer(new (std::nothrow) HnswRabitqIndexProvider( meta_, entity, "HnswRabitqSearcher")); } const void *HnswRabitqSearcher::get_vector(uint64_t key) const { return entity_.get_vector_by_key(key); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "zvec/core/framework/index_framework.h" #include "hnsw_rabitq_query_algorithm.h" #include "hnsw_rabitq_searcher_entity.h" #include "rabitq_reformer.h" namespace zvec { namespace core { class HnswRabitqSearcher : public IndexSearcher { public: using ContextPointer = IndexSearcher::Context::Pointer; public: HnswRabitqSearcher(void); ~HnswRabitqSearcher(void); HnswRabitqSearcher(const HnswRabitqSearcher &) = delete; HnswRabitqSearcher &operator=(const HnswRabitqSearcher &) = delete; protected: //! Initialize Searcher virtual int init(const ailego::Params ¶ms) override; //! Cleanup Searcher virtual int cleanup(void) override; //! Load Index from storage virtual int load(IndexStorage::Pointer container, IndexMetric::Pointer metric) override; //! Unload index from storage virtual int unload(void) override; //! KNN Search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_impl(query, qmeta, 1, context); } //! KNN Search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Linear Search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_impl(query, qmeta, 1, context); } //! Linear Search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch vector by key virtual const void *get_vector(uint64_t key) const override; //! Create a searcher context virtual ContextPointer create_context() const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index virtual const ailego::Params ¶ms(void) const override { return params_; } virtual void print_debug_info() override; private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswRabitqContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; HnswRabitqSearcherEntity entity_{}; HnswRabitqQueryAlgorithm::UPointer alg_; // impl graph algorithm IndexMetric::Pointer metric_{}; IndexMeta meta_{}; ailego::Params params_{}; Stats stats_; uint32_t ef_{HnswRabitqEntity::kDefaultEf}; uint32_t max_scan_num_{0U}; uint32_t bruteforce_threshold_{HnswRabitqEntity::kDefaultBruteForceThreshold}; float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool neighbors_in_memory_enabled_{false}; bool force_padding_topk_enabled_{false}; float bf_negative_probability_{ HnswRabitqEntity::kDefaultBFNegativeProbability}; uint32_t magic_{0U}; RabitqReformer reformer_; State state_{STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_searcher_entity.h" #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswRabitqSearcherEntity::HnswRabitqSearcherEntity() {} int HnswRabitqSearcherEntity::cleanup(void) { storage_.reset(); vectors_.reset(); keys_.reset(); neighbors_.reset(); neighbors_meta_.reset(); neighbors_in_memory_enabled_ = false; loaded_ = false; this->HnswRabitqEntity::cleanup(); return 0; } key_t HnswRabitqSearcherEntity::get_key(node_id_t id) const { const void *key; if (ailego_unlikely(keys_->read(id * sizeof(key_t), &key, sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidKey; } return *(reinterpret_cast(key)); } //! Get vector local id by key node_id_t HnswRabitqSearcherEntity::get_id(key_t key) const { if (ailego_unlikely(!mapping_)) { LOG_ERROR("Index missing mapping segment"); return kInvalidNodeId; } //! Do binary search node_id_t start = 0UL; node_id_t end = doc_cnt(); const void *data; node_id_t idx = 0u; while (start < end) { idx = start + (end - start) / 2; if (ailego_unlikely( mapping_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != sizeof(node_id_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } const key_t *mkey; node_id_t local_id = *reinterpret_cast(data); if (ailego_unlikely(keys_->read(local_id * sizeof(key_t), (const void **)(&mkey), sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } if (*mkey < key) { start = idx + 1; } else if (*mkey > key) { end = idx; } else { return local_id; } } return kInvalidNodeId; } const void *HnswRabitqSearcherEntity::get_vector_by_key(key_t key) const { node_id_t local_id = get_id(key); if (ailego_unlikely(local_id == kInvalidNodeId)) { return nullptr; } return get_vector(local_id); } const void *HnswRabitqSearcherEntity::get_vector(node_id_t id) const { size_t read_size = vector_size(); size_t offset = node_size() * id; const void *vec; if (ailego_unlikely(vectors_->read(offset, &vec, read_size) != read_size)) { LOG_ERROR("Read vector from segment failed"); return nullptr; } return vec; } int HnswRabitqSearcherEntity::get_vector( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector(id); block.reset((void *)vec); return 0; } const void *HnswRabitqSearcherEntity::get_vectors() const { const void *vec; size_t len = node_size() * doc_cnt(); if (vectors_->read(0, &vec, len) != len) { LOG_ERROR("Read vectors from segment failed"); return nullptr; } return vec; } int HnswRabitqSearcherEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { ailego_assert_with(count <= segment_datas_.size(), "invalid count"); size_t read_size = vector_size(); for (uint32_t i = 0; i < count; ++i) { segment_datas_[i].offset = node_size() * ids[i]; segment_datas_[i].length = read_size; ailego_assert_with(segment_datas_[i].offset < vectors_->data_size(), "invalid offset"); } if (ailego_unlikely(!vectors_->read(&segment_datas_[0], count))) { LOG_ERROR("Read vectors from segment failed"); return IndexError_ReadData; } for (uint32_t i = 0; i < count; ++i) { vecs[i] = segment_datas_[i].data; } return 0; } int HnswRabitqSearcherEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { const void *vecs[count]; get_vector(ids, count, vecs); for (uint32_t i = 0; i < count; ++i) { vec_blocks.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } const Neighbors HnswRabitqSearcherEntity::get_neighbors(level_t level, node_id_t id) const { if (level == 0) { if (neighbors_in_memory_enabled_) { auto hd = reinterpret_cast( fixed_neighbors_.get() + neighbors_size() * id); return {hd->neighbor_cnt, hd->neighbors}; } const GraphNeighborMeta *m; if (ailego_unlikely(neighbors_meta_->read(id * sizeof(GraphNeighborMeta), (const void **)(&m), sizeof(GraphNeighborMeta)) != sizeof(GraphNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } const void *data; if (ailego_unlikely(neighbors_->read(m->offset, &data, m->neighbor_cnt * sizeof(node_id_t)) != m->neighbor_cnt * sizeof(node_id_t))) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } return {static_cast(m->neighbor_cnt), reinterpret_cast(data)}; } //! Read level > 0 neighbors const HnswNeighborMeta *m; if (ailego_unlikely(upper_neighbors_meta_->read(id * sizeof(HnswNeighborMeta), (const void **)(&m), sizeof(HnswNeighborMeta)) != sizeof(HnswNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } ailego_assert_with(level <= m->level, "invalid level"); size_t offset = m->offset + (level - 1) * upper_neighbors_size(); ailego_assert_with(offset <= upper_neighbors_->data_size(), "invalid offset"); const void *data; if (ailego_unlikely( upper_neighbors_->read(offset, &data, upper_neighbors_size()) != upper_neighbors_size())) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } auto hd = reinterpret_cast(data); return {hd->neighbor_cnt, hd->neighbors}; } int HnswRabitqSearcherEntity::load(const IndexStorage::Pointer &container, bool check_crc) { storage_ = container; int ret = load_segments(check_crc); if (ret != 0) { return ret; } loaded_ = true; LOG_INFO( "Index info: docCnt=%u entryPoint=%u maxLevel=%d efConstruct=%zu " "l0NeighborCnt=%zu upperNeighborCnt=%zu scalingFactor=%zu " "vectorSize=%zu nodeSize=%zu vectorSegmentSize=%zu keySegmentSize=%zu " "neighborsSegmentSize=%zu neighborsMetaSegmentSize=%zu ", doc_cnt(), entry_point(), cur_max_level(), ef_construction(), l0_neighbor_cnt(), upper_neighbor_cnt(), scaling_factor(), vector_size(), node_size(), vectors_->data_size(), keys_->data_size(), neighbors_ == nullptr ? 0 : neighbors_->data_size(), neighbors_meta_ == nullptr ? 0 : neighbors_meta_->data_size()); return 0; } int HnswRabitqSearcherEntity::load_segments(bool check_crc) { //! load header const void *data = nullptr; HNSWHeader hd; auto graph_hd_segment = storage_->get(kGraphHeaderSegmentId); if (!graph_hd_segment || graph_hd_segment->data_size() < sizeof(hd.graph)) { LOG_ERROR("Miss or invalid segment %s", kGraphHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (graph_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.graph)) != sizeof(hd.graph)) { LOG_ERROR("Read segment %s failed", kGraphHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.graph, data, sizeof(hd.graph)); auto hnsw_hd_segment = storage_->get(kHnswHeaderSegmentId); if (!hnsw_hd_segment || hnsw_hd_segment->data_size() < sizeof(hd.hnsw)) { LOG_ERROR("Miss or invalid segment %s", kHnswHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (hnsw_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.hnsw)) != sizeof(hd.hnsw)) { LOG_ERROR("Read segment %s failed", kHnswHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.hnsw, data, sizeof(hd.hnsw)); *mutable_header() = hd; segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt())); vectors_ = storage_->get(kGraphFeaturesSegmentId); if (!vectors_) { LOG_ERROR("IndexStorage get segment %s failed", kGraphFeaturesSegmentId.c_str()); return IndexError_InvalidFormat; } keys_ = storage_->get(kGraphKeysSegmentId); if (!keys_) { LOG_ERROR("IndexStorage get segment %s failed", kGraphKeysSegmentId.c_str()); return IndexError_InvalidFormat; } neighbors_ = storage_->get(kGraphNeighborsSegmentId); if (!neighbors_ || (neighbors_->data_size() == 0 && doc_cnt() > 1)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kGraphNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } neighbors_meta_ = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta_ || neighbors_meta_->data_size() < sizeof(GraphNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_ = storage_->get(kHnswNeighborsSegmentId); if (!upper_neighbors_ || (upper_neighbors_->data_size() == 0 && cur_max_level() > 0)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kHnswNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_meta_ = storage_->get(kHnswOffsetsSegmentId); if (!upper_neighbors_meta_ || upper_neighbors_meta_->data_size() < sizeof(HnswNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kHnswOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } mapping_ = storage_->get(kGraphMappingSegmentId); if (!mapping_ || mapping_->data_size() < sizeof(node_id_t) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kGraphMappingSegmentId.c_str()); return IndexError_InvalidArgument; } if (check_crc) { std::vector segments; segments.emplace_back(graph_hd_segment); segments.emplace_back(hnsw_hd_segment); segments.emplace_back(vectors_); segments.emplace_back(keys_); segments.emplace_back(neighbors_); segments.emplace_back(neighbors_meta_); segments.emplace_back(upper_neighbors_); segments.emplace_back(upper_neighbors_meta_); if (!do_crc_check(segments)) { LOG_ERROR("Check index crc failed, the index may broken"); return IndexError_Runtime; } } if (neighbors_in_memory_enabled_) { int ret = load_and_flat_neighbors(); if (ret != 0) { return ret; } } return 0; } int HnswRabitqSearcherEntity::load_and_flat_neighbors() { fixed_neighbors_.reset( new (std::nothrow) char[neighbors_size() * doc_cnt()]{}, std::default_delete()); if (!fixed_neighbors_) { LOG_ERROR("Malloc memory failed"); return IndexError_NoMemory; } //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const GraphNeighborMeta *neighbors_index = nullptr; if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), neighbors_meta->data_size()) != neighbors_meta->data_size()) { LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } const char *neighbor_data; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t rd_size = neighbors_index[id].neighbor_cnt * sizeof(node_id_t); if (ailego_unlikely( neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbor_data), rd_size) != rd_size)) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory char *dst = fixed_neighbors_.get() + neighbors_size() * id; *reinterpret_cast(dst) = neighbors_index[id].neighbor_cnt; memcpy(dst + sizeof(uint32_t), neighbor_data, rd_size); } return 0; } int HnswRabitqSearcherEntity::get_fixed_neighbors( std::vector *fixed_neighbors) const { //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = storage_->get(kGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const GraphNeighborMeta *neighbors_index = nullptr; size_t meta_size = neighbors_meta->data_size(); if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), meta_size) != meta_size) { LOG_ERROR("Read segment %s data failed", kGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } size_t fixed_neighbor_cnt = l0_neighbor_cnt(); fixed_neighbors->resize((fixed_neighbor_cnt + 1) * doc_cnt(), kInvalidNodeId); size_t neighbors_cnt_offset = fixed_neighbor_cnt * doc_cnt(); size_t total_neighbor_cnt = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t cur_neighbor_cnt = neighbors_index[id].neighbor_cnt; if (cur_neighbor_cnt == 0) { (*fixed_neighbors)[neighbors_cnt_offset + id] = 0; continue; } size_t rd_size = cur_neighbor_cnt * sizeof(node_id_t); const uint32_t *neighbors; if (neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbors), rd_size) != rd_size) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory auto it = fixed_neighbors->begin() + id * fixed_neighbor_cnt; std::copy(neighbors, neighbors + cur_neighbor_cnt, it); (*fixed_neighbors)[neighbors_cnt_offset + id] = cur_neighbor_cnt; total_neighbor_cnt += cur_neighbor_cnt; } LOG_INFO("total neighbor cnt: %zu, average neighbor cnt: %zu", total_neighbor_cnt, total_neighbor_cnt / doc_cnt()); return 0; } bool HnswRabitqSearcherEntity::do_crc_check( std::vector &segments) const { constexpr size_t blk_size = 4096; const void *data; for (auto &segment : segments) { size_t offset = 0; size_t rd_size; uint32_t crc = 0; while (offset < segment->data_size()) { size_t size = std::min(blk_size, segment->data_size() - offset); if ((rd_size = segment->read(offset, &data, size)) <= 0) { break; } offset += rd_size; crc = ailego::Crc32c::Hash(data, rd_size, crc); } if (crc != segment->data_crc()) { return false; } } return true; } const HnswRabitqEntity::Pointer HnswRabitqSearcherEntity::clone() const { auto vectors = vectors_->clone(); if (ailego_unlikely(!vectors)) { LOG_ERROR("clone segment %s failed", kGraphFeaturesSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto keys = keys_->clone(); if (ailego_unlikely(!keys)) { LOG_ERROR("clone segment %s failed", kGraphKeysSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto mapping = mapping_->clone(); if (ailego_unlikely(!mapping)) { LOG_ERROR("clone segment %s failed", kGraphMappingSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto neighbors = neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kGraphNeighborsSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto upper_neighbors = upper_neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kHnswNeighborsSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto neighbors_meta = neighbors_meta_->clone(); if (ailego_unlikely(!neighbors_meta)) { LOG_ERROR("clone segment %s failed", kGraphOffsetsSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } auto upper_neighbors_meta = upper_neighbors_meta_->clone(); if (ailego_unlikely(!upper_neighbors_meta)) { LOG_ERROR("clone segment %s failed", kHnswOffsetsSegmentId.c_str()); return HnswRabitqEntity::Pointer(); } SegmentGroupParam neighbor_group{neighbors, neighbors_meta, upper_neighbors, upper_neighbors_meta}; HnswRabitqSearcherEntity *entity = new (std::nothrow) HnswRabitqSearcherEntity(header(), vectors, keys, mapping, neighbor_group, fixed_neighbors_, neighbors_in_memory_enabled_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswRabitqSearcherEntity new failed"); } return HnswRabitqEntity::Pointer(entity); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_searcher_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_rabitq_builder_entity.h" #include "hnsw_rabitq_entity.h" namespace zvec { namespace core { class HnswRabitqSearcherEntity : public HnswRabitqEntity { public: using Pointer = std::shared_ptr; using SegmentPointer = IndexStorage::Segment::Pointer; public: struct SegmentGroupParam { SegmentGroupParam(SegmentPointer neighbors_in, SegmentPointer neighbors_meta_in, SegmentPointer upper_neighbors_in, SegmentPointer upper_neighbors_meta_in) : neighbors{neighbors_in}, neighbors_meta{neighbors_meta_in}, upper_neighbors{upper_neighbors_in}, upper_neighbors_meta{upper_neighbors_meta_in} {} SegmentPointer neighbors{nullptr}; SegmentPointer neighbors_meta{nullptr}; SegmentPointer upper_neighbors{nullptr}; SegmentPointer upper_neighbors_meta{nullptr}; }; //! Constructor HnswRabitqSearcherEntity(); //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswRabitqEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector local id by key node_id_t get_id(key_t key) const; //! Get vector feature data by key virtual const void *get_vector_by_key(key_t key) const override; //! Get vector feature data by id virtual const void *get_vector(node_id_t id) const override; //! Get vector feature data by id virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get all vectors const void *get_vectors() const; //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; virtual int load(const IndexStorage::Pointer &container, bool check_crc) override; int load_segments(bool check_crc); virtual int cleanup(void) override; public: bool is_loaded() const { return loaded_; } void set_neighbors_in_memory(bool enabled) { neighbors_in_memory_enabled_ = enabled; } //! get fixed length neighbors data int get_fixed_neighbors(std::vector *fixed_neighbors) const; private: //! Constructor HnswRabitqSearcherEntity(const HNSWHeader &hd, const SegmentPointer &vectors, const SegmentPointer &keys, const SegmentPointer &mapping, const SegmentGroupParam &neighbor_group, const std::shared_ptr &fixed_neighbors, bool neighbors_in_memory_enabled) : HnswRabitqEntity(hd), vectors_(vectors), keys_(keys), mapping_(mapping), neighbors_(neighbor_group.neighbors), neighbors_meta_(neighbor_group.neighbors_meta), upper_neighbors_(neighbor_group.upper_neighbors), upper_neighbors_meta_(neighbor_group.upper_neighbors_meta), neighbors_in_memory_enabled_(neighbors_in_memory_enabled) { segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt()), IndexStorage::SegmentData(0U, 0U)); fixed_neighbors_ = fixed_neighbors; } bool do_crc_check(std::vector &segments) const; inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } //! If neighbors_in_memory_enabled, load the level0 neighbors to memory int load_and_flat_neighbors(void); public: HnswRabitqSearcherEntity(const HnswRabitqSearcherEntity &) = delete; HnswRabitqSearcherEntity &operator=(const HnswRabitqSearcherEntity &) = delete; private: IndexStorage::Pointer storage_{}; SegmentPointer vectors_{}; SegmentPointer keys_{}; SegmentPointer mapping_{}; SegmentPointer neighbors_{}; SegmentPointer neighbors_meta_{}; SegmentPointer upper_neighbors_{}; SegmentPointer upper_neighbors_meta_{}; mutable std::vector segment_datas_{}; std::shared_ptr fixed_neighbors_{}; // level 0 fixed size neighbors bool neighbors_in_memory_enabled_{false}; bool loaded_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_streamer.h" #include #include #include #include #include #include #include "algorithm/hnsw_rabitq/rabitq_reformer.h" #include "zvec/ailego/container/params.h" #include "zvec/ailego/logger/logger.h" #include "hnsw_rabitq_algorithm.h" #include "hnsw_rabitq_context.h" #include "hnsw_rabitq_dist_calculator.h" #include "hnsw_rabitq_index_provider.h" #include "hnsw_rabitq_query_entity.h" #include "rabitq_params.h" #include "rabitq_utils.h" namespace zvec { namespace core { HnswRabitqStreamer::HnswRabitqStreamer() : entity_(stats_) {} HnswRabitqStreamer::HnswRabitqStreamer(IndexProvider::Pointer provider, RabitqReformer::Pointer reformer) : entity_(stats_), reformer_(std::move(reformer)), provider_(std::move(provider)) {} HnswRabitqStreamer::~HnswRabitqStreamer() { if (state_ == STATE_INITED) { this->cleanup(); } } int HnswRabitqStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; meta_.set_streamer("HnswRabitqStreamer", HnswRabitqEntity::kRevision, params); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_INDEX_SIZE, &max_index_size_); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswRabitqEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_RABITQ_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; multiplier = HnswRabitqEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_RABITQ_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR, &scaling_factor_); params.get(PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT, &docs_hard_limit_); params.get(PARAM_HNSW_RABITQ_STREAMER_EF, &ef_); params.get(PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params.get(PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_prob_); params.get(PARAM_HNSW_RABITQ_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); params.get(PARAM_HNSW_RABITQ_STREAMER_CHECK_CRC_ENABLE, &check_crc_enabled_); params.get(PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE, &chunk_size_); params.get(PARAM_HNSW_RABITQ_STREAMER_FILTER_SAME_KEY, &filter_same_key_); params.get(PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE, &get_vector_enabled_); params.get(PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); params.get(PARAM_HNSW_RABITQ_STREAMER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); params.get(PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP, &use_id_map_); entity_.set_use_key_info_map(use_id_map_); params.get(PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT, &docs_soft_limit_); if (docs_soft_limit_ > 0 && docs_soft_limit_ > docs_hard_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str(), PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); return IndexError_InvalidArgument; } else if (docs_soft_limit_ == 0UL) { docs_soft_limit_ = docs_hard_limit_ * HnswRabitqEntity::kDefaultDocsSoftLimitRatio; } if (ef_ == 0U) { ef_ = HnswRabitqEntity::kDefaultEf; } if (ef_construction_ == 0U) { ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0U) { upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d)", PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), HnswRabitqEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0U) { l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswRabitqEntity::kMaxNeighborCnt) { LOG_ERROR("MaxL0NeighborCnt must be in range (0,%d)", HnswRabitqEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%zu] must be <= [%s]-[%zu]", PARAM_HNSW_RABITQ_STREAMER_MIN_NEIGHBOR_COUNT.c_str(), static_cast(min_neighbor_cnt_), PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), static_cast(upper_max_neighbor_cnt_)); return IndexError_InvalidArgument; } if (bf_negative_prob_ <= 0.0f || bf_negative_prob_ >= 1.0f) { LOG_ERROR( "[%s] must be in range (0,1)", PARAM_HNSW_RABITQ_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_RABITQ_STREAMER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_RABITQ_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_RABITQ_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } if (chunk_size_ == 0UL) { chunk_size_ = HnswRabitqEntity::kDefaultChunkSize; } if (chunk_size_ > HnswRabitqEntity::kMaxChunkSize) { LOG_ERROR("[%s] must be < %zu", PARAM_HNSW_RABITQ_STREAMER_CHUNK_SIZE.c_str(), HnswRabitqEntity::kMaxChunkSize); return IndexError_InvalidArgument; } uint32_t total_bits = 0; params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); if (total_bits == 0) { total_bits = kDefaultRabitqTotalBits; } if (total_bits < 1 || total_bits > 9) { LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); return IndexError_InvalidArgument; } uint8_t ex_bits = total_bits - 1; entity_.set_ex_bits(ex_bits); uint32_t dimension = 0; params.get(PARAM_HNSW_RABITQ_GENERAL_DIMENSION, &dimension); if (dimension == 0) { LOG_ERROR("%s not set", PARAM_HNSW_RABITQ_GENERAL_DIMENSION.c_str()); return IndexError_InvalidArgument; } if (dimension < kMinRabitqDimSize || dimension > kMaxRabitqDimSize) { LOG_ERROR("Invalid dimension: %u, must be in [%d, %d]", dimension, kMinRabitqDimSize, kMaxRabitqDimSize); return IndexError_InvalidArgument; } entity_.update_rabitq_params_and_vector_size(dimension); entity_.set_ef_construction(ef_construction_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_prune_cnt(prune_cnt); entity_.set_chunk_size(chunk_size_); entity_.set_filter_same_key(filter_same_key_); entity_.set_get_vector(get_vector_enabled_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); int ret = entity_.init(docs_hard_limit_); if (ret != 0) { LOG_ERROR("Hnsw entity init failed for %s", IndexError::What(ret)); return ret; } LOG_DEBUG( "Init params: maxIndexSize=%zu docsHardLimit=%zu docsSoftLimit=%zu " "efConstruction=%u ef=%u upperMaxNeighborCnt=%u l0MaxNeighborCnt=%u " "scalingFactor=%u maxScanRatio=%.3f minScanLimit=%zu maxScanLimit=%zu " "bfEnabled=%d bruteFoceThreshold=%zu bfNegativeProbability=%.5f " "checkCrcEnabled=%d pruneSize=%zu vectorSize=%u chunkSize=%zu " "filterSameKey=%u getVectorEnabled=%u minNeighborCount=%u " "forcePadding=%u ", max_index_size_, docs_hard_limit_, docs_soft_limit_, ef_construction_, ef_, upper_max_neighbor_cnt_, l0_max_neighbor_cnt_, scaling_factor_, max_scan_ratio_, min_scan_limit_, max_scan_limit_, bf_enabled_, bruteforce_threshold_, bf_negative_prob_, check_crc_enabled_, prune_cnt, meta_.element_size(), chunk_size_, filter_same_key_, get_vector_enabled_, min_neighbor_cnt_, force_padding_topk_enabled_); alg_ = HnswRabitqAlgorithm::UPointer(new HnswRabitqAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswRabitqStreamer::cleanup(void) { if (state_ == STATE_OPENED) { this->close(); } LOG_INFO("HnswRabitqStreamer cleanup"); meta_.clear(); metric_.reset(); stats_.clear(); entity_.cleanup(); if (alg_) { alg_->cleanup(); } max_index_size_ = 0UL; docs_hard_limit_ = HnswRabitqEntity::kDefaultDocsHardLimit; docs_soft_limit_ = 0UL; upper_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultUpperMaxNeighborCnt; l0_max_neighbor_cnt_ = HnswRabitqEntity::kDefaultL0MaxNeighborCnt; ef_ = HnswRabitqEntity::kDefaultEf; ef_construction_ = HnswRabitqEntity::kDefaultEfConstruction; bf_enabled_ = false; scaling_factor_ = HnswRabitqEntity::kDefaultScalingFactor; bruteforce_threshold_ = HnswRabitqEntity::kDefaultBruteForceThreshold; max_scan_limit_ = HnswRabitqEntity::kDefaultMaxScanLimit; min_scan_limit_ = HnswRabitqEntity::kDefaultMinScanLimit; chunk_size_ = HnswRabitqEntity::kDefaultChunkSize; bf_negative_prob_ = HnswRabitqEntity::kDefaultBFNegativeProbability; max_scan_ratio_ = HnswRabitqEntity::kDefaultScanRatio; state_ = STATE_INIT; check_crc_enabled_ = false; filter_same_key_ = false; get_vector_enabled_ = false; return 0; } int HnswRabitqStreamer::open(IndexStorage::Pointer stg) { LOG_INFO("HnswRabitqStreamer open"); if (ailego_unlikely(state_ != STATE_INITED)) { LOG_ERROR("Open storage failed, init streamer first!"); return IndexError_NoReady; } // try to load reformer if (reformer_ == nullptr) { reformer_ = std::make_shared(); ailego::Params reformer_params; reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); int ret = reformer_->init(reformer_params); if (ret != 0) { LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); return ret; } ret = reformer_->load(stg); if (ret != 0) { LOG_ERROR("Failed to load reformer, ret=%d", ret); return ret; } } else { if (!stg->has(RABITQ_CONVERTER_SEG_ID)) { int ret = reformer_->dump(stg); if (ret != 0) { LOG_ERROR("Failed to dump reformer, ret=%d", ret); return ret; } LOG_INFO("Dump reformer success."); } } int ret = entity_.open(std::move(stg), max_index_size_, check_crc_enabled_); if (ret != 0) { return ret; } IndexMeta index_meta; ret = entity_.get_index_meta(&index_meta); if (ret == IndexError_NoExist) { // Set IndexMeta for the new index ret = entity_.set_index_meta(meta_); if (ret != 0) { LOG_ERROR("Failed to set index meta for %s", IndexError::What(ret)); return ret; } } else if (ret != 0) { LOG_ERROR("Failed to get index meta for %s", IndexError::What(ret)); return ret; } else { if (index_meta.dimension() != meta_.dimension() || index_meta.element_size() != meta_.element_size() || index_meta.metric_name() != meta_.metric_name() || index_meta.data_type() != meta_.data_type()) { LOG_ERROR("IndexMeta mismatch from the previous in index"); return IndexError_Mismatch; } // The IndexMetric Params may be updated like MipsSquaredEuclidean auto metric_params = index_meta.metric_params(); metric_params.merge(meta_.metric_params()); meta_.set_metric(index_meta.metric_name(), 0, metric_params); } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("Failed to create metric %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("Failed to init metric, ret=%d", ret); return ret; } if (!metric_->distance()) { LOG_ERROR("Invalid metric distance"); return IndexError_InvalidArgument; } if (!metric_->batch_distance()) { LOG_ERROR("Invalid metric batch distance"); return IndexError_InvalidArgument; } add_distance_ = metric_->distance(); add_batch_distance_ = metric_->batch_distance(); search_distance_ = add_distance_; search_batch_distance_ = add_batch_distance_; if (metric_->query_metric() && metric_->query_metric()->distance() && metric_->query_metric()->batch_distance()) { search_distance_ = metric_->query_metric()->distance(); search_batch_distance_ = metric_->query_metric()->batch_distance(); } state_ = STATE_OPENED; magic_ = IndexContext::GenerateMagic(); query_alg_ = HnswRabitqQueryAlgorithm::UPointer(new HnswRabitqQueryAlgorithm( entity_, reformer_->num_clusters(), reformer_->rabitq_metric_type())); return 0; } int HnswRabitqStreamer::close(void) { LOG_INFO("HnswRabitqStreamer close"); stats_.clear(); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); int ret = entity_.close(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswRabitqStreamer::flush(uint64_t checkpoint) { LOG_INFO("HnswRabitqStreamer flush checkpoint=%zu", (size_t)checkpoint); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); return entity_.flush(checkpoint); } int HnswRabitqStreamer::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("HnswRabitqStreamer dump"); shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); meta_.set_searcher("HnswRabitqSearcher", HnswRabitqEntity::kRevision, ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } ret = reformer_->dump(dumper); if (ret != 0) { LOG_ERROR("Failed to dump reformer into dumper."); return ret; } return entity_.dump(dumper); } IndexStreamer::Context::Pointer HnswRabitqStreamer::create_context(void) const { if (ailego_unlikely(state_ != STATE_OPENED)) { LOG_ERROR("Create context failed, open storage first!"); return Context::Pointer(); } HnswRabitqEntity::Pointer entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("CreateContext clone init failed"); return Context::Pointer(); } HnswRabitqContext *ctx = new (std::nothrow) HnswRabitqContext(meta_.dimension(), metric_, entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswRabitqContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_filter_mode(bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap); ctx->set_filter_negative_probability(bf_negative_prob_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswRabitqContext::kStreamerContext)) != 0) { LOG_ERROR("Init HnswRabitqContext failed"); delete ctx; return Context::Pointer(); } uint32_t estimate_doc_count = 0; if (meta_.streamer_params().get(PARAM_HNSW_RABITQ_STREAMER_ESTIMATE_DOC_COUNT, &estimate_doc_count)) { LOG_DEBUG("HnswRabitqStreamer doc_count[%zu] estimate[%zu]", (size_t)entity_.doc_cnt(), (size_t)estimate_doc_count); } ctx->check_need_adjuct_ctx(std::max(entity_.doc_cnt(), estimate_doc_count)); return Context::Pointer(ctx); } IndexProvider::Pointer HnswRabitqStreamer::create_provider(void) const { LOG_DEBUG("HnswRabitqStreamer create provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswRabitqEntity failed"); return nullptr; } return Provider::Pointer( new HnswRabitqIndexProvider(meta_, entity, "HnswRabitqStreamer")); } int HnswRabitqStreamer::update_context(HnswRabitqContext *ctx) const { const HnswRabitqEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswRabitqContext::kStreamerContext, meta_, metric_, entity, magic_); } //! Add a vector with id into index int HnswRabitqStreamer::add_with_id_impl( uint32_t id, const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) { if (!provider_) { LOG_ERROR("Provider is nullptr, cannot add vector"); return IndexError_InvalidArgument; } int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %zu exceed [%s]", static_cast(entity_.doc_cnt()), PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %zu exceed [%s]", static_cast(entity_.doc_cnt()), PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); ctx->reset_query(query); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); ctx->set_provider(provider_); if (metric_->support_train()) { const std::lock_guard lk(mutex_); ret = metric_->train(query, meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer metric train failed"); (*stats_.mutable_discarded_count())++; return ret; } } std::string converted_vector; IndexQueryMeta converted_meta; ret = reformer_->convert(query, qmeta, &converted_vector, &converted_meta); if (ret != 0) { LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); return ret; } level_t level = alg_->get_random_level(); ret = entity_.add_vector_with_id(level, id, converted_vector.data()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw steamer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } //! Add a vector into index int HnswRabitqStreamer::add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) { if (!provider_) { LOG_ERROR("Provider is nullptr, cannot add vector"); return IndexError_InvalidArgument; } int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %zu exceed [%s]", static_cast(entity_.doc_cnt()), PARAM_HNSW_RABITQ_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %zu exceed [%s]", static_cast(entity_.doc_cnt()), PARAM_HNSW_RABITQ_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_, add_batch_distance_); ctx->reset_query(query); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); ctx->set_provider(provider_); if (metric_->support_train()) { const std::lock_guard lk(mutex_); ret = metric_->train(query, meta_.dimension()); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer metric train failed"); (*stats_.mutable_discarded_count())++; return ret; } } std::string converted_vector; IndexQueryMeta converted_meta; ret = reformer_->convert(query, qmeta, &converted_vector, &converted_meta); if (ret != 0) { LOG_ERROR("Rabitq hnsw convert failed, ret=%d", ret); return ret; } level_t level = alg_->get_random_level(); node_id_t id; ret = entity_.add_vector(level, pkey, converted_vector.data(), &id); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw steamer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } int HnswRabitqStreamer::search_impl( const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_impl(query, qmeta, 1, context); } //! Similarity search int HnswRabitqStreamer::search_impl( const void *query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; ret = reformer_->transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher transform failed"); return ret; } ctx->reset_query(query); ret = query_alg_->search(&entity, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } void HnswRabitqStreamer::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; if (neighbours.size() == 0) std::cout << std::endl; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } // entity_.print_key_map(); } int HnswRabitqStreamer::search_bf_impl( const void *query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_bf_impl(query, qmeta, 1, context); } int HnswRabitqStreamer::search_bf_impl( const void *query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; ret = reformer_->transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw rabitq streamer transform failed"); return ret; } ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { EstimateRecord dist; query_alg_->get_full_est(id, dist, entity); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; ret = reformer_->transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw rabitq streamer transform failed"); return ret; } ctx->reset_query(query); ctx->topk_heap().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { EstimateRecord dist; query_alg_->get_full_est(id, dist, entity); ctx->topk_heap().emplace(id, dist); } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } int HnswRabitqStreamer::search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { int ret = check_params(query, qmeta); if (ailego_unlikely(ret != 0)) { return ret; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswRabitqContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswRabitqContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_InvalidArgument; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; ret = reformer_->transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw rabitq streamer transform failed"); return ret; } ctx->reset_query(query); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { EstimateRecord dist; query_alg_->get_full_est(id, dist, entity); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } else { for (size_t q = 0; q < count; ++q) { HnswRabitqQueryEntity entity; ret = reformer_->transform_to_entity(query, &entity); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw rabitq streamer transform failed"); return ret; } ctx->reset_query(query); ctx->topk_heap().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { key_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { EstimateRecord dist; query_alg_->get_full_est(id, dist, entity); ctx->topk_heap().emplace(id, dist); } } } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "algorithm/hnsw_rabitq/rabitq_reformer.h" #include "zvec/core/framework/index_framework.h" #include "zvec/core/framework/index_provider.h" #include "zvec/core/framework/index_reformer.h" #include "hnsw_rabitq_algorithm.h" #include "hnsw_rabitq_query_algorithm.h" #include "hnsw_rabitq_streamer_entity.h" namespace zvec { namespace core { class HnswRabitqStreamer : public IndexStreamer { public: using ContextPointer = IndexStreamer::Context::Pointer; HnswRabitqStreamer(); explicit HnswRabitqStreamer(IndexProvider::Pointer provider, RabitqReformer::Pointer reformer = nullptr); virtual ~HnswRabitqStreamer(void); HnswRabitqStreamer(const HnswRabitqStreamer &streamer) = delete; HnswRabitqStreamer &operator=(const HnswRabitqStreamer &streamer) = delete; void set_provider(IndexProvider::Pointer provider) { provider_ = std::move(provider); } void set_reformer(IndexReformer::Pointer reformer) { reformer_ = std::dynamic_pointer_cast(reformer); } protected: //! Initialize Streamer virtual int init(const IndexMeta &imeta, const ailego::Params ¶ms) override; //! Cleanup Streamer virtual int cleanup(void) override; //! Create a context virtual Context::Pointer create_context(void) const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Add a vector into index virtual int add_impl(uint64_t pkey, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Add a vector with id into index virtual int add_with_id_impl(uint32_t id, const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(query, p_keys, qmeta, 1, context); } //! Linear search by primary keys virtual int search_bf_by_p_keys_impl( const void *query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch vector by key virtual const void *get_vector(uint64_t key) const override { return entity_.get_vector_by_key(key); } virtual int get_vector(const uint64_t key, IndexStorage::MemoryBlock &block) const override { return entity_.get_vector_by_key(key, block); } //! Fetch vector by id virtual const void *get_vector_by_id(uint32_t id) const override { return entity_.get_vector(id); } virtual int get_vector_by_id( const uint32_t id, IndexStorage::MemoryBlock &block) const override { return entity_.get_vector(id, block); } //! Open index from file path virtual int open(IndexStorage::Pointer stg) override; //! Close file virtual int close(void) override; //! flush file virtual int flush(uint64_t checkpoint) override; //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } virtual void print_debug_info() override; private: inline int check_params(const void *query, const IndexQueryMeta &qmeta) const { if (ailego_unlikely(!query)) { LOG_ERROR("null query"); return IndexError_InvalidArgument; } if (ailego_unlikely(qmeta.dimension() != meta_.dimension() || qmeta.data_type() != meta_.data_type() || qmeta.element_size() != meta_.element_size())) { LOG_ERROR("Unsupported query meta"); return IndexError_Mismatch; } return 0; } inline int check_sparse_count_is_zero(const uint32_t *sparse_count, uint32_t count) const { for (uint32_t i = 0; i < count; ++i) { if (sparse_count[i] != 0) LOG_ERROR("Sparse cout is not empty. Index: %u, Sparse Count: %u", i, sparse_count[i]); return IndexError_InvalidArgument; } return 0; } private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswRabitqContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; class Stats : public IndexStreamer::Stats { public: void clear(void) { set_revision_id(0u); set_loaded_count(0u); set_added_count(0u); set_discarded_count(0u); set_index_size(0u); set_dumped_size(0u); set_check_point(0u); set_create_time(0u); set_update_time(0u); clear_attributes(); } }; HnswRabitqStreamerEntity entity_; HnswRabitqAlgorithm::UPointer alg_; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; IndexMetric::MatrixDistance add_distance_{}; IndexMetric::MatrixDistance search_distance_{}; IndexMetric::MatrixBatchDistance add_batch_distance_{}; IndexMetric::MatrixBatchDistance search_batch_distance_{}; RabitqReformer::Pointer reformer_{}; // RaBitQ reformer HnswRabitqQueryAlgorithm::UPointer query_alg_; // query algorithm // provider_ provides raw vector, which is used to build graph IndexProvider::Pointer provider_{}; Stats stats_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; size_t chunk_size_{HnswRabitqEntity::kDefaultChunkSize}; size_t docs_hard_limit_{HnswRabitqEntity::kDefaultDocsHardLimit}; size_t docs_soft_limit_{0UL}; uint32_t min_neighbor_cnt_{0u}; uint32_t upper_max_neighbor_cnt_{ HnswRabitqEntity::kDefaultUpperMaxNeighborCnt}; uint32_t l0_max_neighbor_cnt_{HnswRabitqEntity::kDefaultL0MaxNeighborCnt}; uint32_t ef_{HnswRabitqEntity::kDefaultEf}; uint32_t ef_construction_{HnswRabitqEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswRabitqEntity::kDefaultScalingFactor}; size_t bruteforce_threshold_{HnswRabitqEntity::kDefaultBruteForceThreshold}; size_t max_scan_limit_{HnswRabitqEntity::kDefaultMaxScanLimit}; size_t min_scan_limit_{HnswRabitqEntity::kDefaultMinScanLimit}; float bf_negative_prob_{HnswRabitqEntity::kDefaultBFNegativeProbability}; float max_scan_ratio_{HnswRabitqEntity::kDefaultScanRatio}; uint32_t magic_{0U}; State state_{STATE_INIT}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; bool force_padding_topk_enabled_{false}; bool use_id_map_{true}; //! avoid add vector while dumping index ailego::SharedMutex shared_mutex_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_rabitq_streamer_entity.h" #include // #define DEBUG_PRINT namespace zvec { namespace core { HnswRabitqStreamerEntity::HnswRabitqStreamerEntity(IndexStreamer::Stats &stats) : stats_(stats) {} HnswRabitqStreamerEntity::~HnswRabitqStreamerEntity() {} int HnswRabitqStreamerEntity::init(size_t max_doc_cnt) { if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); broker_ = std::make_shared(stats_); upper_neighbor_index_ = std::make_shared(); keys_map_lock_ = std::make_shared(); keys_map_ = std::make_shared>(); if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { LOG_ERROR("HnswRabitqStreamerEntity new object failed"); return IndexError_NoMemory; } keys_map_->set_empty_key(kInvalidKey); neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); //! vector + key + level 0 neighbors size_t size = vector_size() + sizeof(key_t) + neighbor_size_; size = AlignSize(size); set_node_size(size); return 0; } int HnswRabitqStreamerEntity::cleanup() { std::lock_guard lock(mutex_); mutable_header()->clear(); chunk_size_ = kDefaultChunkSize; node_index_mask_bits_ = 0U; node_index_mask_ = 0U; node_cnt_per_chunk_ = 0U; neighbor_size_ = 0U; upper_neighbor_size_ = 0U; if (upper_neighbor_index_) { upper_neighbor_index_->cleanup(); } if (keys_map_) { keys_map_->clear(); } node_chunks_.clear(); upper_neighbor_chunks_.clear(); filter_same_key_ = false; get_vector_enabled_ = false; broker_.reset(); return 0; } int HnswRabitqStreamerEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { char buffer[neighbor_size_]; NeighborsHeader *hd = reinterpret_cast(buffer); hd->neighbor_cnt = neighbors.size(); size_t i = 0; for (; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } auto loc = get_neighbor_chunk_loc(level, id); size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; size_t ret = loc.first->write(loc.second, hd, size); if (ailego_unlikely(ret != size)) { LOG_ERROR("Write neighbor header failed, ret=%zu", ret); return IndexError_Runtime; } return 0; } const Neighbors HnswRabitqStreamerEntity::get_neighbors(level_t level, node_id_t id) const { Chunk *chunk = nullptr; size_t offset = 0UL; size_t neighbor_size = neighbor_size_; if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; offset = (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); chunk = node_chunks_[chunk_idx].get(); } else { auto p = get_upper_neighbor_chunk_loc(level, id); chunk = upper_neighbor_chunks_[p.first].get(); offset = p.second; neighbor_size = upper_neighbor_size_; } ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); IndexStorage::MemoryBlock neighbor_block; size_t size = chunk->read(offset, neighbor_block, neighbor_size); if (ailego_unlikely(size != neighbor_size)) { LOG_ERROR("Read neighbor header failed, ret=%zu", size); return Neighbors(); } return Neighbors(std::move(neighbor_block)); } //! Get vector data by key const void *HnswRabitqStreamerEntity::get_vector(node_id_t id) const { auto loc = get_vector_chunk_loc(id); const void *vec = nullptr; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", static_cast(loc.second), read_size, ret); } return vec; } int HnswRabitqStreamerEntity::get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const { for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", static_cast(loc.second), read_size, ret); return IndexError_ReadData; } } return 0; } int HnswRabitqStreamerEntity::get_vector( const node_id_t id, IndexStorage::MemoryBlock &block) const { auto loc = get_vector_chunk_loc(id); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", static_cast(loc.second), read_size, ret); return IndexError_ReadData; } return 0; } int HnswRabitqStreamerEntity::get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const { vec_blocks.resize(count); for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = vector_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, vec_blocks[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%zu, read size=%zu, ret=%zu", static_cast(loc.second), read_size, ret); return IndexError_ReadData; } } return 0; } key_t HnswRabitqStreamerEntity::get_key(node_id_t id) const { if (use_key_info_map_) { auto loc = get_key_chunk_loc(id); IndexStorage::MemoryBlock key_block; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t ret = node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); if (ailego_unlikely(ret != sizeof(key_t))) { LOG_ERROR("Read vector failed, ret=%zu", ret); return kInvalidKey; } return *reinterpret_cast(key_block.data()); } else { return id; } } void HnswRabitqStreamerEntity::add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) { auto loc = get_neighbor_chunk_loc(level, id); size_t offset = loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); if (ailego_unlikely(ret != sizeof(node_id_t))) { LOG_ERROR("Write neighbor id failed, ret=%zu", ret); return; } uint32_t neighbors = size + 1; ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); if (ailego_unlikely(ret != sizeof(uint32_t))) { LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); } return; } int HnswRabitqStreamerEntity::init_chunks(const Chunk::Pointer &header_chunk) { if (header_chunk->data_size() < header_size()) { LOG_ERROR("Invalid header chunk size"); return IndexError_InvalidFormat; } IndexStorage::MemoryBlock header_block; size_t size = header_chunk->read(0UL, header_block, header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Read header chunk failed"); return IndexError_ReadData; } *mutable_header() = *reinterpret_cast(header_block.data()); int ret = check_hnsw_index(&header()); if (ret != 0) { broker_->close(); return ret; } node_chunks_.resize( broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_NODE)); for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { node_chunks_[seq] = broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, seq); if (!node_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, node_chunks_.size()); return IndexError_InvalidFormat; } } upper_neighbor_chunks_.resize( broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { upper_neighbor_chunks_[seq] = broker_->get_chunk( HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); if (!upper_neighbor_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, upper_neighbor_chunks_.size()); return IndexError_InvalidFormat; } } return 0; } int HnswRabitqStreamerEntity::open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc) { std::lock_guard lock(mutex_); bool huge_page = stg->isHugePage(); LOG_DEBUG("huge_page: %d", (int)huge_page); int ret = init_chunk_params(max_index_size, huge_page); if (ailego_unlikely(ret != 0)) { LOG_ERROR("init_chunk_params failed for %s", IndexError::What(ret)); return ret; } ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Open index failed for %s", IndexError::What(ret)); return ret; } ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, scaling_factor(), estimate_doc_capacity(), kUpperHashMemoryInflateRatio); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Init neighbor hash map failed"); return ret; } //! init header auto header_chunk = broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, HnswRabitqChunkBroker::kDefaultChunkSeqId); if (!header_chunk) { // open empty index, create one auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, HnswRabitqChunkBroker::kDefaultChunkSeqId, header_size()); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc header chunk failed"); return p.first; } size_t size = p.second->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } //! Open an exist hnsw index ret = init_chunks(header_chunk); if (ailego_unlikely(ret != 0)) { return ret; } //! total docs including features wrote in index but neighbors may not ready node_id_t total_vecs = 0; if (node_chunks_.size() > 0) { size_t last_idx = node_chunks_.size() - 1; auto last_chunk = node_chunks_[last_idx]; if (last_chunk->data_size() % node_size()) { LOG_WARN("The index may broken"); return IndexError_InvalidFormat; } total_vecs = last_idx * node_cnt_per_chunk_ + node_chunks_[last_idx]->data_size() / node_size(); } LOG_INFO( "Open index, l0NeighborCnt=%zu upperNeighborCnt=%zu " "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), total_vecs, cur_max_level()); //! try to correct the docCnt if index not fully flushed if (doc_cnt() != total_vecs) { LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); *mutable_doc_cnt() = total_vecs; } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { for (node_id_t id = 0U; id < doc_cnt(); ++id) { if (get_key(id) == kInvalidKey) { continue; } (*keys_map_)[get_key(id)] = id; } } } stats_.set_loaded_count(doc_cnt()); return 0; } int HnswRabitqStreamerEntity::close() { LOG_DEBUG("close index"); std::lock_guard lock(mutex_); flush_header(); mutable_header()->reset(); upper_neighbor_index_->cleanup(); keys_map_->clear(); header_.clear(); node_chunks_.clear(); upper_neighbor_chunks_.clear(); return broker_->close(); } int HnswRabitqStreamerEntity::flush(uint64_t checkpoint) { LOG_INFO("Flush index, curDocs=%zu", static_cast(doc_cnt())); std::lock_guard lock(mutex_); flush_header(); int ret = broker_->flush(checkpoint); if (ret != 0) { return ret; } return 0; } int HnswRabitqStreamerEntity::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("Dump index, curDocs=%zu", static_cast(doc_cnt())); //! sort by keys, to support get_vector by key in searcher std::vector keys(doc_cnt()); for (node_id_t i = 0; i < doc_cnt(); ++i) { keys[i] = get_key(i); } //! dump neighbors auto get_level = [&](node_id_t id) { auto it = upper_neighbor_index_->find(id); if (it == upper_neighbor_index_->end()) { return 0U; }; auto meta = reinterpret_cast(&it->second); return meta->level; }; auto ret = dump_segments(dumper, keys.data(), get_level); if (ailego_unlikely(ret < 0)) { return ret; } *stats_.mutable_dumped_size() += ret; return 0; } int HnswRabitqStreamerEntity::check_hnsw_index(const HNSWHeader *hd) const { if (l0_neighbor_cnt() != hd->l0_neighbor_cnt() || upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { LOG_ERROR("Param neighbor cnt: %zu:%zu mismatch index previous %zu:%zu", l0_neighbor_cnt(), upper_neighbor_cnt(), hd->l0_neighbor_cnt(), hd->upper_neighbor_cnt()); return IndexError_Mismatch; } if (vector_size() != hd->vector_size()) { LOG_ERROR("vector size %zu mismatch index previous %zu", vector_size(), hd->vector_size()); return IndexError_Mismatch; } if (ef_construction() != hd->ef_construction()) { LOG_WARN("Param efConstruction %zu mismatch index previous %zu", ef_construction(), hd->ef_construction()); } if (scaling_factor() != hd->scaling_factor()) { LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", scaling_factor(), hd->scaling_factor()); return IndexError_Mismatch; } if (prune_cnt() != hd->neighbor_prune_cnt()) { LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), hd->neighbor_prune_cnt()); return IndexError_Mismatch; } if ((hd->entry_point() != kInvalidNodeId && hd->entry_point() >= hd->doc_cnt()) || (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { LOG_WARN("Invalid entryPoint %zu, docCnt %zu", static_cast(hd->entry_point()), static_cast(hd->doc_cnt())); return IndexError_InvalidFormat; } if (hd->entry_point() == kInvalidNodeId && broker_->get_chunk_cnt(HnswRabitqChunkBroker::CHUNK_TYPE_NODE) > 0) { LOG_WARN("The index is broken, maybe it haven't flush"); return IndexError_InvalidFormat; } return 0; } int HnswRabitqStreamerEntity::add_vector(level_t level, key_t key, const void *vec, node_id_t *id) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } node_id_t local_id = static_cast(doc_cnt()); uint32_t chunk_index = node_chunks_.size() - 1U; if (chunk_index == -1U || (node_chunks_[chunk_index]->data_size() >= node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } chunk_index++; auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; chunk_offset = 0UL; node_chunks_.emplace_back(node_chunk); } else { node_chunk = node_chunks_[chunk_index]; chunk_offset = node_chunk->data_size(); } size_t size = node_chunk->write(chunk_offset, vec, vector_size()); if (ailego_unlikely(size != vector_size())) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, local_id); if (ret != 0) { return ret; } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { keys_map_lock_->lock(); (*keys_map_)[key] = local_id; keys_map_lock_->unlock(); } } *mutable_doc_cnt() += 1; broker_->mark_dirty(); *id = local_id; return 0; } int HnswRabitqStreamerEntity::add_vector_with_id(level_t level, node_id_t id, const void *vec) { Chunk::Pointer node_chunk; size_t chunk_offset = -1UL; key_t key = id; std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } // set node_chunk & chunk_offset if succeed auto func_get_node_chunk_and_offset = [&](node_id_t node_id) -> int { uint32_t chunk_index = node_id >> node_index_mask_bits_; ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); // belongs to next chunk if (chunk_index == node_chunks_.size()) { if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; node_chunks_.emplace_back(node_chunk); } node_chunk = node_chunks_[chunk_index]; chunk_offset = (node_id & node_index_mask_) * node_size(); return 0; }; for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { if (auto ret = func_get_node_chunk_and_offset(start_id); ret != 0) { LOG_ERROR("func_get_node_chunk_and_offset failed"); return ret; } size_t size = node_chunk->write(chunk_offset + vector_size(), &kInvalidKey, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write key failed, ret=%zu", size); return IndexError_WriteData; } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } if (auto ret = func_get_node_chunk_and_offset(id); ret != 0) { LOG_ERROR("func_get_node_chunk_and_offset failed"); return ret; } size_t size = node_chunk->write(chunk_offset, vec, vector_size()); if (ailego_unlikely(size != vector_size())) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + vector_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("Chunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, id); if (ret != 0) { return ret; } if (*mutable_doc_cnt() <= id) { *mutable_doc_cnt() = id + 1; chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } if (filter_same_key_ || get_vector_enabled_) { if (use_key_info_map_) { keys_map_lock_->lock(); (*keys_map_)[key] = id; keys_map_lock_->unlock(); } } broker_->mark_dirty(); return 0; } void HnswRabitqStreamerEntity::update_ep_and_level(node_id_t ep, level_t level) { HnswRabitqEntity::update_ep_and_level(ep, level); flush_header(); return; } const HnswRabitqEntity::Pointer HnswRabitqStreamerEntity::clone() const { std::vector node_chunks; node_chunks.reserve(node_chunks_.size()); for (size_t i = 0UL; i < node_chunks_.size(); ++i) { node_chunks.emplace_back(node_chunks_[i]->clone()); if (ailego_unlikely(!node_chunks[i])) { LOG_ERROR("HnswRabitqStreamerEntity get chunk failed in clone"); return HnswRabitqEntity::Pointer(); } } std::vector upper_neighbor_chunks; upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); if (ailego_unlikely(!upper_neighbor_chunks[i])) { LOG_ERROR("HnswRabitqStreamerEntity get chunk failed in clone"); return HnswRabitqEntity::Pointer(); } } HnswRabitqStreamerEntity *entity = new (std::nothrow) HnswRabitqStreamerEntity( stats_, header(), chunk_size_, node_index_mask_bits_, upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, upper_neighbor_index_, keys_map_lock_, keys_map_, use_key_info_map_, std::move(node_chunks), std::move(upper_neighbor_chunks), broker_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswRabitqStreamerEntity new failed"); } return HnswRabitqEntity::Pointer(entity); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/hnsw_rabitq_streamer_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "zvec/core/framework/index_framework.h" #include "hnsw_rabitq_chunk.h" #include "hnsw_rabitq_entity.h" #include "hnsw_rabitq_index_hash.h" #include "hnsw_rabitq_params.h" namespace zvec { namespace core { //! HnswRabitqStreamerEntity manage vector data, pkey, and node's neighbors class HnswRabitqStreamerEntity : public HnswRabitqEntity { public: //! Cleanup //! return 0 on success, or errCode in failure virtual int cleanup() override; //! Make a copy of streamer entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswRabitqEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector(node_id_t id) const override; //! Get vectors feature data by local ids virtual int get_vector(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector(const node_id_t id, IndexStorage::MemoryBlock &block) const override; virtual int get_vector( const node_id_t *ids, uint32_t count, std::vector &vec_blocks) const override; //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t level, key_t key, const void *vec, node_id_t *id) override; //! Add vector and id to hnsw entity virtual int add_vector_with_id(level_t level, node_id_t id, const void *vec) override; virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! Append neighbor_id to node id neighbors on level //! Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Dump index by dumper virtual int dump(const IndexDumper::Pointer &dumper) override; virtual void update_ep_and_level(node_id_t ep, level_t level) override; void set_use_key_info_map(bool use_id_map) { use_key_info_map_ = use_id_map; LOG_DEBUG("use_key_info_map_: %d", (int)use_key_info_map_); } public: //! Constructor HnswRabitqStreamerEntity(IndexStreamer::Stats &stats); //! Destructor ~HnswRabitqStreamerEntity(); //! Get vector feature data by key virtual const void *get_vector_by_key(key_t key) const override { auto id = get_id(key); return id == kInvalidNodeId ? nullptr : get_vector(id); } virtual int get_vector_by_key( const key_t key, IndexStorage::MemoryBlock &block) const override { auto id = get_id(key); if (id != kInvalidNodeId) { return get_vector(id, block); } else { return IndexError_InvalidArgument; } } //! Init entity int init(size_t max_doc_cnt); //! Flush graph entity to disk //! return 0 on success, or errCode in failure int flush(uint64_t checkpoint); //! Open entity from storage //! return 0 on success, or errCode in failure int open(IndexStorage::Pointer stg, uint64_t max_index_size, bool check_crc); //! Close entity //! return 0 on success, or errCode in failure int close(); //! Set meta information from entity int set_index_meta(const IndexMeta &meta) const { return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); } //! Get meta information from entity int get_index_meta(IndexMeta *meta) const { return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); } //! Set params: chunk size inline void set_chunk_size(size_t val) { chunk_size_ = val; } //! Set params inline void set_filter_same_key(bool val) { filter_same_key_ = val; } //! Set params inline void set_get_vector(bool val) { get_vector_enabled_ = val; } //! Get vector local id by key inline node_id_t get_id(key_t key) const { if (use_key_info_map_) { keys_map_lock_->lock_shared(); auto it = keys_map_->find(key); keys_map_lock_->unlock_shared(); return it == keys_map_->end() ? kInvalidNodeId : it->second; } else { return key; } } void print_key_map() const { std::cout << "key map begins" << std::endl; auto iter = keys_map_->begin(); while (iter != keys_map_->end()) { std::cout << "key: " << iter->first << ", id: " << iter->second << std::endl; ; iter++; } std::cout << "key map ends" << std::endl; } //! Get l0 neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get neighbors size for level > 0 inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } private: union UpperNeighborIndexMeta { struct { uint32_t level : 4; uint32_t index : 28; // index is composite type: chunk idx, and the // N th neighbors in chunk, they two composite // the 28 bits location }; uint32_t data; }; template using HashMap = google::dense_hash_map>; template using HashMapPointer = std::shared_ptr>; template using HashSet = google::dense_hash_set>; template using HashSetPointer = std::shared_ptr>; //! upper neighbor index hashmap using NIHashMap = HnswIndexHashMap; using NIHashMapPointer = std::shared_ptr; //! Private construct, only be called by clone method HnswRabitqStreamerEntity(IndexStreamer::Stats &stats, const HNSWHeader &hd, size_t chunk_size, uint32_t node_index_mask_bits, uint32_t upper_neighbor_mask_bits, bool filter_same_key, bool get_vector_enabled, const NIHashMapPointer &upper_neighbor_index, std::shared_ptr &keys_map_lock, const HashMapPointer &keys_map, bool use_key_info_map, std::vector &&node_chunks, std::vector &&upper_neighbor_chunks, const HnswRabitqChunkBroker::Pointer &broker) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), node_cnt_per_chunk_(1UL << node_index_mask_bits_), node_index_mask_(node_cnt_per_chunk_ - 1), upper_neighbor_mask_bits_(upper_neighbor_mask_bits), upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), filter_same_key_(filter_same_key), get_vector_enabled_(get_vector_enabled), use_key_info_map_(use_key_info_map), upper_neighbor_index_(upper_neighbor_index), keys_map_lock_(keys_map_lock), keys_map_(keys_map), node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), broker_(broker) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); } //! Called only in searching procedure per context, so no need to lock void sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE type, size_t idx, std::vector *chunks) const { if (ailego_likely(idx < chunks->size())) { return; } for (size_t i = chunks->size(); i <= idx; ++i) { auto chunk = broker_->get_chunk(type, i); // the storage can ensure get chunk will success after the first get ailego_assert_with(!!chunk, "get chunk failed"); chunks->emplace_back(std::move(chunk)); } } //! return pair: chunk index + chunk offset inline std::pair get_vector_chunk_loc( node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size(); sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } //! return pair: chunk index + chunk offset inline std::pair get_key_chunk_loc(node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } inline std::pair get_upper_neighbor_chunk_loc( level_t level, node_id_t id) const { auto it = upper_neighbor_index_->find(id); ailego_assert_abort(it != upper_neighbor_index_->end(), "Get upper neighbor header failed"); auto meta = reinterpret_cast(&it->second); uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * upper_neighbor_size_; sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, &upper_neighbor_chunks_); ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(chunk_idx, offset); } //! return pair: chunk + chunk offset inline std::pair get_neighbor_chunk_loc(level_t level, node_id_t id) const { if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); sync_chunks(HnswRabitqChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(node_chunks_[chunk_idx].get(), offset); } else { auto p = get_upper_neighbor_chunk_loc(level, id); return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); } } //! Chunk hnsw index valid int check_hnsw_index(const HNSWHeader *hd) const; size_t get_total_upper_neighbors_size(level_t level) const { return level * upper_neighbor_size_; } //! Add upper neighbor header and reserve space for upper neighbor int add_upper_neighbor(level_t level, node_id_t id) { if (level == 0) { return 0; } Chunk::Pointer chunk; uint64_t chunk_offset = -1UL; size_t neighbors_size = get_total_upper_neighbors_size(level); uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; if (chunk_index == -1UL || (upper_neighbor_chunks_[chunk_index]->padding_size() < neighbors_size)) { // no space left and need to alloc chunk_index++; if (ailego_unlikely(upper_neighbor_chunks_.capacity() == upper_neighbor_chunks_.size())) { LOG_ERROR("add upper neighbor failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_index, upper_neighbor_chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } chunk = p.second; chunk_offset = 0UL; upper_neighbor_chunks_.emplace_back(chunk); } else { chunk = upper_neighbor_chunks_[chunk_index]; chunk_offset = chunk->data_size(); } ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, "invalid offset"); ailego_assert_with((chunk_offset / upper_neighbor_size_) < (1U << upper_neighbor_mask_bits_), "invalid offset"); ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), "invalid chunk index"); UpperNeighborIndexMeta meta; meta.level = level; meta.index = (chunk_index << upper_neighbor_mask_bits_) | (chunk_offset / upper_neighbor_size_); chunk_offset += upper_neighbor_size_ * level; if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { LOG_ERROR("HashMap insert value failed"); return IndexError_Runtime; } if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", (size_t)chunk_offset); return IndexError_Runtime; } return 0; } size_t estimate_doc_capacity() const { return node_chunks_.capacity() * node_cnt_per_chunk_; } int init_chunk_params(size_t max_index_size, bool huge_page) { node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); //! align node cnt per chunk to pow of 2 node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; if (huge_page) { chunk_size_ = AlignHugePageSize(node_cnt_per_chunk_ * node_size()); } else { chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); } node_index_mask_ = node_cnt_per_chunk_ - 1; if (max_index_size == 0UL) { max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; } else { max_index_size_ = max_index_size; } //! To get a balanced upper neighbor chunk size. //! If the upper chunk size is equal to node chunk size, it may waste //! upper neighbor chunk space; if the upper neighbor chunk size is too //! small, the will need large upper neighbor chunks index space. So to //! get a balanced ratio be sqrt of the node/neighbor size ratio float ratio = std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); if (huge_page) { upper_neighbor_chunk_size_ = AlignHugePageSize( std::max(get_total_upper_neighbors_size(kMaxGraphLayers), static_cast(chunk_size_ / ratio))); } else { upper_neighbor_chunk_size_ = AlignPageSize( std::max(get_total_upper_neighbors_size(kMaxGraphLayers), static_cast(chunk_size_ / ratio))); } upper_neighbor_mask_bits_ = std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); size_t max_upper_chunk_cnt = std::ceil( (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / (upper_neighbor_chunk_size_ / upper_neighbor_size_)); max_upper_chunk_cnt = max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); //! reserve space to avoid memmove in chunks vector emplace chunk, so //! as to lock-free in reading chunk node_chunks_.reserve(max_node_chunk_cnt); upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); LOG_DEBUG( "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " "upperNeighborChunkSize=%u " "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " "maxIndexSize=%zu ratio=%.3f", node_size(), chunk_size_, upper_neighbor_size_, upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, max_upper_chunk_cnt, max_index_size_, ratio); return 0; } //! Init node chunk and neighbor chunks int init_chunks(const Chunk::Pointer &header_chunk); int flush_header(void) { if (!broker_->dirty()) { // do not need to flush return 0; } auto header_chunk = broker_->get_chunk(HnswRabitqChunkBroker::CHUNK_TYPE_HEADER, HnswRabitqChunkBroker::kDefaultChunkSeqId); if (ailego_unlikely(!header_chunk)) { LOG_ERROR("get header chunk failed"); return IndexError_Runtime; } size_t size = header_chunk->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } private: HnswRabitqStreamerEntity(const HnswRabitqStreamerEntity &) = delete; HnswRabitqStreamerEntity &operator=(const HnswRabitqStreamerEntity &) = delete; static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; private: IndexStreamer::Stats &stats_; HNSWHeader header_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; uint32_t chunk_size_{kDefaultChunkSize}; uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; uint32_t node_index_mask_bits_{0U}; uint32_t node_cnt_per_chunk_{0U}; uint32_t node_index_mask_{0U}; uint32_t neighbor_size_{0U}; uint32_t upper_neighbor_size_{0U}; //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the //! following mask uint32_t upper_neighbor_mask_bits_{0U}; uint32_t upper_neighbor_mask_{0U}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; bool use_key_info_map_{true}; NIHashMapPointer upper_neighbor_index_{}; mutable std::shared_ptr keys_map_lock_{}; HashMapPointer keys_map_{}; //! the chunks will be changed in searcher, so need mutable //! data chunk include: vector, key, level 0 neighbors mutable std::vector node_chunks_{}; //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors mutable std::vector upper_neighbor_chunks_{}; HnswRabitqChunkBroker::Pointer broker_{}; // chunk broker }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_converter.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "rabitq_converter.h" #include #include #include #include #include #include #include "ailego/pattern/defer.h" #include "algorithm/hnsw_rabitq/rabitq_reformer.h" #include "zvec/core/framework/index_cluster.h" #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_features.h" #include "zvec/core/framework/index_holder.h" #include "zvec/core/framework/index_memory.h" #include "zvec/core/framework/index_meta.h" #include "rabitq_params.h" #include "rabitq_utils.h" #ifdef _MSC_VER #define strncasecmp _strnicmp #endif namespace zvec { namespace core { RabitqConverter::~RabitqConverter() { this->cleanup(); } int RabitqConverter::init(const IndexMeta &meta, const ailego::Params ¶ms) { // Copy meta and ensure it has metric information meta_ = meta; dimension_ = meta.dimension(); if (meta_.metric_name().empty()) { LOG_ERROR("Meta metric is empty"); return IndexError_InvalidArgument; } // Round up dimension to multiple of 64 padded_dim_ = ((dimension_ + 63) / 64) * 64; // Get RaBitQ parameters with defaults uint32_t total_bits = 0; params.get(PARAM_RABITQ_TOTAL_BITS, &total_bits); if (total_bits == 0) { total_bits = kDefaultRabitqTotalBits; } if (total_bits < 1 || total_bits > 9) { LOG_ERROR("Invalid total_bits: %zu, must be in [1, 9]", (size_t)total_bits); return IndexError_InvalidArgument; } ex_bits_ = total_bits - 1; params.get(PARAM_RABITQ_NUM_CLUSTERS, &num_clusters_); if (num_clusters_ == 0) { num_clusters_ = kDefaultNumClusters; } if (ex_bits_ > 8) { LOG_ERROR("Invalid ex_bits: %zu, must be <= 8", ex_bits_); return IndexError_InvalidArgument; } if (meta.data_type() != IndexMeta::DataType::DT_FP32) { LOG_ERROR("RaBitQ only supports FP32 data type"); return IndexError_Unsupported; } params.get(PARAM_RABITQ_SAMPLE_COUNT, &sample_count_); std::string rotator_type_str; params.get(PARAM_RABITQ_ROTATOR_TYPE, &rotator_type_str); if (rotator_type_str.empty()) { rotator_type_ = rabitqlib::RotatorType::FhtKacRotator; } else if (strncasecmp(rotator_type_str.c_str(), "fht", 3) == 0) { rotator_type_ = rabitqlib::RotatorType::FhtKacRotator; } else if (strncasecmp(rotator_type_str.c_str(), "matrix", 6) == 0) { rotator_type_ = rabitqlib::RotatorType::MatrixRotator; } else { LOG_ERROR("Invalid rotator_type: %s", rotator_type_str.c_str()); return IndexError_InvalidArgument; } // Create rotator rotator_.reset( rabitqlib::choose_rotator(dimension_, rotator_type_, padded_dim_)); LOG_INFO( "RabitqConverter initialized: dim=%zu, padded_dim=%zu, " "num_clusters=%zu, ex_bits=%zu, rotator_type=%d[%s] sample_count[%zu]", dimension_, padded_dim_, num_clusters_, ex_bits_, (int)rotator_type_, rotator_type_str.c_str(), sample_count_); return 0; } int RabitqConverter::cleanup() { centroids_.clear(); rotated_centroids_.clear(); result_holder_.reset(); rotator_.reset(); return 0; } int RabitqConverter::train(IndexHolder::Pointer holder) { if (!holder) { LOG_ERROR("Null holder for training"); return IndexError_InvalidArgument; } ailego::ElapsedTime timer; size_t vector_count = holder->count(); if (vector_count == 0) { LOG_ERROR("No vectors for training"); return IndexError_InvalidArgument; } // do sampling from all data size_t sample_count = vector_count; if (sample_count_ > 0) { sample_count = std::min(sample_count_, vector_count); } LOG_INFO("Training with %zu vectors from %zu of holder", sample_count, vector_count); auto sampler = std::make_shared>( meta_, sample_count); auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator error"); return IndexError_Runtime; } for (; iter->is_valid(); iter->next()) { sampler->emplace(iter->data()); } // Holder is not needed, cleanup it. holder.reset(); if (sampler->count() == 0) { LOG_ERROR("Load training data error"); return IndexError_InvalidLength; } // Create KmeansCluster for training centroids auto cluster = IndexFactory::CreateCluster("OptKmeansCluster"); if (!cluster) { LOG_ERROR("Failed to create OptKmeansCluster"); return IndexError_NoExist; } // Initialize cluster LOG_INFO( "Initializing KmeansCluster with meta: dim=%u, data_type=%d, metric=%s", meta_.dimension(), (int)meta_.data_type(), meta_.metric_name().c_str()); ailego::Params cluster_params; int ret = cluster->init(meta_, cluster_params); if (ret != 0) { LOG_ERROR("Failed to initialize KmeansCluster: %d", ret); return ret; } ret = cluster->mount(sampler); if (ret != 0) { LOG_ERROR("Failed to mount training data: %d", ret); return ret; } cluster->suggest(num_clusters_); // Perform clustering IndexCluster::CentroidList cents; // TODO: support specify threads with argument auto threads = std::make_shared(0, false); ret = cluster->cluster(threads, cents); if (ret != 0) { LOG_ERROR("Failed to perform clustering: %d", ret); return ret; } if (cents.size() != num_clusters_) { LOG_WARN("Expected %zu clusters, got %zu", num_clusters_, cents.size()); num_clusters_ = cents.size(); } // Extract original centroids (for LinearSeeker query) centroids_.resize(num_clusters_ * dimension_); // Extract rotated centroids (for quantization) rotated_centroids_.resize(num_clusters_ * padded_dim_); for (uint32_t i = 0; i < num_clusters_; ++i) { const float *cent_data = static_cast(cents[i].feature()); // Save original centroids std::memcpy(¢roids_[i * dimension_], cent_data, dimension_ * sizeof(float)); // Save rotated centroids this->rotator_->rotate(cent_data, &rotated_centroids_[i * padded_dim_]); } stats_.set_trained_count(sampler->count()); stats_.set_trained_costtime(timer.milli_seconds()); LOG_INFO("Training completed: %zu centroids, cost %zu ms", num_clusters_, static_cast(timer.milli_seconds())); return 0; } int RabitqConverter::transform(IndexHolder::Pointer holder) { if (!holder) { LOG_ERROR("Null holder for transformation"); return IndexError_InvalidArgument; } if (rotated_centroids_.empty()) { LOG_ERROR("Centroids not trained yet"); return IndexError_NoReady; } LOG_ERROR("Not implemented"); return IndexError_NotImplemented; } int RabitqConverter::dump(const IndexDumper::Pointer &dumper) { if (!dumper) { LOG_ERROR("Null dumper"); return IndexError_InvalidArgument; } if (rotated_centroids_.empty() || centroids_.empty()) { LOG_ERROR("No centroids to dump"); return IndexError_NoReady; } ailego::ElapsedTime timer; size_t dumped_size = 0; int ret = dump_rabitq_centroids( dumper, dimension_, padded_dim_, ex_bits_, num_clusters_, rotator_type_, rotated_centroids_, centroids_, rotator_, &dumped_size); if (ret != 0) { return ret; } stats_.set_dumped_size(dumped_size); stats_.set_dumped_costtime(timer.milli_seconds()); LOG_INFO("Dump completed: %zu bytes, cost %zu ms", stats_.dumped_size(), static_cast(timer.milli_seconds())); return 0; } int RabitqConverter::to_reformer(IndexReformer::Pointer *reformer) { auto memory_dumper = IndexFactory::CreateDumper("MemoryDumper"); memory_dumper->init(ailego::Params()); std::string file_id = ailego::StringHelper::Concat( "rabitq_converter_", ailego::Monotime::MilliSeconds(), rand()); int ret = memory_dumper->create(file_id); if (ret != 0) { LOG_ERROR("Failed to create memory dumper: %d", ret); return ret; } // Release memory AILEGO_DEFER([&file_id]() { IndexMemory::Instance()->remove(file_id); }); ret = this->dump(memory_dumper); if (ret != 0) { LOG_ERROR("Failed to dump RabitqConverter: %d", ret); return ret; } ret = memory_dumper->close(); if (ret != 0) { LOG_ERROR("Failed to close memory dumper: %d", ret); return ret; } auto res = std::make_shared(); ailego::Params reformer_params; reformer_params.set(PARAM_RABITQ_METRIC_NAME, meta_.metric_name()); ret = res->init(reformer_params); if (ret != 0) { LOG_ERROR("Failed to initialize RabitqReformer: %d", ret); return ret; } auto memory_storage = IndexFactory::CreateStorage("MemoryReadStorage"); ret = memory_storage->open(file_id, false); if (ret != 0) { LOG_ERROR("Failed to open memory storage: %d", ret); return ret; } ret = res->load(memory_storage); if (ret != 0) { LOG_ERROR("Failed to load RabitqReformer: %d", ret); return ret; } *reformer = std::move(res); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_converter.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "zvec/core/framework/index_cluster.h" #include "zvec/core/framework/index_converter.h" #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_threads.h" #include "rabitq_params.h" namespace zvec { namespace core { class RabitqReformer; /*! RaBitQ Converter * Trains KMeans centroids and quantizes vectors using RaBitQ */ class RabitqConverter : public IndexConverter { public: //! Constructor RabitqConverter() = default; //! Destructor ~RabitqConverter() override; //! Initialize Converter int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup Converter int cleanup(void) override; //! Train the data - perform KMeans clustering int train(IndexHolder::Pointer holder) override; //! Transform the data - quantize vectors using RaBitQ int transform(IndexHolder::Pointer holder) override; //! Dump centroids and config into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve a holder as result IndexHolder::Pointer result(void) const override { return result_holder_; } //! Retrieve Index Meta const IndexMeta &meta(void) const override { return meta_; } int to_reformer(IndexReformer::Pointer *reformer) override; private: static inline size_t AlignSize(size_t size) { return (size + 0x1F) & (~0x1F); } private: IndexMeta meta_; IndexHolder::Pointer result_holder_; Stats stats_; size_t sample_count_{0}; // RaBitQ parameters size_t num_clusters_{0}; size_t ex_bits_{0}; size_t dimension_{0}; size_t padded_dim_{0}; // Original centroids: num_clusters * dimension (for LinearSeeker query) std::vector centroids_; // Rotated centroids: num_clusters * padded_dim (for quantization) std::vector rotated_centroids_; // Rotator for vector transformation rabitqlib::RotatorType rotator_type_{rabitqlib::RotatorType::FhtKacRotator}; std::unique_ptr> rotator_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { // Local metric type enum that mirrors rabitqlib::MetricType, // without exposing rabitqlib headers to consumers of this file. enum class RabitqMetricType { kL2 = 0, kIP = 1, }; // RaBitQ Converter parameters static const std::string PARAM_RABITQ_NUM_CLUSTERS( "proxima.rabitq.num_clusters"); static const std::string PARAM_RABITQ_TOTAL_BITS("proxima.rabitq.total_bits"); static const std::string PARAM_RABITQ_METRIC_NAME("proxima.rabitq.metric_name"); static const std::string PARAM_RABITQ_ROTATOR_TYPE( "proxima.rabitq.rotator.type"); static const std::string PARAM_RABITQ_SAMPLE_COUNT( "proxima.rabitq.sample_count"); // Default values constexpr size_t kDefaultNumClusters = 16; // 4-bit, 5-bit, and 7-bit quantization typically achieve 90%, 95%, and 99% // recall, respectively—without accessing raw vectors for reranking constexpr size_t kDefaultRabitqTotalBits = 7; constexpr int kMinRabitqDimSize = 64; constexpr int kMaxRabitqDimSize = 4095; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_reformer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "rabitq_reformer.h" #include #include #include #include #include #include #include #include #include "core/algorithm/cluster/linear_seeker.h" #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_features.h" #include "zvec/core/framework/index_meta.h" #include "zvec/core/framework/index_storage.h" #include "hnsw_rabitq_query_entity.h" #include "rabitq_converter.h" #include "rabitq_utils.h" namespace zvec { namespace core { // All rabitqlib types are confined to this translation unit via pimpl. struct RabitqReformer::Impl { // RaBitQ parameters size_t num_clusters{0}; size_t ex_bits{0}; size_t dimension{0}; size_t padded_dim{0}; size_t size_bin_data{0}; size_t size_ex_data{0}; bool loaded{false}; // Original centroids: num_clusters * dimension (for LinearSeeker query) std::vector centroids; // Rotated centroids: num_clusters * padded_dim (for quantization) std::vector rotated_centroids; rabitqlib::RotatorType rotator_type{rabitqlib::RotatorType::FhtKacRotator}; std::unique_ptr> rotator; rabitqlib::quant::RabitqConfig query_config; rabitqlib::quant::RabitqConfig config; rabitqlib::MetricType metric_type{rabitqlib::METRIC_L2}; LinearSeeker::Pointer centroid_seeker; CoherentIndexFeatures::Pointer centroid_features; // Translate local enum to rabitqlib enum (used only inside this .cc). static rabitqlib::MetricType to_rabitq(RabitqMetricType m) { return m == RabitqMetricType::kIP ? rabitqlib::METRIC_IP : rabitqlib::METRIC_L2; } // Translate rabitqlib enum to local enum. static RabitqMetricType from_rabitq(rabitqlib::MetricType m) { return m == rabitqlib::METRIC_IP ? RabitqMetricType::kIP : RabitqMetricType::kL2; } int quantize_vector(const float *raw_vector, uint32_t cluster_id, std::string *quantized_data) const; }; RabitqReformer::RabitqReformer() : impl_(std::make_unique()) {} RabitqReformer::~RabitqReformer() { this->cleanup(); } size_t RabitqReformer::num_clusters() const { return impl_->num_clusters; } RabitqMetricType RabitqReformer::rabitq_metric_type() const { return Impl::from_rabitq(impl_->metric_type); } int RabitqReformer::init(const ailego::Params ¶ms) { std::string metric_name = params.get_as_string(PARAM_RABITQ_METRIC_NAME); if (metric_name == "SquaredEuclidean") { impl_->metric_type = rabitqlib::METRIC_L2; } else if (metric_name == "InnerProduct") { impl_->metric_type = rabitqlib::METRIC_IP; } else if (metric_name == "Cosine") { impl_->metric_type = rabitqlib::METRIC_IP; } else { LOG_ERROR("Unsupported metric name: %s", metric_name.c_str()); return IndexError_InvalidArgument; } LOG_DEBUG("Rabitq reformer init done. metric_name=%s metric_type=%d", metric_name.c_str(), static_cast(impl_->metric_type)); return 0; } int RabitqReformer::cleanup() { impl_->centroids.clear(); impl_->rotated_centroids.clear(); impl_->centroid_seeker.reset(); impl_->centroid_features.reset(); impl_->loaded = false; impl_->rotator.reset(); return 0; } int RabitqReformer::unload() { return this->cleanup(); } int RabitqReformer::load(IndexStorage::Pointer storage) { if (!storage) { LOG_ERROR("Invalid storage for load"); return IndexError_InvalidArgument; } auto segment = storage->get(RABITQ_CONVERTER_SEG_ID); if (!segment) { LOG_ERROR("Failed to get segment %s", RABITQ_CONVERTER_SEG_ID.c_str()); return IndexError_InvalidFormat; } size_t offset = 0; RabitqConverterHeader header; IndexStorage::MemoryBlock block; size_t size = segment->read(offset, block, sizeof(header)); if (size != sizeof(header)) { LOG_ERROR("Failed to read header"); return IndexError_InvalidFormat; } memcpy(&header, block.data(), sizeof(header)); impl_->dimension = header.dim; impl_->padded_dim = header.padded_dim; impl_->ex_bits = header.ex_bits; impl_->num_clusters = header.num_clusters; impl_->rotator_type = static_cast(header.rotator_type); offset += sizeof(header); // Read rotated centroids size_t rotated_centroids_size = sizeof(float) * header.num_clusters * header.padded_dim; size = segment->read(offset, block, rotated_centroids_size); if (size != rotated_centroids_size) { LOG_ERROR("Failed to read rotated centroids"); return IndexError_InvalidFormat; } impl_->rotated_centroids.resize(header.num_clusters * header.padded_dim); memcpy(impl_->rotated_centroids.data(), block.data(), rotated_centroids_size); offset += size; // Read original centroids (for LinearSeeker query) size_t centroids_size = sizeof(float) * header.num_clusters * header.dim; size = segment->read(offset, block, centroids_size); if (size != centroids_size) { LOG_ERROR("Failed to read centroids"); return IndexError_InvalidFormat; } impl_->centroids.resize(header.num_clusters * header.dim); memcpy(impl_->centroids.data(), block.data(), centroids_size); offset += size; // Read rotator size_t rotator_size = header.rotator_size; size = segment->read(offset, block, rotator_size); if (size != rotator_size) { LOG_ERROR("Failed to read rotator"); return IndexError_InvalidFormat; } impl_->rotator.reset(rabitqlib::choose_rotator( impl_->dimension, impl_->rotator_type, impl_->padded_dim)); impl_->rotator->load(reinterpret_cast(block.data())); offset += size; impl_->query_config = rabitqlib::quant::faster_config( impl_->padded_dim, rabitqlib::SplitSingleQuery::kNumBits); impl_->config = rabitqlib::quant::faster_config(impl_->padded_dim, impl_->ex_bits + 1); impl_->size_bin_data = rabitqlib::BinDataMap::data_bytes(impl_->padded_dim); impl_->size_ex_data = rabitqlib::ExDataMap::data_bytes( impl_->padded_dim, impl_->ex_bits); // Initialize LinearSeeker for centroid search IndexMeta centroid_meta; centroid_meta.set_data_type(IndexMeta::DataType::DT_FP32); centroid_meta.set_dimension(static_cast(impl_->dimension)); // Note: // 1. spherical kmeans is used for InnerProduct and Cosine, so centroids are // normalized. // 2. for Cosine metric, `transform_to_entity` input is normalized, need to // use InnerProduct metric as Cosine metric requires extra dimension which is // unsuitable for centroids. centroid_meta.set_metric(impl_->metric_type == rabitqlib::METRIC_L2 ? "SquaredEuclidean" : "InnerProduct", 0, ailego::Params()); impl_->centroid_features = std::make_shared(); impl_->centroid_features->mount(centroid_meta, impl_->centroids.data(), impl_->centroids.size() * sizeof(float)); impl_->centroid_seeker = std::make_shared(); int ret = impl_->centroid_seeker->init(centroid_meta); if (ret != 0) { LOG_ERROR("Failed to init centroid seeker. ret[%d]", ret); return ret; } ret = impl_->centroid_seeker->mount(impl_->centroid_features); if (ret != 0) { LOG_ERROR("Failed to mount centroid features. ret[%d]", ret); return ret; } LOG_INFO( "Rabitq reformer load done. dimension=%zu, padded_dim=%zu, " "ex_bits=%zu, num_clusters=%zu, size_bin_data=%zu, size_ex_data=%zu " "rotator_type=%d", impl_->dimension, impl_->padded_dim, impl_->ex_bits, impl_->num_clusters, impl_->size_bin_data, impl_->size_ex_data, (int)impl_->rotator_type); impl_->loaded = true; return 0; } int RabitqReformer::convert(const void *record, const IndexQueryMeta &rmeta, std::string *out, IndexQueryMeta *ometa) const { if (!impl_->loaded) { LOG_ERROR("Centroids not loaded yet"); return IndexError_NoReady; } if (!record || !out) { LOG_ERROR("Invalid arguments for convert"); return IndexError_InvalidArgument; } // input may be transformed, require rmeta.dimension >= dimension if (rmeta.dimension() < impl_->dimension || rmeta.data_type() != IndexMeta::DataType::DT_FP32) { LOG_ERROR("Invalid record meta: dimension=%zu, data_type=%d", static_cast(rmeta.dimension()), (int)rmeta.data_type()); return IndexError_InvalidArgument; } // Find nearest centroid using LinearSeeker Seeker::Document doc; int ret = impl_->centroid_seeker->seek( record, impl_->dimension * sizeof(float), &doc); if (ret != 0) { LOG_ERROR("Failed to seek centroid. ret[%d]", ret); return ret; } uint32_t cluster_id = doc.index; const float *vector = static_cast(record); ret = impl_->quantize_vector(vector, cluster_id, out); if (ret != 0) { LOG_ERROR("Failed to quantize vector"); return ret; } ometa->set_meta(IndexMeta::DataType::DT_INT8, (uint32_t)out->size()); return 0; } int RabitqReformer::transform(const void *, const IndexQueryMeta &, std::string *, IndexQueryMeta *) const { return IndexError_NotImplemented; } int RabitqReformer::transform_to_entity(const void *query, HnswRabitqQueryEntity *entity) const { if (!impl_->loaded) { LOG_ERROR("Centroids not loaded yet"); return IndexError_NoReady; } if (!query) { LOG_ERROR("Invalid arguments for transform"); return IndexError_InvalidArgument; } const float *query_vector = static_cast(query); // Apply rotator entity->rotated_query.resize(impl_->padded_dim); impl_->rotator->rotate(query_vector, entity->rotated_query.data()); // Quantize query to 4-bit representation entity->query_wrapper = std::make_unique>( entity->rotated_query.data(), impl_->padded_dim, impl_->ex_bits, impl_->query_config, impl_->metric_type); // Preprocess - get the distance from query to all centroids entity->q_to_centroids.resize(impl_->num_clusters); if (impl_->metric_type == rabitqlib::METRIC_L2) { for (size_t i = 0; i < impl_->num_clusters; i++) { entity->q_to_centroids[i] = std::sqrt(rabitqlib::euclidean_sqr( entity->rotated_query.data(), impl_->rotated_centroids.data() + (i * impl_->padded_dim), impl_->padded_dim)); } } else if (impl_->metric_type == rabitqlib::METRIC_IP) { entity->q_to_centroids.resize(impl_->num_clusters * 2); // first half as g_add, second half as g_error for (size_t i = 0; i < impl_->num_clusters; i++) { entity->q_to_centroids[i] = rabitqlib::dot_product( entity->rotated_query.data(), impl_->rotated_centroids.data() + (i * impl_->padded_dim), impl_->padded_dim); entity->q_to_centroids[i + impl_->num_clusters] = std::sqrt(rabitqlib::euclidean_sqr( entity->rotated_query.data(), impl_->rotated_centroids.data() + (i * impl_->padded_dim), impl_->padded_dim)); } } return 0; } int RabitqReformer::Impl::quantize_vector(const float *raw_vector, uint32_t cluster_id, std::string *quantized_data) const { std::vector rotated_data(padded_dim); rotator->rotate(raw_vector, rotated_data.data()); // quantized format: cluster_id + bin_data + ex_data quantized_data->resize(sizeof(cluster_id) + size_bin_data + size_ex_data); memcpy(&(*quantized_data)[0], &cluster_id, sizeof(cluster_id)); int bin_data_offset = sizeof(cluster_id); int ex_data_offset = bin_data_offset + size_bin_data; rabitqlib::quant::quantize_split_single( rotated_data.data(), rotated_centroids.data() + (cluster_id * padded_dim), padded_dim, ex_bits, &(*quantized_data)[bin_data_offset], &(*quantized_data)[ex_data_offset], metric_type, config); return 0; } int RabitqReformer::dump(const IndexDumper::Pointer &dumper) { if (!dumper) { LOG_ERROR("Null dumper"); return IndexError_InvalidArgument; } if (!impl_->loaded || impl_->rotated_centroids.empty() || impl_->centroids.empty()) { LOG_ERROR("No centroids to dump"); return IndexError_NoReady; } size_t dumped_size = 0; int ret = dump_rabitq_centroids( dumper, impl_->dimension, impl_->padded_dim, impl_->ex_bits, impl_->num_clusters, impl_->rotator_type, impl_->rotated_centroids, impl_->centroids, impl_->rotator, &dumped_size); if (ret != 0) { return ret; } LOG_INFO("RabitqReformer dump completed: %zu bytes", dumped_size); return 0; } int RabitqReformer::dump(const IndexStorage::Pointer &storage) { if (!storage) { LOG_ERROR("Null storage"); return IndexError_InvalidArgument; } if (!impl_->loaded || impl_->rotated_centroids.empty() || impl_->centroids.empty()) { LOG_ERROR("No centroids to dump"); return IndexError_NoReady; } auto align_size = [](size_t size) -> size_t { return (size + 0x1F) & (~0x1F); }; size_t header_size = sizeof(RabitqConverterHeader); size_t rotated_centroids_size = impl_->rotated_centroids.size() * sizeof(float); size_t centroids_size = impl_->centroids.size() * sizeof(float); size_t rotator_size = impl_->rotator->dump_bytes(); size_t data_size = header_size + rotated_centroids_size + centroids_size + rotator_size; size_t total_size = align_size(data_size); int ret = storage->append(RABITQ_CONVERTER_SEG_ID, total_size); if (ret != 0) { LOG_ERROR("Failed to append segment %s, ret=%d", RABITQ_CONVERTER_SEG_ID.c_str(), ret); return ret; } auto segment = storage->get(RABITQ_CONVERTER_SEG_ID); if (!segment) { LOG_ERROR("Failed to get segment %s", RABITQ_CONVERTER_SEG_ID.c_str()); return IndexError_ReadData; } size_t offset = 0; RabitqConverterHeader header; header.dim = static_cast(impl_->dimension); header.padded_dim = static_cast(impl_->padded_dim); header.num_clusters = static_cast(impl_->num_clusters); header.ex_bits = static_cast(impl_->ex_bits); header.rotator_type = static_cast(impl_->rotator_type); header.rotator_size = static_cast(rotator_size); size_t written = segment->write(offset, &header, header_size); if (written != header_size) { LOG_ERROR("Failed to write header: written=%zu, expected=%zu", written, header_size); return IndexError_WriteData; } offset += header_size; written = segment->write(offset, impl_->rotated_centroids.data(), rotated_centroids_size); if (written != rotated_centroids_size) { LOG_ERROR("Failed to write rotated centroids: written=%zu, expected=%zu", written, rotated_centroids_size); return IndexError_WriteData; } offset += rotated_centroids_size; written = segment->write(offset, impl_->centroids.data(), centroids_size); if (written != centroids_size) { LOG_ERROR("Failed to write centroids: written=%zu, expected=%zu", written, centroids_size); return IndexError_WriteData; } offset += centroids_size; std::vector buffer(rotator_size); impl_->rotator->save(buffer.data()); written = segment->write(offset, buffer.data(), rotator_size); if (written != rotator_size) { LOG_ERROR("Failed to write rotator data: written=%zu, expected=%zu", written, rotator_size); return IndexError_WriteData; } LOG_INFO("RabitqReformer dump to storage completed: %zu bytes", data_size); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_reformer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #pragma once #include #include "zvec/core/framework/index_dumper.h" #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_storage.h" #include "rabitq_params.h" namespace zvec { namespace core { class HnswRabitqQueryEntity; /*! RaBitQ Reformer * Loads centroids and performs query transformation and vector quantization. * * All rabitqlib types are hidden behind a pimpl to avoid leaking rabitqlib * headers to consumers of this class. */ class RabitqReformer : public IndexReformer { public: typedef std::shared_ptr Pointer; RabitqReformer(); ~RabitqReformer() override; // Non-copyable RabitqReformer(const RabitqReformer &) = delete; RabitqReformer &operator=(const RabitqReformer &) = delete; int init(const ailego::Params ¶ms) override; int cleanup(void) override; int load(IndexStorage::Pointer storage) override; int unload(void) override; // transform() is not implemented for RabitqReformer; use transform_to_entity. int transform(const void *query, const IndexQueryMeta &qmeta, std::string *out, IndexQueryMeta *ometa) const override; int convert(const void *record, const IndexQueryMeta &rmeta, std::string *out, IndexQueryMeta *ometa) const override; int dump(const IndexDumper::Pointer &dumper); int dump(const IndexStorage::Pointer &storage); int transform_to_entity(const void *query, HnswRabitqQueryEntity *entity) const; size_t num_clusters() const; RabitqMetricType rabitq_metric_type() const; private: struct Impl; std::unique_ptr impl_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_utils.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "rabitq_utils.h" #include #include #include "zvec/core/framework/index_error.h" #include "zvec/core/framework/index_logger.h" namespace zvec { namespace core { int dump_rabitq_centroids( const IndexDumper::Pointer &dumper, size_t dimension, size_t padded_dim, size_t ex_bits, size_t num_clusters, rabitqlib::RotatorType rotator_type, const std::vector &rotated_centroids, const std::vector ¢roids, const std::unique_ptr> &rotator, size_t *out_dumped_size) { auto align_size = [](size_t size) -> size_t { return (size + 0x1F) & (~0x1F); }; uint32_t crc = 0; size_t dumped_size = 0; // Write header RabitqConverterHeader header; header.dim = static_cast(dimension); header.padded_dim = static_cast(padded_dim); header.num_clusters = static_cast(num_clusters); header.ex_bits = static_cast(ex_bits); header.rotator_type = static_cast(rotator_type); header.rotator_size = static_cast(rotator->dump_bytes()); size_t size = dumper->write(&header, sizeof(header)); if (size != sizeof(header)) { LOG_ERROR("Failed to write header: written=%zu, expected=%zu", size, sizeof(header)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(&header, sizeof(header), crc); dumped_size += size; // Write rotated centroids size = dumper->write(rotated_centroids.data(), rotated_centroids.size() * sizeof(float)); if (size != rotated_centroids.size() * sizeof(float)) { LOG_ERROR("Failed to write rotated centroids: written=%zu, expected=%zu", size, rotated_centroids.size() * sizeof(float)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(rotated_centroids.data(), rotated_centroids.size() * sizeof(float), crc); dumped_size += size; // Write original centroids size = dumper->write(centroids.data(), centroids.size() * sizeof(float)); if (size != centroids.size() * sizeof(float)) { LOG_ERROR("Failed to write centroids: written=%zu, expected=%zu", size, centroids.size() * sizeof(float)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(centroids.data(), centroids.size() * sizeof(float), crc); dumped_size += size; // Write rotator data std::vector buffer(rotator->dump_bytes()); rotator->save(buffer.data()); size = dumper->write(buffer.data(), buffer.size()); if (size != buffer.size()) { LOG_ERROR("Failed to write rotator data: written=%zu, expected=%zu", size, buffer.size()); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(buffer.data(), buffer.size(), crc); dumped_size += size; // Write padding size_t padding_size = align_size(dumped_size) - dumped_size; if (padding_size > 0) { std::string padding(padding_size, '\0'); if (dumper->write(padding.data(), padding_size) != padding_size) { LOG_ERROR("Append padding failed, size %lu", padding_size); return IndexError_WriteData; } } int ret = dumper->append(RABITQ_CONVERTER_SEG_ID, dumped_size, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s meta failed, ret=%d", RABITQ_CONVERTER_SEG_ID.c_str(), ret); return ret; } if (out_dumped_size) { *out_dumped_size = dumped_size; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_rabitq/rabitq_utils.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "zvec/core/framework/index_dumper.h" namespace zvec { namespace core { inline const std::string RABITQ_CONVERTER_SEG_ID{"rabitq.converter"}; struct RabitqConverterHeader { uint32_t num_clusters; uint32_t dim; uint32_t padded_dim; uint32_t rotator_size; uint8_t ex_bits; uint8_t rotator_type; uint8_t padding[2]; uint32_t reserve[3]; RabitqConverterHeader() { memset(this, 0, sizeof(RabitqConverterHeader)); } }; static_assert(sizeof(RabitqConverterHeader) % 32 == 0, "RabitqConverterHeader must be aligned with 32 bytes"); // Common dump implementation for RabitqConverter and RabitqReformer int dump_rabitq_centroids( const IndexDumper::Pointer &dumper, size_t dimension, size_t padded_dim, size_t ex_bits, size_t num_clusters, rabitqlib::RotatorType rotator_type, const std::vector &rotated_centroids, const std::vector ¢roids, const std::unique_ptr> &rotator, size_t *out_dumped_size = nullptr); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_knn_hnsw_sparse STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS core_framework sparsehash INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_algorithm.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_algorithm.h" #include #include #include #include namespace zvec { namespace core { HnswSparseAlgorithm::HnswSparseAlgorithm(HnswSparseEntity &entity) : entity_(entity), mt_(std::chrono::system_clock::now().time_since_epoch().count()), lock_pool_(kLockCnt) {} int HnswSparseAlgorithm::cleanup() { return 0; } int HnswSparseAlgorithm::add_node(node_id_t id, level_t level, HnswSparseContext *ctx) { spin_lock_.lock(); // std::cout << "id: " << id << ", level: " << level << std::endl; auto cur_max_level = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { entity_.update_ep_and_level(id, level); spin_lock_.unlock(); return 0; } spin_lock_.unlock(); if (ailego_unlikely(level > cur_max_level)) { mutex_.lock(); // re-check max level cur_max_level = entity_.cur_max_level(); entry_point = entity_.entry_point(); if (level <= cur_max_level) { mutex_.unlock(); } } level_t cur_level = cur_max_level; dist_t dist = ctx->dist_calculator()(entry_point); for (; cur_level > level; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } for (; cur_level >= 0; --cur_level) { search_neighbors(cur_level, &entry_point, &dist, ctx->level_topk(cur_level), ctx); } // add neighbors from down level to top level, to avoid upper level visible // to knn_search but the under layer level not ready for (cur_level = 0; cur_level <= level; ++cur_level) { add_neighbors(id, cur_level, ctx->level_topk(cur_level), ctx); ctx->level_topk(cur_level).clear(); } if (ailego_unlikely(level > cur_max_level)) { spin_lock_.lock(); entity_.update_ep_and_level(id, level); spin_lock_.unlock(); mutex_.unlock(); } return 0; } int HnswSparseAlgorithm::search(HnswSparseContext *ctx) const { spin_lock_.lock(); auto maxLevel = entity_.cur_max_level(); auto entry_point = entity_.entry_point(); spin_lock_.unlock(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { return 0; } dist_t dist = ctx->dist_calculator().dist(entry_point); for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } auto &topk_heap = ctx->topk_heap(); topk_heap.clear(); search_neighbors(0, &entry_point, &dist, topk_heap, ctx); if (ctx->group_by_search()) { expand_neighbors_by_group(topk_heap, ctx); } return 0; } //! select_entry_point on hnsw level, ef = 1 void HnswSparseAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswSparseContext *ctx) const { auto &entity = ctx->get_entity(); HnswSparseDistCalculator &dc = ctx->dist_calculator(); while (true) { const Neighbors neighbors = entity.get_neighbors(level, *entry_point); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } uint32_t size = neighbors.size(); if (size == 0) { break; } std::vector neighbor_block_vecs; int ret = entity.get_vector_metas(&neighbors[0], size, neighbor_block_vecs); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } bool find_closer = false; for (uint32_t i = 0; i < size; ++i) { dist_t cur_dist = dc.dist(neighbor_block_vecs[i].data()); if (cur_dist < *dist) { *entry_point = neighbors[i]; *dist = cur_dist; find_closer = true; } } if (!find_closer) { break; } } return; } void HnswSparseAlgorithm::add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswSparseContext *ctx) { if (ailego_unlikely(topk_heap.size() == 0)) { return; } HnswSparseDistCalculator &dc = ctx->dist_calculator(); update_neighbors(dc, id, level, topk_heap); // reverse update neighbors for (size_t i = 0; i < topk_heap.size(); ++i) { reverse_update_neighbors(dc, topk_heap[i].first, level, id, topk_heap[i].second, ctx->update_heap()); } return; } void HnswSparseAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, TopkHeap &topk, HnswSparseContext *ctx) const { const auto &entity = ctx->get_entity(); HnswSparseDistCalculator &dc = ctx->dist_calculator(); VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } candidates.clear(); visit.clear(); visit.set_visited(*entry_point); if (!filter(*entry_point)) { topk.emplace(*entry_point, *dist); } candidates.emplace(*entry_point, *dist); while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; dist_t main_dist = top->second; if (topk.full() && main_dist > topk[0].second) { break; } candidates.pop(); const Neighbors neighbors = entity.get_neighbors(level, main_node); ailego_prefetch(neighbors.data); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } std::vector neighbor_block_vecs; int ret = entity.get_vector_metas(neighbor_ids.data(), size, neighbor_block_vecs); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } static constexpr node_id_t PREFETCH_STEP = 2; static constexpr node_id_t SPARSE_PREFETCH_STEP = 1; for (uint32_t i = 0; i < std::min(PREFETCH_STEP, size); ++i) { ailego_prefetch(neighbor_block_vecs[i].data()); } for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; node_id_t prefetch_id = i + PREFETCH_STEP; if (prefetch_id < size) { ailego_prefetch(neighbor_block_vecs[prefetch_id].data()); } node_id_t sparse_prefetch_id = i + SPARSE_PREFETCH_STEP; if (sparse_prefetch_id < size) { IndexStorage::MemoryBlock sparse_block; int sparse_length = 0; entity.get_sparse_data_from_vector( neighbor_block_vecs[sparse_prefetch_id].data(), sparse_block, sparse_length); auto sparse_data = std::make_pair(sparse_block.data(), sparse_length); if (sparse_data.first != nullptr) { ailego_prefetch(sparse_data.first); } } dist_t cur_dist = dc.dist(neighbor_block_vecs[i].data()); if ((!topk.full()) || cur_dist < topk[0].second) { candidates.emplace(node, cur_dist); // update entry_point for next level scan if (cur_dist < *dist) { *entry_point = node; *dist = cur_dist; } if (!filter(node)) { topk.emplace(node, cur_dist); } } // end if } // end for } // while return; } void HnswSparseAlgorithm::expand_neighbors_by_group( TopkHeap &topk, HnswSparseContext *ctx) const { if (!ctx->group_by().is_valid()) { return; } const auto &entity = ctx->get_entity(); std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity.get_key(id)); }; // devide into groups std::map &group_topk_heaps = ctx->group_topk_heaps(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; auto score = topk[i].second; std::string group_id = group_by(id); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, score); } // stage 2, expand to reach group num as possible if (group_topk_heaps.size() < ctx->group_num()) { VisitFilter &visit = ctx->visit_filter(); CandidateHeap &candidates = ctx->candidates(); HnswSparseDistCalculator &dc = ctx->dist_calculator(); std::function filter = [](node_id_t) { return false; }; if (ctx->filter().is_valid()) { filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } // refill to get enough groups candidates.clear(); visit.clear(); for (uint32_t i = 0; i < topk.size(); ++i) { node_id_t id = topk[i].first; float score = topk[i].second; visit.set_visited(id); candidates.emplace_back(id, score); } // do expand while (!candidates.empty() && !ctx->reach_scan_limit()) { auto top = candidates.begin(); node_id_t main_node = top->first; candidates.pop(); const Neighbors neighbors = entity.get_neighbors(0, main_node); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_neighbors())++; } std::vector neighbor_ids(neighbors.size()); uint32_t size = 0; for (uint32_t i = 0; i < neighbors.size(); ++i) { node_id_t node = neighbors[i]; if (visit.visited(node)) { if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_visit_dup_cnt())++; } continue; } visit.set_visited(node); neighbor_ids[size++] = node; } if (size == 0) { continue; } std::vector neighbor_block_vecs; int ret = entity.get_vector_metas(neighbor_ids.data(), size, neighbor_block_vecs); if (ailego_unlikely(ctx->debugging())) { (*ctx->mutable_stats_get_vector())++; } if (ailego_unlikely(ret != 0)) { break; } static constexpr node_id_t PREFETCH_STEP = 2; for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; node_id_t prefetch_id = i + PREFETCH_STEP; if (prefetch_id < size) { ailego_prefetch(neighbor_block_vecs[prefetch_id].data()); } dist_t cur_dist = dc.dist(neighbor_block_vecs[i].data()); if (!filter(node)) { std::string group_id = group_by(node); auto &topk_heap = group_topk_heaps[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(node, cur_dist); if (group_topk_heaps.size() >= ctx->group_num()) { break; } } candidates.emplace(node, cur_dist); } // end for } // end while } // end if } void HnswSparseAlgorithm::update_neighbors(HnswSparseDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap) { topk_heap.sort(); uint32_t max_neighbor_cnt = entity_.neighbor_cnt(level); if (topk_heap.size() <= static_cast(entity_.prune_cnt())) { if (topk_heap.size() <= static_cast(max_neighbor_cnt)) { entity_.update_neighbors(level, id, topk_heap); return; } } uint32_t cur_size = 0; for (size_t i = 0; i < topk_heap.size(); ++i) { node_id_t cur_node = topk_heap[i].first; dist_t cur_node_dist = topk_heap[i].second; bool good = true; for (uint32_t j = 0; j < cur_size; ++j) { dist_t tmp_dist = dc.dist(cur_node, topk_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { topk_heap[cur_size].first = cur_node; topk_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } // when after-prune neighbor count is too seldom, // we use this strategy to make-up enough edges // not only just make-up out-degrees // we also make-up enough in-degrees uint32_t min_neighbors = entity_.min_neighbor_cnt(); for (size_t k = cur_size; cur_size < min_neighbors && k < topk_heap.size(); ++k) { bool exist = false; for (size_t j = 0; j < cur_size; ++j) { if (topk_heap[j].first == topk_heap[k].first) { exist = true; break; } } if (!exist) { topk_heap[cur_size].first = topk_heap[k].first; topk_heap[cur_size].second = topk_heap[k].second; cur_size++; } } topk_heap.resize(cur_size); entity_.update_neighbors(level, id, topk_heap); return; } void HnswSparseAlgorithm::reverse_update_neighbors(HnswSparseDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, dist_t dist, TopkHeap &update_heap) { const size_t max_neighbor_cnt = entity_.neighbor_cnt(level); uint32_t lock_idx = id & kLockMask; lock_pool_[lock_idx].lock(); const Neighbors neighbors = entity_.get_neighbors(level, id); size_t size = neighbors.size(); ailego_assert_with(size <= max_neighbor_cnt, "invalid neighbor size"); if (size < max_neighbor_cnt) { entity_.add_neighbor(level, id, size, link_id); lock_pool_[lock_idx].unlock(); return; } update_heap.emplace(link_id, dist); for (size_t i = 0; i < size; ++i) { node_id_t node = neighbors[i]; dist_t cur_dist = dc.dist(id, node); update_heap.emplace(node, cur_dist); } //! TODO: optimize prune //! prune edges update_heap.sort(); size_t cur_size = 0; for (size_t i = 0; i < update_heap.size(); ++i) { node_id_t cur_node = update_heap[i].first; dist_t cur_node_dist = update_heap[i].second; bool good = true; for (size_t j = 0; j < cur_size; ++j) { dist_t tmp_dist = dc.dist(cur_node, update_heap[j].first); if (tmp_dist <= cur_node_dist) { good = false; break; } } if (good) { update_heap[cur_size].first = cur_node; update_heap[cur_size].second = cur_node_dist; cur_size++; if (cur_size >= max_neighbor_cnt) { break; } } } update_heap.resize(cur_size); entity_.update_neighbors(level, id, update_heap); lock_pool_[lock_idx].unlock(); update_heap.clear(); return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_algorithm.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_sparse_context.h" #include "hnsw_sparse_dist_calculator.h" #include "hnsw_sparse_entity.h" namespace zvec { namespace core { //! hnsw graph algorithm implement class HnswSparseAlgorithm { public: typedef std::unique_ptr UPointer; public: //! Constructor explicit HnswSparseAlgorithm(HnswSparseEntity &entity); //! Cleanup HnswSparseAlgorithm int cleanup(); //! Add a node to hnsw graph //! @id: the node unique id //! @level: a node will be add to graph in each level [0, level] //! return 0 on success, or errCode in failure int add_node(node_id_t id, level_t level, HnswSparseContext *ctx); //! do knn search in graph //! return 0 on success, or errCode in failure. results saved in ctx int search(HnswSparseContext *ctx) const; //! Initiate HnswAlgorithm int init() { level_probas_.clear(); double level_mult = 1 / std::log(static_cast(entity_.scaling_factor())); for (int level = 0;; level++) { // refers faiss get_random_level alg double proba = std::exp(-level / level_mult) * (1 - std::exp(-1 / level_mult)); if (proba < 1e-9) { break; } level_probas_.push_back(proba); } return 0; } //! Generate a random level //! return graph level uint32_t get_random_level() const { // gen rand float (0, 1) double f = mt_() / static_cast(mt_.max()); for (size_t level = 0; level < level_probas_.size(); level++) { if (f < level_probas_[level]) { return level; } f -= level_probas_[level]; } return level_probas_.size() - 1; } private: //! Select in upper layer to get entry point for next layer search void select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswSparseContext *ctx) const; //! update node id neighbors from topkHeap, and reverse link is also updated void add_neighbors(node_id_t id, level_t level, TopkHeap &topk_heap, HnswSparseContext *ctx); //! Given a node id and level, search the nearest neighbors in graph //! Note: the nearest neighbors result keeps in topk, and entry_point and //! dist will be updated to current level nearest node id and distance void search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, TopkHeap &topk, HnswSparseContext *ctx) const; //! Update the node's neighbors void update_neighbors(HnswSparseDistCalculator &dc, node_id_t id, level_t level, TopkHeap &topk_heap); //! Checking linkId could be id's new neighbor, and add as neighbor if true //! @dc distance calculator //! @updateHeap temporary heap in updating neighbors void reverse_update_neighbors(HnswSparseDistCalculator &dc, node_id_t id, level_t level, node_id_t link_id, dist_t dist, TopkHeap &update_heap); //! expand neighbors until group nums are reached void expand_neighbors_by_group(TopkHeap &topk, HnswSparseContext *ctx) const; private: HnswSparseAlgorithm(const HnswSparseAlgorithm &) = delete; HnswSparseAlgorithm &operator=(const HnswSparseAlgorithm &) = delete; private: static constexpr uint32_t kLockCnt{1U << 8}; static constexpr uint32_t kLockMask{kLockCnt - 1U}; HnswSparseEntity &entity_; mutable std::mt19937 mt_{}; std::vector level_probas_{}; mutable ailego::SpinMutex spin_lock_{}; // global spin lock std::mutex mutex_{}; // global mutex // TODO: spin lock? std::vector lock_pool_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_builder.h" #include #include #include #include #include #include #include "hnsw_sparse_algorithm.h" #include "hnsw_sparse_params.h" namespace zvec { namespace core { HnswSparseBuilder::HnswSparseBuilder() {} int HnswSparseBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { LOG_INFO("Begin HnswSparseBuilder::init"); meta_ = meta; auto params_copy = params; meta_.set_builder("HnswSparseBuilder", HnswSparseEntity::kRevision, std::move(params_copy)); size_t memory_quota = 0UL; params.get(PARAM_HNSW_SPARSE_BUILDER_MEMORY_QUOTA, &memory_quota); params.get(PARAM_HNSW_SPARSE_BUILDER_THREAD_COUNT, &thread_cnt_); params.get(PARAM_HNSW_SPARSE_BUILDER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_SPARSE_BUILDER_CHECK_INTERVAL_SECS, &check_interval_secs_); params.get(PARAM_HNSW_SPARSE_BUILDER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswSparseEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_SPARSE_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_SPARSE_BUILDER_SCALING_FACTOR, &scaling_factor_); multiplier = HnswSparseEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_SPARSE_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; if (ef_construction_ == 0) { ef_construction_ = HnswSparseEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0) { upper_max_neighbor_cnt_ = HnswSparseEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d]", PARAM_HNSW_SPARSE_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%d] must be <= [%s]-[%d]", PARAM_HNSW_SPARSE_BUILDER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, PARAM_HNSW_SPARSE_BUILDER_MAX_NEIGHBOR_COUNT.c_str(), upper_max_neighbor_cnt_); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0) { l0_max_neighbor_cnt_ = HnswSparseEntity::kDefaultUpperMaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswSparseEntity::kMaxNeighborCnt) { LOG_ERROR("L0MaxNeighborCnt must be in range (0,%d)", HnswSparseEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswSparseEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_SPARSE_BUILDER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (thread_cnt_ == 0) { thread_cnt_ = std::thread::hardware_concurrency(); } if (thread_cnt_ > std::thread::hardware_concurrency()) { LOG_WARN("[%s] greater than cpu cores %u", PARAM_HNSW_SPARSE_BUILDER_THREAD_COUNT.c_str(), std::thread::hardware_concurrency()); } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMeasure failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } int ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMeasure init failed, ret=%d", ret); return ret; } entity_.set_ef_construction(ef_construction_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_memory_quota(memory_quota); entity_.set_prune_cnt(prune_cnt); entity_.set_sparse_meta_size(HnswSparseEntity::kSparseMetaSize); entity_.set_sparse_unit_size(meta.unit_size()); ret = entity_.init(); if (ret != 0) { return ret; } alg_ = HnswSparseAlgorithm::UPointer(new HnswSparseAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } state_ = BUILD_STATE_INITED; LOG_INFO( "End HnswSparseBuilder::init, params: efConstruction=%u " "l0NeighborCnt=%u upperNeighborCnt=%u scalingFactor=%u " "memoryQuota=%zu neighborPruneCnt=%zu measureName=%s ", ef_construction_, l0_max_neighbor_cnt_, upper_max_neighbor_cnt_, scaling_factor_, memory_quota, prune_cnt, meta_.metric_name().c_str()); return 0; } int HnswSparseBuilder::cleanup(void) { LOG_INFO("Begin HnswSparseBuilder::cleanup"); l0_max_neighbor_cnt_ = HnswSparseEntity::kDefaultL0MaxNeighborCnt; min_neighbor_cnt_ = 0; upper_max_neighbor_cnt_ = HnswSparseEntity::kDefaultUpperMaxNeighborCnt; ef_construction_ = HnswSparseEntity::kDefaultEfConstruction; scaling_factor_ = HnswSparseEntity::kDefaultScalingFactor; check_interval_secs_ = kDefaultLogIntervalSecs; errcode_ = 0; error_ = false; entity_.cleanup(); alg_->cleanup(); meta_.clear(); metric_.reset(); stats_.clear_attributes(); stats_.set_trained_count(0UL); stats_.set_built_count(0UL); stats_.set_dumped_count(0UL); stats_.set_discarded_count(0UL); stats_.set_trained_costtime(0UL); stats_.set_built_costtime(0UL); stats_.set_dumped_costtime(0UL); state_ = BUILD_STATE_INIT; LOG_INFO("End HnswSparseBuilder::cleanup"); return 0; } int HnswSparseBuilder::train(IndexThreads::Pointer, IndexSparseHolder::Pointer /*holder*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswSparseBuilder::train"); return IndexError_NoReady; } stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswSparseBuilder::train"); return 0; } int HnswSparseBuilder::train(const IndexTrainer::Pointer & /*trainer*/) { if (state_ != BUILD_STATE_INITED) { LOG_ERROR("Init the builder before HnswSparseBuilder::train"); return IndexError_NoReady; } LOG_INFO("Begin HnswSparseBuilder::train by trainer"); stats_.set_trained_count(0UL); stats_.set_trained_costtime(0UL); state_ = BUILD_STATE_TRAINED; LOG_INFO("End HnswSparseBuilder::train by trainer"); return 0; } int HnswSparseBuilder::build(IndexThreads::Pointer threads, IndexSparseHolder::Pointer holder) { if (!holder) { LOG_ERROR("Input holder is nullptr while building index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while building index"); return IndexError_Mismatch; } if (!threads) { threads = std::make_shared(thread_cnt_, false); if (!threads) { return IndexError_NoMemory; } } auto start_time = ailego::Monotime::MilliSeconds(); LOG_INFO("Begin HnswSparseBuilder::build sparse"); // holder should be hybrid holder auto sparse_holder = std::dynamic_pointer_cast(holder); if (sparse_holder == nullptr) { LOG_ERROR("HnswSparseBuilder failed to cast holder"); return IndexError_Runtime; } if (sparse_holder->count() != static_cast(-1)) { LOG_DEBUG("HnswSparseBuilder holder documents count %lu", sparse_holder->count()); int ret = entity_.reserve_space(sparse_holder->count(), sparse_holder->total_sparse_count()); if (ret != 0) { LOG_ERROR("HnswBuilde reserver space failed"); return ret; } } auto iter = sparse_holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } int ret; error_ = false; while (iter->is_valid()) { level_t level = alg_->get_random_level(); node_id_t id; ret = entity_.add_vector(level, iter->key(), iter->sparse_count(), iter->sparse_indices(), iter->sparse_data(), &id); if (ailego_unlikely(ret != 0) && ret != IndexError_InvalidValue) { return ret; } iter->next(); } // Holder is not needed, cleanup it. sparse_holder.reset(); LOG_INFO("Finished save vector, start build graph..."); std::atomic finished{0}; ret = build_graph(threads, finished); if (ret != 0) { LOG_ERROR("Failed to build graph"); return ret; } stats_.set_built_count(finished.load()); stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = BUILD_STATE_BUILT; LOG_INFO("End HnswSparseBuilder::build"); return 0; } int HnswSparseBuilder::build_graph(IndexThreads::Pointer threads, std::atomic &finished) { auto task_group = threads->make_group(); if (!task_group) { LOG_ERROR("Failed to create task group"); return IndexError_Runtime; } for (size_t i = 0; i < threads->count(); ++i) { task_group->submit(ailego::Closure ::New(this, &HnswSparseBuilder::do_build, i, threads->count(), &finished)); } while (!task_group->is_finished()) { std::unique_lock lk(mutex_); cond_.wait_until(lk, std::chrono::system_clock::now() + std::chrono::seconds(check_interval_secs_)); if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } LOG_INFO("Built cnt %u, finished percent %.3f%%", finished.load(), finished.load() * 100.0f / entity_.doc_cnt()); } if (error_.load(std::memory_order_acquire)) { LOG_ERROR("Failed to build index while waiting finish"); return errcode_; } task_group->wait_finish(); return 0; } void HnswSparseBuilder::do_build(node_id_t idx, size_t step_size, std::atomic *finished) { AILEGO_DEFER([&]() { std::lock_guard latch(mutex_); cond_.notify_one(); }); HnswSparseContext *ctx = new (std::nothrow) HnswSparseContext( metric_, std::shared_ptr(&entity_, [](HnswSparseEntity *) {})); if (ailego_unlikely(ctx == nullptr)) { if (!error_.exchange(true)) { LOG_ERROR("Failed to create context"); errcode_ = IndexError_NoMemory; } return; } HnswSparseContext::Pointer auto_ptr(ctx); ctx->set_max_scan_num(entity_.doc_cnt()); int ret = ctx->init(HnswSparseContext::kSparseBuilderContext); if (ret != 0) { if (!error_.exchange(true)) { LOG_ERROR("Failed to init context"); errcode_ = IndexError_Runtime; } return; } IndexQueryMeta qmeta(meta_.data_type()); for (node_id_t id = idx; id < entity_.doc_cnt(); id += step_size) { const void *vec = entity_.get_vector_meta(id); auto sparse_data = entity_.get_sparse_data_from_vector(vec); ctx->reset_query(sparse_data.first); ret = alg_->add_node(id, entity_.get_level(id), ctx); if (ailego_unlikely(ret != 0)) { if (!error_.exchange(true)) { LOG_ERROR("Hnsw graph add node failed"); errcode_ = ret; } return; } ctx->clear(); (*finished)++; } } int HnswSparseBuilder::dump(const IndexDumper::Pointer &dumper) { if (state_ != BUILD_STATE_BUILT) { LOG_INFO("Build the index before HnswSparseBuilder::dump"); return IndexError_NoReady; } LOG_INFO("Begin HnswSparseBuilder::dump"); meta_.set_searcher("HnswSparseSearcher", HnswSparseEntity::kRevision, ailego::Params()); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } ret = entity_.dump(dumper); if (ret != 0) { LOG_ERROR("HnswSparseBuilder dump index failed"); return ret; } stats_.set_dumped_count(entity_.doc_cnt()); stats_.set_dumped_costtime(ailego::Monotime::MilliSeconds() - start_time); LOG_INFO("EndHnswSparseBuilder::dump"); return 0; } int HnswSparseBuilder::build(IndexThreads::Pointer threads, size_t count, const uint64_t *keys, const uint64_t *sparse_indptr, const uint32_t *sparse_indices, const void *sparse_data) { IndexQueryMeta qmeta(meta_.data_type()); return build(threads, qmeta, count, keys, sparse_indptr, sparse_indices, sparse_data); } int HnswSparseBuilder::build(IndexThreads::Pointer threads, const IndexQueryMeta &qmeta, size_t count, const uint64_t *keys, const uint64_t *sparse_indptr, const uint32_t *sparse_indices, const void *sparse_data) { if (!threads) { threads = std::make_shared(thread_cnt_, false); if (!threads) { return IndexError_NoMemory; } } auto start_time = ailego::Monotime::MilliSeconds(); LOG_INFO("Begin HnswSparseBuilder::build sparse, documents count %lu", count); size_t total_sparse_count = sparse_indptr[count]; int ret = entity_.reserve_space(count, total_sparse_count); if (ret != 0) { LOG_ERROR("HnswBuilde reserver space failed"); return ret; } if (qmeta.data_type() == meta_.data_type()) { for (size_t i = 0; i < count; i++) { level_t level = alg_->get_random_level(); node_id_t id; uint32_t sparse_count = sparse_indptr[i + 1] - sparse_indptr[i]; const uint32_t *sparse_indices_temp = sparse_indices + sparse_indptr[i]; const void *sparse_data_temp = static_cast(sparse_data) + sparse_indptr[i] * qmeta.unit_size(); ret = entity_.add_vector(level, keys[i], sparse_count, sparse_indices_temp, sparse_data_temp, &id); if (ailego_unlikely(ret != 0) && ret != IndexError_InvalidValue) { return ret; } } } else if (meta_.data_type() == IndexMeta::DataType::DT_FP16 && qmeta.data_type() == IndexMeta::DataType::DT_FP32) { // transform from float 32 to float 16 auto reformer = IndexFactory::CreateReformer("HalfFloatSparseReformer"); if (!reformer) { LOG_ERROR("Sparse reformer not existed."); return IndexError_NoExist; } meta_.set_converter("HalfFloatSparseConverter", 0, ailego::Params()); meta_.set_reformer("HalfFloatSparseReformer", 0, ailego::Params()); for (size_t i = 0; i < count; i++) { level_t level = alg_->get_random_level(); node_id_t id; uint32_t sparse_count = sparse_indptr[i + 1] - sparse_indptr[i]; const uint32_t *sparse_indices_temp = sparse_indices + sparse_indptr[i]; const void *sparse_data_temp = static_cast(sparse_data) + sparse_indptr[i] * qmeta.unit_size(); std::string query_fp16; IndexQueryMeta ometa; reformer->transform(sparse_count, sparse_indices_temp, sparse_data_temp, qmeta, &query_fp16, &ometa); ret = entity_.add_vector(level, keys[i], sparse_count, sparse_indices_temp, query_fp16.data(), &id); if (ailego_unlikely(ret != 0) && ret != IndexError_InvalidValue) { return ret; } } } else { LOG_ERROR("Format not supported."); return IndexError_Unsupported; } LOG_INFO("Finished save vector, start build graph..."); std::atomic finished{0}; ret = build_graph(threads, finished); if (ret != 0) { LOG_ERROR("Failed to build graph"); return ret; } stats_.set_built_count(finished.load()); stats_.set_built_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = BUILD_STATE_BUILT; LOG_INFO("End HnswSparseBuilder::build"); return 0; } INDEX_FACTORY_REGISTER_BUILDER(HnswSparseBuilder); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_sparse_algorithm.h" #include "hnsw_sparse_builder_entity.h" namespace zvec { namespace core { class HnswSparseBuilder : public IndexBuilder { public: //! Constructor HnswSparseBuilder(); //! Initialize the builder int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder int cleanup(void) override; //! Train the data int train(IndexThreads::Pointer, IndexSparseHolder::Pointer holder) override; //! Train the data int train(const IndexTrainer::Pointer &trainer) override; //! Build the index int build(IndexThreads::Pointer threads, IndexSparseHolder::Pointer holder) override; //! Build the index with indptr format int build(IndexThreads::Pointer threads, const IndexQueryMeta &qmeta, size_t count, const uint64_t *keys, const uint64_t *sparse_indptr, const uint32_t *sparse_indices, const void *sparse_data) override; //! Build the index with indptr format int build(IndexThreads::Pointer threads, size_t count, const uint64_t *keys, const uint64_t *sparse_indptr, const uint32_t *sparse_indices, const void *sparse_data) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } private: int build_graph(IndexThreads::Pointer threads, std::atomic &finished); void do_build(node_id_t idx, size_t step_size, std::atomic *finished); constexpr static uint32_t kDefaultLogIntervalSecs = 15U; constexpr static uint32_t kMaxNeighborCnt = 65535; private: enum BUILD_STATE { BUILD_STATE_INIT = 0, BUILD_STATE_INITED = 1, BUILD_STATE_TRAINED = 2, BUILD_STATE_BUILT = 3 }; HnswSparseBuilderEntity entity_{}; HnswSparseAlgorithm::UPointer alg_; // impl graph algorithm uint32_t thread_cnt_{0}; uint32_t l0_max_neighbor_cnt_{HnswSparseEntity::kDefaultL0MaxNeighborCnt}; uint32_t min_neighbor_cnt_{0}; uint32_t upper_max_neighbor_cnt_{ HnswSparseEntity::kDefaultUpperMaxNeighborCnt}; uint32_t ef_construction_{HnswSparseEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswSparseEntity::kDefaultScalingFactor}; uint32_t check_interval_secs_{kDefaultLogIntervalSecs}; int errcode_{0}; std::atomic_bool error_{false}; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; std::mutex mutex_{}; std::condition_variable cond_{}; Stats stats_{}; BUILD_STATE state_{BUILD_STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_builder_entity.h" #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswSparseBuilderEntity::HnswSparseBuilderEntity() { update_ep_and_level(kInvalidNodeId, 0U); } int HnswSparseBuilderEntity::cleanup() { memory_quota_ = 0UL; neighbors_size_ = 0U; upper_neighbors_size_ = 0U; padding_size_ = 0U; vectors_buffer_.clear(); keys_buffer_.clear(); neighbors_buffer_.clear(); upper_neighbors_buffer_.clear(); neighbors_index_.clear(); vectors_buffer_.shrink_to_fit(); keys_buffer_.shrink_to_fit(); neighbors_buffer_.shrink_to_fit(); upper_neighbors_buffer_.shrink_to_fit(); neighbors_index_.shrink_to_fit(); this->HnswSparseEntity::cleanup(); return 0; } int HnswSparseBuilderEntity::init() { size_t size = vector_size(); size += sparse_meta_size(); //! aligned size to 32 set_node_size(AlignSize(size)); //! if node size is aligned to 1k, the build performance will downgrade if (node_size() % 1024 == 0) { set_node_size(AlignSize(node_size() + 1)); } padding_size_ = node_size() - size; neighbors_size_ = neighbors_size(); upper_neighbors_size_ = upper_neighbors_size(); return 0; } int HnswSparseBuilderEntity::reserve_space(size_t docs, size_t total_sparse_count) { if (memory_quota_ > 0 && (node_size() * docs + neighbors_size_ * docs + sizeof(SparseNeighborIndex) * docs > memory_quota_)) { return IndexError_NoMemory; } vectors_buffer_.reserve(node_size() * docs); keys_buffer_.reserve(sizeof(key_t) * docs); neighbors_buffer_.reserve(neighbors_size_ * docs); neighbors_index_.reserve(docs); sparse_data_buffer_.reserve(sizeof(uint32_t) * docs + (sizeof(uint32_t)) * total_sparse_count + sparse_unit_size() * total_sparse_count); return 0; } int HnswSparseBuilderEntity::add_vector(level_t level, key_t key, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec, node_id_t *id) { if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { LOG_WARN( "Failed to add sparse vector: number of non-zero elements (%u) exceeds " "maximum allowed (%u), key=%zu", sparse_count, HnswSparseEntity::kSparseMaxDimSize, (size_t)key); return IndexError_InvalidValue; } std::string sparse_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_vec, sparse_unit_size(), sparse_buffer); uint32_t sparse_len = sparse_buffer.size(); if (memory_quota_ > 0 && (vectors_buffer_.capacity() + keys_buffer_.capacity() + neighbors_buffer_.capacity() + upper_neighbors_buffer_.capacity() + neighbors_index_.capacity() * sizeof(SparseNeighborIndex) + sparse_len > memory_quota_)) { LOG_ERROR("Add vector failed, used memory exceed quota, cur_doc=%u", doc_cnt()); return IndexError_NoMemory; } vectors_buffer_.append(reinterpret_cast(&sparse_data_offset_), sizeof(uint64_t)); vectors_buffer_.append(reinterpret_cast(&sparse_len), sizeof(uint32_t)); vectors_buffer_.append(sizeof(uint32_t), '\0'); // reserve to make it up to meta size vectors_buffer_.append(padding_size_, '\0'); keys_buffer_.append(reinterpret_cast(&key), sizeof(key)); sparse_data_buffer_.append(sparse_buffer.data(), sparse_len); sparse_data_offset_ += sparse_len; // init level 0 neighbors neighbors_buffer_.append(neighbors_size_, '\0'); neighbors_index_.emplace_back(upper_neighbors_buffer_.size(), level); // init upper layer neighbors for (level_t cur_level = 1; cur_level <= level; ++cur_level) { upper_neighbors_buffer_.append(upper_neighbors_size_, '\0'); } *id = (*mutable_doc_cnt())++; return 0; } key_t HnswSparseBuilderEntity::get_key(node_id_t id) const { return *(reinterpret_cast(keys_buffer_.data() + id * sizeof(key_t))); } const void *HnswSparseBuilderEntity::get_vector_meta(node_id_t id) const { return vectors_buffer_.data() + id * node_size(); } int HnswSparseBuilderEntity::get_vector_meta( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector_meta(id); block.reset((void *)vec); return 0; } int HnswSparseBuilderEntity::get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const { for (uint32_t i = 0; i < count; ++i) { vecs[i] = vectors_buffer_.data() + ids[i] * node_size(); } return 0; } int HnswSparseBuilderEntity::get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const { const void *vecs[count]; get_vector_metas(ids, count, vecs); for (uint32_t i = 0; i < count; ++i) { block_vecs.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } //! Get vector feature data by key const void *HnswSparseBuilderEntity::get_sparse_data(uint64_t offset, uint32_t /*len*/) const { return reinterpret_cast(sparse_data_buffer_.data()) + offset; } int HnswSparseBuilderEntity::get_sparse_data( uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const { const void *vec = get_sparse_data(offset, len); block.reset((void *)vec); return 0; } //! Get sparse data from id const void *HnswSparseBuilderEntity::get_sparse_data(node_id_t id) const { auto sparse_data = get_sparse_data_from_vector(get_vector_meta(id)); return sparse_data.first; } int HnswSparseBuilderEntity::get_sparse_data( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_sparse_data(id); block.reset((void *)vec); return 0; } //! Get sparse data from vector std::pair HnswSparseBuilderEntity::get_sparse_data_from_vector(const void *vec) const { uint32_t vec_size = vector_size(); const char *vec_ptr = reinterpret_cast(vec); uint64_t offset = *((uint64_t *)(vec_ptr + vec_size)); uint32_t sparse_vector_len = *((uint32_t *)(vec_ptr + vec_size + sizeof(uint64_t))); const void *sparse_data = get_sparse_data(offset, sparse_vector_len); if (ailego_unlikely(sparse_data == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)offset, sparse_vector_len); return std::make_pair(nullptr, 0); } return std::make_pair(sparse_data, sparse_vector_len); } int HnswSparseBuilderEntity::get_sparse_data_from_vector( const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const { std::pair sparse_data = get_sparse_data_from_vector(vec); block.reset((void *)sparse_data.first); sparse_length = sparse_data.second; return 0; } const Neighbors HnswSparseBuilderEntity::get_neighbors(level_t level, node_id_t id) const { const NeighborsHeader *hd = get_neighbor_header(level, id); return {hd->neighbor_cnt, hd->neighbors}; } int HnswSparseBuilderEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); for (size_t i = 0; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } hd->neighbor_cnt = neighbors.size(); // std::cout << "id: " << id << ", neighbour, id: "; // for (size_t i = 0; i < neighbors.size(); ++i) { // if (i == neighbors.size()-1) // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // std::endl; // else // std::cout << neighbors[i].first << ", score:" << neighbors[i].second << // ", id: "; // } return 0; } void HnswSparseBuilderEntity::add_neighbor(level_t level, node_id_t id, uint32_t /*size*/, node_id_t neighbor_id) { NeighborsHeader *hd = const_cast(get_neighbor_header(level, id)); hd->neighbors[hd->neighbor_cnt++] = neighbor_id; return; } int HnswSparseBuilderEntity::dump(const IndexDumper::Pointer &dumper) { key_t *keys = reinterpret_cast(const_cast(keys_buffer_.data())); auto ret = dump_segments(dumper, keys, [&](node_id_t id) { return get_level(id); }); if (ailego_unlikely(ret < 0)) { return ret; } return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_builder_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_sparse_entity.h" namespace zvec { namespace core { class HnswSparseBuilderEntity : public HnswSparseEntity { public: //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t level, key_t key, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec, node_id_t *id) override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector_meta(node_id_t id) const override; virtual int get_vector_meta(const node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Batch get vectors feature data by keys virtual int get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const override; //! Get the node id's neighbors on graph level const NeighborsHeader *get_neighbor_header(level_t level, node_id_t id) const { if (level == 0) { return reinterpret_cast( neighbors_buffer_.data() + neighbors_size_ * id); } else { size_t offset = neighbors_index_[id].offset; return reinterpret_cast( upper_neighbors_buffer_.data() + offset + (level - 1) * upper_neighbors_size_); } } //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Replace node id in level's neighbors virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! add a neighbor to id in graph level virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Get vector sparse feature data by chunk index and offset virtual const void *get_sparse_data(uint64_t offset, uint32_t len) const override; //! Get sparse data from id virtual const void *get_sparse_data(node_id_t id) const override; virtual int get_sparse_data(uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const override; virtual int get_sparse_data(const node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Get sparse data from vector virtual std::pair get_sparse_data_from_vector( const void *vec) const override; virtual int get_sparse_data_from_vector(const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const override; //! Dump the hnsw graph to dumper virtual int dump(const IndexDumper::Pointer &dumper) override; //! Cleanup the entity virtual int cleanup(void) override; public: //! Constructor HnswSparseBuilderEntity(); //! Get the node graph level by id level_t get_level(node_id_t id) const { return neighbors_index_[id].level; } //! Init builerEntity int init(); //! reserve buffer space for documents //! @param docs number of documents //! @param total_sparse_count total dim of sparse count int reserve_space(size_t docs, size_t total_sparse_count); //! Set memory quota params inline void set_memory_quota(size_t memory_quota) { memory_quota_ = memory_quota; } //! Get neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get upper neighbors size inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } public: HnswSparseBuilderEntity(const HnswSparseBuilderEntity &) = delete; HnswSparseBuilderEntity &operator=(const HnswSparseBuilderEntity &) = delete; private: friend class HnswSparseSearcherEntity; //! class internal used only struct SparseNeighborIndex { SparseNeighborIndex(size_t off, level_t l) : offset(off), level(l) {} uint64_t offset : 48; uint64_t level : 16; }; std::string vectors_buffer_{}; // aligned vectors std::string keys_buffer_{}; // aligned vectors std::string neighbors_buffer_{}; // level 0 neighbors buffer std::string upper_neighbors_buffer_{}; // upper layer neighbors buffer std::string sparse_data_buffer_{}; // aligned spase data buffer size_t sparse_data_offset_{0}; // // upper layer offset + level in upper_neighbors_buffer_ std::vector neighbors_index_{}; size_t memory_quota_{0UL}; size_t neighbors_size_{0U}; // level 0 neighbors size size_t upper_neighbors_size_{0U}; // level 0 neighbors size size_t padding_size_{}; // padding size for each vector element }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_chunk.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_chunk.h" #include #include #include #include #include #include #include #include namespace zvec { namespace core { int SparseChunkBroker::init_storage(size_t chunk_size) { chunk_meta_.clear(); chunk_meta_.chunk_size = chunk_size; chunk_meta_.create_time = ailego::Realtime::Seconds(); stats_.set_create_time(chunk_meta_.create_time); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); //! alloc meta chunk size_t size = sizeof(HnswSparseChunkMeta); size = (size + page_mask_) & (~page_mask_); const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return ret; } chunk_meta_segment_ = get_chunk(CHUNK_TYPE_META, kDefaultChunkSeqId); if (ailego_unlikely(!chunk_meta_segment_)) { LOG_ERROR("Get meta segment failed"); return IndexError_Runtime; } //! update meta info and write to storage chunk_meta_.chunk_cnts[CHUNK_TYPE_META] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswSparseChunkMeta)); if (ailego_unlikely(size != sizeof(HnswSparseChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); return IndexError_WriteData; } return 0; } int SparseChunkBroker::load_storage(size_t chunk_size) { IndexStorage::MemoryBlock data_block; size_t size = chunk_meta_segment_->read(0UL, data_block, chunk_meta_segment_->data_size()); if (size != sizeof(HnswSparseChunkMeta)) { LOG_ERROR("Invalid hnsw meta chunk, read size=%zu chunk size=%zu", size, chunk_meta_segment_->data_size()); return IndexError_InvalidFormat; } std::memcpy(&chunk_meta_, data_block.data(), size); if (chunk_meta_.chunk_size != chunk_size) { LOG_ERROR( "Params hnsw chunk size=%zu mismatch from previous %zu " "in index", chunk_size, (size_t)chunk_meta_.chunk_size); return IndexError_Mismatch; } *stats_.mutable_check_point() = stg_->check_point(); stats_.set_revision_id(chunk_meta_.revision_id); stats_.set_update_time(chunk_meta_.update_time); stats_.set_create_time(chunk_meta_.create_time); char create_time[32]; char update_time[32]; ailego::Realtime::Gmtime(chunk_meta_.create_time, "%Y-%m-%d %H:%M:%S", create_time, sizeof(create_time)); ailego::Realtime::Gmtime(chunk_meta_.update_time, "%Y-%m-%d %H:%M:%S", update_time, sizeof(update_time)); LOG_DEBUG( "Load index, indexSize=%zu chunkSize=%zu nodeChunks=%zu " "upperNeighborChunks=%zu revisionId=%zu " "createTime=%s updateTime=%s", (size_t)chunk_meta_.total_size, (size_t)chunk_meta_.chunk_size, (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_NODE], (size_t)chunk_meta_.chunk_cnts[CHUNK_TYPE_UPPER_NEIGHBOR], (size_t)chunk_meta_.revision_id, create_time, update_time); return 0; } int SparseChunkBroker::open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc) { if (ailego_unlikely(stg_)) { LOG_ERROR("An storage instance is already opened"); return IndexError_Duplicate; } stg_ = std::move(stg); check_crc_ = check_crc; max_chunks_size_ = max_index_size; dirty_ = false; const std::string segment_id = make_segment_id(CHUNK_TYPE_META, kDefaultChunkSeqId); chunk_meta_segment_ = stg_->get(segment_id); if (!chunk_meta_segment_) { LOG_DEBUG("Create new index"); return init_storage(chunk_size); } return load_storage(chunk_size); } int SparseChunkBroker::close(void) { flush(0UL); stg_.reset(); check_crc_ = false; dirty_ = false; return 0; } int SparseChunkBroker::flush(uint64_t checkpoint) { ailego_assert_with(chunk_meta_segment_, "invalid meta segment"); chunk_meta_.update_time = ailego::Realtime::Seconds(); stats_.set_update_time(chunk_meta_.update_time); size_t size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswSparseChunkMeta)); if (ailego_unlikely(size != sizeof(HnswSparseChunkMeta))) { LOG_ERROR("Storage write data failed, wsize=%zu", size); } stg_->refresh(checkpoint); int ret = stg_->flush(); if (ret == 0) { (*stats_.mutable_check_point()) = checkpoint; } else { LOG_ERROR("Storage flush failed for %s", IndexError::What(ret)); } return ret; } std::pair SparseChunkBroker::alloc_chunk( int type, uint64_t seq_id, size_t size) { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); SparseChunk::Pointer chunk; if (ailego_unlikely(!stg_)) { LOG_ERROR("Init storage first"); return std::make_pair(IndexError_Uninitialized, chunk); } //! check exist a empty chunk with the same name chunk = get_chunk(type, seq_id); if (chunk) { if (ailego_unlikely(chunk->capacity() == size && chunk->data_size() == 0UL)) { LOG_ERROR("Exist invalid chunk size %zu, expect size %zu", chunk->capacity(), size); chunk.reset(); return std::make_pair(IndexError_Runtime, chunk); } return std::make_pair(0, chunk); } //! align to page size size = (size + page_mask_) & (~page_mask_); if (ailego_unlikely(chunk_meta_.total_size + size >= max_chunks_size_)) { LOG_ERROR("No space to new a chunk, curIndexSize=%zu allocSize=%zu", (size_t)chunk_meta_.total_size, size); return std::make_pair(IndexError_IndexFull, chunk); } std::string segment_id = make_segment_id(type, seq_id); int ret = stg_->append(segment_id, size); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Storage append segment failed for %s", IndexError::What(ret)); return std::make_pair(ret, chunk); } chunk_meta_.chunk_cnts[type] += 1; chunk_meta_.total_size += size; (*stats_.mutable_index_size()) += size; size = chunk_meta_segment_->write(0UL, &chunk_meta_, sizeof(HnswSparseChunkMeta)); if (ailego_unlikely(size != sizeof(HnswSparseChunkMeta))) { LOG_ERROR("Storage append segment failed, wsize=%zu", size); } chunk = get_chunk(type, seq_id); return std::make_pair(chunk ? 0 : IndexError_NoMemory, chunk); } SparseChunk::Pointer SparseChunkBroker::get_chunk(int type, uint64_t seq_id) const { std::string segment_id = make_segment_id(type, seq_id); return stg_->get(segment_id); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_chunk.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace zvec { namespace core { using SparseChunk = IndexStorage::Segment; class SparseChunkBroker { public: typedef std::shared_ptr Pointer; enum CHUNK_TYPE { CHUNK_TYPE_HEADER = 1, CHUNK_TYPE_META = 2, CHUNK_TYPE_NODE = 3, CHUNK_TYPE_UPPER_NEIGHBOR = 4, CHUNK_TYPE_NEIGHBOR_INDEX = 5, CHUNK_TYPE_SPARSE_NODE = 6, CHUNK_TYPE_MAX = 8 }; static constexpr size_t kDefaultChunkSeqId = 0UL; SparseChunkBroker(IndexStreamer::Stats &stats) : stats_(stats) { page_mask_ = ailego::MemoryHelper::PageSize() - 1; } //! Open storage int open(IndexStorage::Pointer stg, size_t max_index_size, size_t chunk_size, bool check_crc); int close(void); int flush(uint64_t checkpoint); //! alloc a new chunk with size, not thread-safe std::pair alloc_chunk(int type, uint64_t seq_id, size_t size); //! alloc a new chunk with chunk size inline std::pair alloc_chunk(int type, uint64_t seq_id) { return alloc_chunk(type, seq_id, chunk_meta_.chunk_size); } SparseChunk::Pointer get_chunk(int type, uint64_t seq_id) const; inline size_t get_chunk_cnt(int type) const { ailego_assert_with(type < CHUNK_TYPE_MAX, "chunk type overflow"); return chunk_meta_.chunk_cnts[type]; } inline bool dirty(void) const { return dirty_; } inline void mark_dirty(void) { if (!dirty_) { dirty_ = true; chunk_meta_.revision_id += 1; stats_.set_revision_id(chunk_meta_.revision_id); } } const IndexStorage::Pointer storage(void) const { return stg_; } private: SparseChunkBroker(const SparseChunkBroker &) = delete; SparseChunkBroker &operator=(const SparseChunkBroker &) = delete; struct HnswSparseChunkMeta { HnswSparseChunkMeta(void) { memset(this, 0, sizeof(HnswSparseChunkMeta)); } void clear() { memset(this, 0, sizeof(HnswSparseChunkMeta)); } uint64_t chunk_cnts[CHUNK_TYPE_MAX]; uint64_t chunk_size; // size of per chunk uint64_t total_size; // total size of allocated chunk uint64_t revision_id; // index revision uint64_t create_time; uint64_t update_time; uint64_t reserved[3]; }; static_assert(sizeof(HnswSparseChunkMeta) % 32 == 0, "HnswSparseChunkMeta must be aligned with 32 bytes"); //! Init the storage after open an empty index int init_storage(size_t chunk_size); //! Load index from storage int load_storage(size_t chunk_size); static inline const std::string make_segment_id(int type, uint64_t seq_id) { return "HnswT" + ailego::StringHelper::ToString(type) + "S" + ailego::StringHelper::ToString(seq_id); } private: IndexStreamer::Stats &stats_; HnswSparseChunkMeta chunk_meta_{}; size_t page_mask_{0UL}; size_t max_chunks_size_{0UL}; IndexStorage::Pointer stg_{}; IndexStorage::Segment::Pointer chunk_meta_segment_{}; bool check_crc_{false}; bool dirty_{false}; // set as true if index is modified , the flag // will not be cleared even if flushed }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_context.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_context.h" #include #include "hnsw_sparse_params.h" namespace zvec { namespace core { HnswSparseContext::HnswSparseContext(const IndexMetric::Pointer &metric, const HnswSparseEntity::Pointer &entity) : IndexContext(metric), entity_(entity), dc_(entity_.get(), metric) {} HnswSparseContext::~HnswSparseContext() { visit_filter_.destroy(); } int HnswSparseContext::init(ContextType type) { int ret; uint32_t doc_cnt; type_ = type; switch (type) { case kSparseBuilderContext: ret = visit_filter_.init(VisitFilter::ByteMap, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); update_heap_.limit(entity_->l0_neighbor_cnt() + 1); break; case kSparseSearcherContext: ret = visit_filter_.init(filter_mode_, entity_->doc_cnt(), max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } candidates_.limit(max_scan_num_); break; case kSparseStreamerContext: // maxScanNum is unknown if inited from streamer, so the docCnt may // change. we need to compute maxScanNum by scan ratio, and preserve // max_doc_cnt space from visit filter doc_cnt = entity_->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); ret = visit_filter_.init(filter_mode_, reserve_max_doc_cnt_, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } update_heap_.limit(entity_->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); check_need_adjuct_ctx(); break; default: LOG_ERROR("Init context failed"); return IndexError_Runtime; } return 0; } int HnswSparseContext::update(const ailego::Params ¶ms) { LOG_DEBUG("Update hnsw context params"); auto update_visit_filter_param = [&]() { bool need_update = false; std::string p; switch (type_) { case kSparseSearcherContext: p = PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_ENABLE; break; case kSparseStreamerContext: p = PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_ENABLE; break; } if (params.has(p)) { bool bf_enabled; params.get(p, &bf_enabled); if (bf_enabled ^ (filter_mode_ == VisitFilter::BloomFilter)) { need_update = true; filter_mode_ = bf_enabled ? VisitFilter::BloomFilter : VisitFilter::ByteMap; } } float prob = negative_probability_; p.clear(); switch (type_) { case kSparseSearcherContext: p = PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; case kSparseStreamerContext: p = PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB; break; } params.get(p, &prob); if (filter_mode_ == VisitFilter::BloomFilter && std::abs(prob - negative_probability_) > 1e-6) { need_update = true; } if (need_update) { visit_filter_.destroy(); int max_doc_cnt = 0; if (type_ == kSparseSearcherContext) { max_doc_cnt = entity_->doc_cnt(); } else { max_doc_cnt = reserve_max_doc_cnt_; } int ret = visit_filter_.init(filter_mode_, max_doc_cnt, max_scan_num_, negative_probability_); if (ret != 0) { LOG_ERROR("Create filter failed, mode %d", filter_mode_); return ret; } } return 0; }; switch (type_) { case kSparseSearcherContext: if (params.has(PARAM_HNSW_SPARSE_SEARCHER_EF)) { params.get(PARAM_HNSW_SPARSE_SEARCHER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } if (params.has(PARAM_HNSW_SPARSE_SEARCHER_MAX_SCAN_RATIO)) { params.get(PARAM_HNSW_SPARSE_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); max_scan_num_ = static_cast(max_scan_ratio_ * entity_->doc_cnt()); max_scan_num_ = std::max(10000U, max_scan_num_); } if (params.has(PARAM_HNSW_SPARSE_SEARCHER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_SPARSE_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); case kSparseStreamerContext: if (params.has(PARAM_HNSW_SPARSE_STREAMER_EF)) { params.get(PARAM_HNSW_SPARSE_STREAMER_EF, &ef_); topk_heap_.limit(std::max(topk_, ef_)); } params.get(PARAM_HNSW_SPARSE_STREAMER_EF, &ef_); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_SPARSE_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_SPARSE_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (params.has(PARAM_HNSW_SPARSE_STREAMER_BRUTE_FORCE_THRESHOLD)) { params.get(PARAM_HNSW_SPARSE_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); } return update_visit_filter_param(); default: LOG_ERROR("update context failed, type=%u", type_); return IndexError_Runtime; } } int HnswSparseContext::update_context(ContextType type, const IndexMeta & /*meta*/, const IndexMetric::Pointer &metric, const HnswSparseEntity::Pointer &entity, uint32_t magic_num) { uint32_t doc_cnt; if (ailego_unlikely(type != type_)) { LOG_ERROR( "HnswSparseContext doesn't support shared by different type, " "src=%u dst=%u", type_, type); return IndexError_Unsupported; } magic_ = kInvalidMgic; // TODO: support change filter mode? switch (type) { case kSparseBuilderContext: LOG_ERROR("BuildContext doesn't support update"); return IndexError_NotImplemented; case kSparseSearcherContext: if (!visit_filter_.reset(entity->doc_cnt(), max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; case kSparseStreamerContext: doc_cnt = entity->doc_cnt(); max_scan_num_ = compute_max_scan_num(doc_cnt); reserve_max_doc_cnt_ = doc_cnt + compute_reserve_cnt(doc_cnt); if (!visit_filter_.reset(reserve_max_doc_cnt_, max_scan_num_)) { LOG_ERROR("Reset filter failed, mode %d", visit_filter_.get_mode()); return IndexError_Runtime; } update_heap_.limit(entity->l0_neighbor_cnt() + 1); candidates_.limit(max_scan_num_); topk_heap_.limit(std::max(topk_, ef_)); break; default: LOG_ERROR("update context failed"); return IndexError_Runtime; } entity_ = entity; dc_.update(entity_.get(), metric); magic_ = magic_num; level_topks_.clear(); return 0; } void HnswSparseContext::fill_random_to_topk_full(void) { static std::mt19937 mt( std::chrono::system_clock::now().time_since_epoch().count()); std::uniform_int_distribution dt(0, entity_->doc_cnt() - 1); std::function gen; node_id_t seqid; std::function myfilter = [](node_id_t) { return false; }; if (this->filter().is_valid()) { myfilter = [&](node_id_t id) { return this->filter()(entity_->get_key(id)); }; } if (topk_heap_.limit() < entity_->doc_cnt() / 2) { gen = [&](void) { return dt(mt); }; } else { // If topk limit is big value, gen sequential id from an random initial seqid = dt(mt); gen = [&](void) { seqid = seqid == (entity_->doc_cnt() - 1) ? 0 : (seqid + 1); return seqid; }; } for (size_t i = 0; !topk_heap_.full() && i < entity_->doc_cnt(); ++i) { const auto id = gen(); if (!visit_filter_.visited(id) && !myfilter(id)) { visit_filter_.set_visited(id); topk_heap_.emplace(id, dc_.dist(id)); } } return; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "utility/sparse_utility.h" #include "utility/visit_filter.h" #include "hnsw_sparse_dist_calculator.h" namespace zvec { namespace core { class HnswSparseContext : public IndexContext { public: //! Index Context Pointer typedef std::unique_ptr Pointer; enum ContextType { kUnknownContext = 0, kSparseSearcherContext = 1, kSparseBuilderContext = 2, kSparseStreamerContext = 3, }; //! Construct HnswSparseContext(const IndexMetric::Pointer &metric, const HnswSparseEntity::Pointer &entity); //! Destructor virtual ~HnswSparseContext(); public: //! Set topk of search result virtual void set_topk(uint32_t val) override { topk_ = val; topk_heap_.limit(std::max(val, ef_)); } //! Retrieve search result virtual const IndexDocumentList &result(void) const override { return results_[0]; } //! Retrieve search result virtual const IndexDocumentList &result(size_t idx) const override { return results_[idx]; } //! Retrieve result object for output virtual IndexDocumentList *mutable_result(size_t idx) override { ailego_assert_with(idx < results_.size(), "invalid idx"); return &results_[idx]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result(void) const override { return group_results_[0]; } //! Retrieve search group result with index virtual const IndexGroupDocumentList &group_result( size_t idx) const override { return group_results_[idx]; } virtual uint32_t magic(void) const override { return magic_; } //! Set mode of debug virtual void set_debug_mode(bool enable) override { debug_mode_ = enable; } //! Retrieve mode of debug virtual bool debug_mode(void) const override { return this->debugging(); } //! Retrieve string of debug virtual std::string debug_string(void) const override { char buf[4096]; size_t size = snprintf( buf, sizeof(buf), "scan_cnt=%zu,get_vector_cnt=%u,get_neighbors_cnt=%u,dup_node=%u", get_scan_num(), stats_get_vector_cnt_, stats_get_neighbors_cnt_, stats_visit_dup_cnt_); return std::string(buf, size); } //! Update the parameters of context virtual int update(const ailego::Params ¶ms) override; public: //! Init context int init(ContextType type); //! Update context, the context may be shared by different searcher/streamer int update_context(ContextType type, const IndexMeta &meta, const IndexMetric::Pointer &metric, const HnswSparseEntity::Pointer &entity, uint32_t magic_num); inline const HnswSparseEntity &get_entity() const { return *entity_; } inline void resize_results(size_t size) { if (group_by_search()) { group_results_.resize(size); } else { results_.resize(size); } } inline void topk_to_result() { return topk_to_result(0); } //! Construct result from topk heap, result will be normalized inline void topk_to_result(uint32_t idx) { if (group_by_search()) { topk_to_group_result(idx); } else { topk_to_single_result(idx); } } inline void recal_topk_dist() { TopkHeap heap(topk_heap_); topk_heap_.clear(); for (size_t i = 0; i < heap.size(); ++i) { node_id_t id = heap[i].first; dist_t dist = dc_.dist(id); topk_heap_.emplace_back(id, dist); } } inline void topk_to_single_result(uint32_t idx) { if (force_padding_topk_ && !topk_heap_.full() && topk_heap_.size() < entity_->doc_cnt()) { this->fill_random_to_topk_full(); } if (ailego_unlikely(topk_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(topk_heap_.size())); topk_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = topk_heap_[i].second; if (score > this->threshold()) { break; } node_id_t id = topk_heap_[i].first; if (fetch_vector_) { IndexSparseDocument sparse_doc; IndexStorage::MemoryBlock vec_block; entity_->get_sparse_data(id, vec_block); const void *sparse_data = vec_block.data(); if (sparse_data != nullptr) { SparseUtility::ReverseSparseFormat(sparse_data, sparse_doc, entity_->sparse_unit_size()); } results_[idx].emplace_back(entity_->get_key(id), score, id, entity_->get_vector_meta(id), sparse_doc); } else { results_[idx].emplace_back(entity_->get_key(id), score, id); } } return; } //! Construct result from topk heap, result will be normalized inline void topk_to_group_result(uint32_t idx) { ailego_assert_with(idx < group_results_.size(), "invalid idx"); group_results_[idx].clear(); std::vector> group_topk_list; std::vector> best_score_in_groups; for (auto itr = group_topk_heaps_.begin(); itr != group_topk_heaps_.end(); itr++) { const std::string &group_id = (*itr).first; auto &heap = (*itr).second; heap.sort(); if (heap.size() > 0) { float best_score = heap[0].second; best_score_in_groups.push_back(std::make_pair(group_id, best_score)); } } std::sort(best_score_in_groups.begin(), best_score_in_groups.end(), [](const std::pair &a, const std::pair &b) -> int { return a.second < b.second; }); // truncate to group num for (uint32_t i = 0; i < group_num() && i < best_score_in_groups.size(); ++i) { const std::string &group_id = best_score_in_groups[i].first; group_topk_list.emplace_back( std::make_pair(group_id, group_topk_heaps_[group_id])); } group_results_[idx].resize(group_topk_list.size()); for (uint32_t i = 0; i < group_topk_list.size(); ++i) { const std::string &group_id = group_topk_list[i].first; group_results_[idx][i].set_group_id(group_id); uint32_t size = std::min( group_topk_, static_cast(group_topk_list[i].second.size())); for (uint32_t j = 0; j < size; ++j) { auto score = group_topk_list[i].second[j].second; if (score > this->threshold()) { break; } node_id_t id = group_topk_list[i].second[j].first; if (fetch_vector_) { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score, id, entity_->get_vector_meta(id)); } else { group_results_[idx][i].mutable_docs()->emplace_back( entity_->get_key(id), score, id); } } } } inline void reset_query(const void *query) { dc_.reset_query(query); dc_.clear_compare_cnt(); } inline HnswSparseDistCalculator &dist_calculator() { return dc_; } inline TopkHeap &topk_heap() { return topk_heap_; } inline TopkHeap &update_heap() { return update_heap_; } inline VisitFilter &visit_filter() { return visit_filter_; } inline CandidateHeap &candidates() { return candidates_; } inline void set_max_scan_num(uint32_t max_scan_num) { max_scan_num_ = max_scan_num; } inline void set_max_scan_limit(uint32_t max_scan_limit) { max_scan_limit_ = max_scan_limit; } inline void set_min_scan_limit(uint32_t min_scan_limit) { min_scan_limit_ = min_scan_limit; } inline void set_ef(uint32_t v) { ef_ = v; } inline void set_filter_mode(uint32_t v) { filter_mode_ = v; } inline void set_filter_negative_probability(float v) { negative_probability_ = v; } inline void set_max_scan_ratio(float v) { max_scan_ratio_ = v; } virtual void set_magic(uint32_t v) { magic_ = v; } virtual void set_force_padding_topk(bool v) { force_padding_topk_ = v; } virtual void set_bruteforce_threshold(uint32_t v) override { bruteforce_threshold_ = v; } inline uint32_t get_bruteforce_threshold() const { return bruteforce_threshold_; } virtual void set_fetch_vector(bool v) override { fetch_vector_ = v; } virtual bool fetch_vector() const override { return fetch_vector_; } //! Reset context void reset(void) override { set_filter(nullptr); reset_threshold(); set_fetch_vector(false); set_group_params(0, 0); reset_group_by(); } inline std::map &group_topk_heaps() { return group_topk_heaps_; } inline TopkHeap &level_topk(int level) { if (ailego_unlikely(level_topks_.size() <= static_cast(level))) { int cur_level = level_topks_.size(); level_topks_.resize(level + 1); for (; cur_level <= level; ++cur_level) { size_t heap_size = std::max(entity_->neighbor_cnt(cur_level), entity_->ef_construction()); level_topks_[cur_level].clear(); level_topks_[cur_level].limit(heap_size); } } return level_topks_[level]; } inline void check_need_adjuct_ctx(void) { check_need_adjuct_ctx(entity_->doc_cnt()); } inline size_t compute_reserve_cnt(uint32_t cur_doc) const { if (cur_doc > kMaxReserveDocCnt) { return kMaxReserveDocCnt; } else if (cur_doc < kMinReserveDocCnt) { return kMinReserveDocCnt; } return cur_doc; } //! candidates heap and visitfilter need to resize as doc cnt growing up inline void check_need_adjuct_ctx(uint32_t doc_cnt) { if (ailego_unlikely(doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_)) { while (doc_cnt + kTriggerReserveCnt > reserve_max_doc_cnt_) { reserve_max_doc_cnt_ = reserve_max_doc_cnt_ + compute_reserve_cnt(reserve_max_doc_cnt_); } uint32_t max_scan_cnt = compute_max_scan_num(reserve_max_doc_cnt_); max_scan_num_ = max_scan_cnt; visit_filter_.reset(reserve_max_doc_cnt_, max_scan_cnt); candidates_.clear(); candidates_.limit(max_scan_num_); } } inline uint32_t compute_max_scan_num(uint32_t max_doc_cnt) const { uint32_t max_scan = max_doc_cnt * max_scan_ratio_; if (max_scan < min_scan_limit_) { max_scan = min_scan_limit_; } else if (max_scan > max_scan_limit_) { max_scan = max_scan_limit_; } return max_scan; } inline size_t get_scan_num() const { return dc_.compare_cnt(); } inline uint64_t reach_scan_limit() const { return dc_.compare_cnt() >= max_scan_num_; } inline bool error() const { return dc_.error(); } inline void clear() { dc_.clear(); if (ailego_unlikely(this->debugging())) { stats_get_neighbors_cnt_ = 0u; stats_get_vector_cnt_ = 0u; stats_visit_dup_cnt_ = 0u; } // do not clear results_ for the next query will need it for (auto &it : results_) { it.clear(); } } uint32_t *mutable_stats_get_neighbors() { return &stats_get_neighbors_cnt_; } uint32_t *mutable_stats_get_vector() { return &stats_get_vector_cnt_; } uint32_t *mutable_stats_visit_dup_cnt() { return &stats_visit_dup_cnt_; } inline bool debugging(void) const { return debug_mode_; } inline void update_dist_caculator_distance( const IndexMetric::MatrixSparseDistance &distance) { dc_.update_distance(distance); } //! Get topk inline uint32_t topk() const override { return topk_; } //! Get group topk inline uint32_t group_topk() const { return group_topk_; } //! Get group num inline uint32_t group_num() const { return group_num_; } //! Get if group by search inline bool group_by_search() { return group_num_ > 0; } //! Set group params void set_group_params(uint32_t group_num, uint32_t group_topk) override { group_num_ = group_num; group_topk_ = group_topk; topk_ = group_topk_ * group_num_; topk_heap_.limit(std::max(topk_, ef_)); group_topk_heaps_.clear(); } private: // Filling random nodes if topk not full void fill_random_to_topk_full(void); constexpr static uint32_t kTriggerReserveCnt = 4096UL; constexpr static uint32_t kMinReserveDocCnt = 4096UL; constexpr static uint32_t kMaxReserveDocCnt = 128 * 1024UL; constexpr static uint32_t kInvalidMgic = -1U; private: HnswSparseEntity::Pointer entity_; HnswSparseDistCalculator dc_; bool debug_mode_{false}; bool force_padding_topk_{false}; uint32_t max_scan_num_{0}; uint32_t max_scan_limit_{0}; uint32_t min_scan_limit_{0}; uint32_t reserve_max_doc_cnt_{kMinReserveDocCnt}; uint32_t topk_{0}; uint32_t group_topk_{0}; uint32_t filter_mode_{VisitFilter::ByteMap}; float negative_probability_{HnswSparseEntity::kDefaultBFNegativeProbability}; uint32_t ef_{HnswSparseEntity::kDefaultEf}; float max_scan_ratio_{HnswSparseEntity::kDefaultScanRatio}; uint32_t magic_{0U}; std::vector results_{}; std::vector group_results_{}; TopkHeap topk_heap_{}; TopkHeap update_heap_{}; std::vector level_topks_{}; CandidateHeap candidates_{}; VisitFilter visit_filter_{}; uint32_t bruteforce_threshold_{}; bool fetch_vector_{false}; uint32_t group_num_{0}; std::map group_topk_heaps_{}; uint32_t type_{kUnknownContext}; //! debug stats info uint32_t stats_get_neighbors_cnt_{0u}; uint32_t stats_get_vector_cnt_{0u}; uint32_t stats_visit_dup_cnt_{0u}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_dist_calculator.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_sparse_entity.h" namespace zvec { namespace core { class HnswSparseDistCalculator { public: typedef std::shared_ptr Pointer; public: //! Constructor HnswSparseDistCalculator(const HnswSparseEntity *entity, const IndexMetric::Pointer &metric) : entity_(entity), distance_(metric->sparse_distance()), query_{nullptr}, compare_cnt_(0) {} //! Constructor HnswSparseDistCalculator(const HnswSparseEntity *entity, const IndexMetric::Pointer &metric, const void *query) : entity_(entity), distance_(metric->sparse_distance()), query_(query), compare_cnt_(0) {} void update(const HnswSparseEntity *entity, const IndexMetric::Pointer &metric) { entity_ = entity; distance_ = metric->sparse_distance(); } inline void update_distance( const IndexMetric::MatrixSparseDistance &distance) { distance_ = distance; } //! Reset query vector data inline void reset_query(const void *query) { error_ = false; query_ = query; } //! Returns distance inline dist_t dist(const void *sparse_data_lhs, const void *sparse_data_rhs) { float score{0.0f}; if (ailego_unlikely(sparse_data_lhs == nullptr || sparse_data_rhs == nullptr)) { // LOG_WARN("Nullptr of sparse vector. Return dense score only"); // error_ = true; return score; } distance_(sparse_data_lhs, sparse_data_rhs, &score); return score; } //! Returns distance between query and vec. inline dist_t dist(const void *vec) { compare_cnt_++; auto sparse_data = entity_->get_sparse_data_from_vector(vec); if (sparse_data.first == nullptr) { error_ = true; return 0.0f; } return dist(sparse_data.first, query_); } //! Return distance between query and node id. inline dist_t dist(node_id_t id) { compare_cnt_++; const void *feat = entity_->get_vector_meta(id); if (ailego_unlikely(feat == nullptr)) { LOG_ERROR("Get nullptr vector, id=%u", id); error_ = true; return 0.0f; } auto sparse_data = entity_->get_sparse_data_from_vector(feat); if (sparse_data.first == nullptr) { error_ = true; return 0.0f; } return dist(sparse_data.first, query_); } //! Return dist node lhs between node rhs inline dist_t dist(node_id_t lhs, node_id_t rhs) { compare_cnt_++; const void *feat = entity_->get_vector_meta(lhs); const void *query = entity_->get_vector_meta(rhs); if (ailego_unlikely(feat == nullptr || query == nullptr)) { LOG_ERROR("Get nullptr vector"); error_ = true; return 0.0f; } auto feat_sparse_data = entity_->get_sparse_data_from_vector(feat); if (feat_sparse_data.first == nullptr) { error_ = true; return 0.0f; } auto query_sparse_data = entity_->get_sparse_data_from_vector(query); if (query_sparse_data.first == nullptr) { error_ = true; return 0.0f; } return dist(feat_sparse_data.first, query_sparse_data.first); } dist_t operator()(const void *vec) { return dist(vec); } dist_t operator()(id_t i) { return dist(i); } dist_t operator()(id_t lhs, id_t rhs) { return dist(lhs, rhs); } inline void clear() { compare_cnt_ = 0; error_ = false; } inline void clear_compare_cnt() { compare_cnt_ = 0; } inline bool error() const { return error_; } //! Get distances compute times inline uint32_t compare_cnt() const { return compare_cnt_; } private: HnswSparseDistCalculator(const HnswSparseDistCalculator &) = delete; HnswSparseDistCalculator &operator=(const HnswSparseDistCalculator &) = delete; private: const HnswSparseEntity *entity_; IndexMetric::MatrixSparseDistance distance_; const void *query_; uint32_t compare_cnt_; // record distance compute times bool error_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_entity.h" namespace zvec { namespace core { const std::string HnswSparseEntity::kSparseGraphHeaderSegmentId = "sparse_graph.header"; const std::string HnswSparseEntity::kSparseGraphFeaturesSegmentId = "sparse_graph.features"; const std::string HnswSparseEntity::kSparseGraphKeysSegmentId = "sparse_graph.keys"; const std::string HnswSparseEntity::kSparseGraphNeighborsSegmentId = "sparse_graph.neighbors"; const std::string HnswSparseEntity::kSparseGraphOffsetsSegmentId = "sparse_graph.offsets"; const std::string HnswSparseEntity::kSparseGraphMappingSegmentId = "sparse_graph.mapping"; const std::string HnswSparseEntity::kSparseHnswHeaderSegmentId = "sparse_hnsw.header"; const std::string HnswSparseEntity::kSparseHnswNeighborsSegmentId = "sparse_hnsw.neighbors"; const std::string HnswSparseEntity::kSparseHnswOffsetsSegmentId = "sparse_hnsw.offsets"; const std::string HnswSparseEntity::kSparseGraphVectorsSegmentId = "sparse_graph.vectors"; const std::string HnswSparseEntity::kSparseGraphVectorMetaSegmentId = "sparse_graph.vector_meta"; int HnswSparseEntity::CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size) { *padding_size = AlignSize(data_size) - data_size; if (*padding_size == 0) { return 0; } std::string padding(*padding_size, '\0'); if (dumper->write(padding.data(), *padding_size) != *padding_size) { LOG_ERROR("Append padding failed, size %lu", *padding_size); return IndexError_WriteData; } return 0; } int64_t HnswSparseEntity::dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const { size_t len = dumper->write(data, size); if (len != size) { LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", segment_id.c_str(), size, len); return IndexError_WriteData; } size_t padding_size = AlignSize(size) - size; if (padding_size > 0) { std::string padding(padding_size, '\0'); if (dumper->write(padding.data(), padding_size) != padding_size) { LOG_ERROR("Append padding failed, size %lu", padding_size); return IndexError_WriteData; } } uint32_t crc = ailego::Crc32c::Hash(data, size); int ret = dumper->append(segment_id, size, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); return ret; } return len + padding_size; } int64_t HnswSparseEntity::dump_header(const IndexDumper::Pointer &dumper, const HNSWSparseHeader &hd) const { //! dump basic graph header. header is aligned and does not need padding int64_t graph_hd_size = dump_segment(dumper, kSparseGraphHeaderSegmentId, &hd.graph, hd.graph.size); if (graph_hd_size < 0) { return graph_hd_size; } //! dump basic graph header. header is aligned and does not need padding int64_t hnsw_hd_size = dump_segment(dumper, kSparseHnswHeaderSegmentId, &hd.hnsw, hd.hnsw.size); if (hnsw_hd_size < 0) { return hnsw_hd_size; } return graph_hd_size + hnsw_hd_size; } void HnswSparseEntity::reshuffle_vectors( const std::function & /*get_level*/, std::vector * /*n2o_mapping*/, std::vector * /*o2n_mapping*/, key_t * /*keys*/) const { // TODO return; } int64_t HnswSparseEntity::dump_mapping_segment( const IndexDumper::Pointer &dumper, const key_t *keys) const { std::vector mapping(doc_cnt()); std::iota(mapping.begin(), mapping.end(), 0U); std::sort(mapping.begin(), mapping.end(), [&](node_id_t i, node_id_t j) { return keys[i] < keys[j]; }); size_t size = mapping.size() * sizeof(node_id_t); return dump_segment(dumper, kSparseGraphMappingSegmentId, mapping.data(), size); } int64_t HnswSparseEntity::dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const { HNSWSparseHeader dump_hd(header()); dump_hd.graph.node_size = sparse_meta_size(); std::vector n2o_mapping; // map new id to origin id std::vector o2n_mapping; // map origin id to new id reshuffle_vectors(get_level, &n2o_mapping, &o2n_mapping, keys); if (!o2n_mapping.empty()) { dump_hd.hnsw.entry_point = o2n_mapping[entry_point()]; } //! Dump header int64_t hd_size = dump_header(dumper, dump_hd); if (hd_size < 0) { return hd_size; } //! Dump vectors int64_t sparse_vector_meta_size = dump_sparse_vector_meta(dumper, n2o_mapping); if (sparse_vector_meta_size < 0) { return sparse_vector_meta_size; } int64_t sparse_vecs_size = dump_sparse_vector(dumper, n2o_mapping); if (sparse_vecs_size < 0) { return sparse_vecs_size; } //! Dump neighbors auto neighbors_size = dump_neighbors(dumper, get_level, n2o_mapping, o2n_mapping); if (neighbors_size < 0) { return neighbors_size; } //! free memory n2o_mapping = std::vector(); o2n_mapping = std::vector(); //! Dump keys size_t key_segment_size = doc_cnt() * sizeof(key_t); int64_t keys_size = dump_segment(dumper, kSparseGraphKeysSegmentId, keys, key_segment_size); if (keys_size < 0) { return keys_size; } //! Dump mapping int64_t mapping_size = dump_mapping_segment(dumper, keys); if (mapping_size < 0) { return mapping_size; } return hd_size + keys_size + sparse_vector_meta_size + sparse_vecs_size + neighbors_size + mapping_size; } int64_t HnswSparseEntity::dump_sparse_vector_meta( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const { const void *data = nullptr; uint32_t crc = 0U; size_t dump_size = 0UL; uint64_t sparse_data_offset = 0UL; uint64_t sparse_data_len = 0UL; //! dump vectors for (node_id_t id = 0; id < doc_cnt(); ++id) { data = get_vector_meta(reorder_mapping.empty() ? id : reorder_mapping[id]); if (ailego_unlikely(!data)) { return IndexError_ReadData; } const char *data_ptr = reinterpret_cast(data); sparse_data_len = *((uint32_t *)(data_ptr + sizeof(uint64_t))); size_t len = dumper->write(&sparse_data_offset, sizeof(uint64_t)); if (len != sizeof(uint64_t)) { LOG_ERROR("Dump sparse data offset failed, write=%zu expect=%zu", len, sizeof(uint64_t)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(&sparse_data_offset, sizeof(uint64_t), crc); dump_size += sizeof(uint64_t); len = dumper->write(&sparse_data_len, sizeof(uint64_t)); if (len != sizeof(uint64_t)) { LOG_ERROR("Dump sparse data len failed, write=%zu expect=%zu", len, sizeof(uint64_t)); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(&sparse_data_len, sizeof(uint64_t), crc); dump_size += sizeof(uint64_t); sparse_data_offset += sparse_data_len; } int ret = dumper->append(kSparseGraphVectorMetaSegmentId, dump_size, 0UL, crc); if (ret != 0) { LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); return ret; } return dump_size; } int64_t HnswSparseEntity::dump_sparse_vector( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const { uint32_t crc = 0U; size_t data_size = 0UL; const void *data = nullptr; uint64_t sparse_data_len = 0UL; uint32_t sparse_chunk_index = 0U; uint32_t sparse_chunk_offset = 0U; //! dump vectors for (node_id_t id = 0; id < doc_cnt(); ++id) { data = get_vector_meta(reorder_mapping.empty() ? id : reorder_mapping[id]); if (ailego_unlikely(!data)) { return IndexError_ReadData; } const char *data_ptr = reinterpret_cast(data); sparse_data_len = *((uint32_t *)(data_ptr + sizeof(uint64_t))); uint64_t sparse_offset = *((uint64_t *)(data_ptr)); const void *sparse = get_sparse_data(sparse_offset, sparse_data_len); if (ailego_unlikely(sparse == nullptr)) { LOG_ERROR("Get nullptr sparse, chunk index=%u, chunk offset=%u, len=%zu", sparse_chunk_index, sparse_chunk_offset, (size_t)sparse_data_len); return IndexError_ReadData; } size_t len = dumper->write(sparse, sparse_data_len); if (len != sparse_data_len) { LOG_ERROR("Dump sparse data failed, write=%zu expect=%zu", len, (size_t)sparse_data_len); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(sparse, sparse_data_len, crc); data_size += sparse_data_len; } int ret = dumper->append(kSparseGraphVectorsSegmentId, data_size, 0UL, crc); if (ret != 0) { LOG_ERROR("Dump vectors segment meta failed, ret %d", ret); return ret; } return data_size; } int64_t HnswSparseEntity::dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector graph_meta; graph_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; std::vector mapping(l0_neighbor_cnt()); uint32_t min_neighbor_count = 10000; uint32_t max_neighbor_count = 0; size_t sum_neighbor_count = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { const Neighbors neighbors = get_neighbors(0, reorder_mapping.empty() ? id : reorder_mapping[id]); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= l0_neighbor_cnt(), "invalid neighbors"); uint32_t neighbor_count = neighbors.size(); if (neighbor_count < min_neighbor_count) { min_neighbor_count = neighbor_count; } if (neighbor_count > max_neighbor_count) { max_neighbor_count = neighbor_count; } sum_neighbor_count += neighbor_count; graph_meta.emplace_back(offset, neighbor_count); size_t size = neighbors.size() * sizeof(node_id_t); const node_id_t *data = &neighbors[0]; if (!neighbor_mapping.empty()) { for (node_id_t i = 0; i < neighbors.size(); ++i) { mapping[i] = neighbor_mapping[neighbors[i]]; } data = mapping.data(); } if (dumper->write(data, size) != size) { LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, size); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(data, size, crc); offset += size; } uint32_t average_neighbor_count = 0; if (doc_cnt() > 0) { average_neighbor_count = sum_neighbor_count / doc_cnt(); } LOG_INFO( "Dump hnsw graph: min_neighbor_count[%u] max_neighbor_count[%u] " "average_neighbor_count[%u]", min_neighbor_count, max_neighbor_count, average_neighbor_count); size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kSparseGraphNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kSparseGraphNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kSparseGraphOffsetsSegmentId, graph_meta.data(), graph_meta.size() * sizeof(SparseGraphNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } int64_t HnswSparseEntity::dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { std::vector hnsw_meta; hnsw_meta.reserve(doc_cnt()); size_t offset = 0; uint32_t crc = 0; std::vector buffer(upper_neighbor_cnt() + 1); for (node_id_t id = 0; id < doc_cnt(); ++id) { node_id_t new_id = reorder_mapping.empty() ? id : reorder_mapping[id]; auto level = get_level(new_id); if (level == 0) { hnsw_meta.emplace_back(0U, 0U); continue; } hnsw_meta.emplace_back(offset, level); ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); for (level_t cur_level = 1; cur_level <= level; ++cur_level) { const Neighbors neighbors = get_neighbors(cur_level, new_id); ailego_assert_with(!!neighbors.data, "invalid neighbors"); ailego_assert_with(neighbors.size() <= neighbor_cnt(cur_level), "invalid neighbors"); memset(buffer.data(), 0, sizeof(node_id_t) * buffer.size()); buffer[0] = neighbors.size(); if (neighbor_mapping.empty()) { memcpy(&buffer[1], &neighbors[0], neighbors.size() * sizeof(node_id_t)); } else { for (node_id_t i = 0; i < neighbors.size(); ++i) { buffer[i + 1] = neighbor_mapping[neighbors[i]]; } } if (dumper->write(buffer.data(), sizeof(node_id_t) * buffer.size()) != sizeof(node_id_t) * buffer.size()) { LOG_ERROR("Dump graph neighbor id=%u failed, size %lu", id, sizeof(node_id_t) * buffer.size()); return IndexError_WriteData; } crc = ailego::Crc32c::Hash(buffer.data(), sizeof(node_id_t) * buffer.size(), crc); offset += sizeof(node_id_t) * buffer.size(); } } size_t padding_size = 0; int ret = CalcAndAddPadding(dumper, offset, &padding_size); if (ret != 0) { return ret; } ret = dumper->append(kSparseHnswNeighborsSegmentId, offset, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s failed, ret %d", kSparseHnswNeighborsSegmentId.c_str(), ret); return ret; } //! dump level 0 neighbors meta auto len = dump_segment(dumper, kSparseHnswOffsetsSegmentId, hnsw_meta.data(), hnsw_meta.size() * sizeof(HnswSparseNeighborMeta)); if (len < 0) { return len; } return len + offset + padding_size; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include namespace zvec { namespace core { using node_id_t = uint32_t; using key_t = uint64_t; using level_t = int32_t; using dist_t = float; using TopkHeap = ailego::KeyValueHeap; using CandidateHeap = ailego::KeyValueHeap>; constexpr node_id_t kInvalidNodeId = static_cast(-1); constexpr key_t kInvalidKey = static_cast(-1); class HnswSparseDistCalculator; struct SparseGraphHeader { uint32_t size; uint32_t version; uint32_t graph_type; uint32_t doc_count; uint32_t vector_size; uint32_t node_size; uint32_t l0_neighbor_count; uint32_t prune_type; uint32_t prune_neighbor_count; uint32_t ef_construction; uint32_t options; uint32_t min_neighbor_count; uint32_t sparse_meta_size; uint32_t sparse_unit_size; uint32_t total_sparse_count; uint8_t reserved[868]; }; static_assert(sizeof(SparseGraphHeader) % 32 == 0, "SparseGraphHeader must be aligned with 32 bytes"); //! Hnsw upper neighbor header struct HnswSparseHeader { uint32_t size; // header size uint32_t revision; // current total docs of the graph uint32_t upper_neighbor_count; uint32_t ef_construction; uint32_t scaling_factor; uint32_t max_level; uint32_t entry_point; uint32_t options; uint8_t reserved[30]; }; struct SparseData { public: SparseData() {}; SparseData(uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_vec) : count(sparse_count), indices(sparse_indices), vec(sparse_vec) {} uint32_t count{0}; const uint32_t *indices{nullptr}; const void *vec{nullptr}; }; static_assert(sizeof(HnswSparseHeader) % 32 == 0, "SparseGraphHeader must be aligned with 32 bytes"); //! Hnsw common header and upper neighbor header struct HNSWSparseHeader { HNSWSparseHeader() { clear(); } HNSWSparseHeader(const HNSWSparseHeader &header) { memcpy(this, &header, sizeof(header)); } HNSWSparseHeader &operator=(const HNSWSparseHeader &header) { memcpy(this, &header, sizeof(header)); return *this; } //! Reset state to zero, and the params is untouched void inline reset() { graph.doc_count = 0U; hnsw.entry_point = kInvalidNodeId; hnsw.max_level = 0; graph.total_sparse_count = 0U; } //! Clear all fields to init value void inline clear() { memset(this, 0, sizeof(HNSWSparseHeader)); hnsw.entry_point = kInvalidNodeId; graph.size = sizeof(SparseGraphHeader); hnsw.size = sizeof(HnswSparseHeader); graph.total_sparse_count = 0U; } size_t neighbor_cnt() const { return graph.l0_neighbor_count; } size_t upper_neighbor_cnt() const { return hnsw.upper_neighbor_count; } size_t vector_size() const { return graph.vector_size; } size_t ef_construction() const { return graph.ef_construction; } size_t scaling_factor() const { return hnsw.scaling_factor; } size_t neighbor_prune_cnt() const { return graph.prune_neighbor_count; } node_id_t entry_point() const { return hnsw.entry_point; } node_id_t doc_cnt() const { return graph.doc_count; } uint32_t total_sparse_count() const { return graph.total_sparse_count; } SparseGraphHeader graph; HnswSparseHeader hnsw; }; struct NeighborsHeader { uint32_t neighbor_cnt; node_id_t neighbors[0]; }; struct Neighbors { Neighbors() : cnt{0}, data{nullptr} {} Neighbors(uint32_t cnt_in, const node_id_t *data_in) : cnt{cnt_in}, data{data_in} {} Neighbors(IndexStorage::MemoryBlock &&mem_block) : neighbor_block{std::move(mem_block)} { auto hd = reinterpret_cast(neighbor_block.data()); cnt = hd->neighbor_cnt; data = hd->neighbors; } size_t size(void) const { return cnt; } const node_id_t &operator[](size_t idx) const { return data[idx]; } uint32_t cnt; const node_id_t *data; IndexStorage::MemoryBlock neighbor_block; }; //! level 0 neighbors offset struct SparseGraphNeighborMeta { SparseGraphNeighborMeta(size_t o, size_t cnt) : offset(o), neighbor_cnt(cnt) {} uint64_t offset : 48; uint64_t neighbor_cnt : 16; }; //! hnsw upper neighbors meta struct HnswSparseNeighborMeta { HnswSparseNeighborMeta(size_t o, size_t l) : offset(o), level(l) {} uint64_t offset : 48; // offset = idx * upper neighors size uint64_t level : 16; }; class HnswSparseEntity { public: //! Constructor HnswSparseEntity() {} //! Constructor HnswSparseEntity(const HNSWSparseHeader &hd) { header_ = hd; } //! Destructor virtual ~HnswSparseEntity() {} //! HnswSparseEntity Pointerd; typedef std::shared_ptr Pointer; //! Get max neighbor size of graph level inline size_t neighbor_cnt(level_t level) const { return level == 0 ? header_.graph.l0_neighbor_count : header_.hnsw.upper_neighbor_count; } //! get max neighbor size of graph level 0 inline size_t l0_neighbor_cnt() const { return header_.graph.l0_neighbor_count; } //! get min neighbor size of graph inline size_t min_neighbor_cnt() const { return header_.graph.min_neighbor_count; } //! get upper neighbor size of graph level other than 0 inline size_t upper_neighbor_cnt() const { return header_.hnsw.upper_neighbor_count; } //! Get current total doc of the hnsw graph inline node_id_t *mutable_doc_cnt() { return &header_.graph.doc_count; } inline node_id_t doc_cnt() const { return header_.graph.doc_count; } inline uint32_t *mutable_total_sparse_count() { return &header_.graph.total_sparse_count; } uint32_t total_sparse_count() const { return header_.graph.total_sparse_count; } //! Get hnsw graph scaling params inline size_t scaling_factor() const { return header_.hnsw.scaling_factor; } //! Get prune_size inline size_t prune_cnt() const { return header_.graph.prune_neighbor_count; } //! Current entity of top level graph inline node_id_t entry_point() const { return header_.hnsw.entry_point; } //! Current max graph level inline level_t cur_max_level() const { return header_.hnsw.max_level; } //! Retrieve index vector size size_t vector_size() const { return header_.graph.vector_size; } //! Retrieve node size size_t node_size() const { return header_.graph.node_size; } //! Retrieve ef constuction size_t ef_construction() const { return header_.graph.ef_construction; } //! Retrieve sparse meta size size_t sparse_meta_size() const { return header_.graph.sparse_meta_size; } //! Retrieve sparse unit size size_t sparse_unit_size() const { return header_.graph.sparse_unit_size; } void set_vector_size(size_t size) { header_.graph.vector_size = size; } void set_prune_cnt(size_t v) { header_.graph.prune_neighbor_count = v; } void set_scaling_factor(size_t val) { header_.hnsw.scaling_factor = val; } void set_l0_neighbor_cnt(size_t cnt) { header_.graph.l0_neighbor_count = cnt; } void set_min_neighbor_cnt(size_t cnt) { header_.graph.min_neighbor_count = cnt; } void set_upper_neighbor_cnt(size_t cnt) { header_.hnsw.upper_neighbor_count = cnt; } void set_ef_construction(size_t ef) { header_.graph.ef_construction = ef; } void set_sparse_meta_size(size_t size) { header_.graph.sparse_meta_size = size; } void set_sparse_unit_size(size_t size) { header_.graph.sparse_unit_size = size; } protected: inline const HNSWSparseHeader &header() const { return header_; } inline HNSWSparseHeader *mutable_header() { return &header_; } inline size_t header_size() const { return sizeof(header_); } void set_node_size(size_t size) { header_.graph.node_size = size; } //! Dump all segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segments( const IndexDumper::Pointer &dumper, key_t *keys, const std::function &get_level) const; private: //! dump mapping segment, for get_vector_by_key in provider int64_t dump_mapping_segment(const IndexDumper::Pointer &dumper, const key_t *keys) const; //! dump hnsw head by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_header(const IndexDumper::Pointer &dumper, const HNSWSparseHeader &hd) const; //! dump vectors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_sparse_vector_meta( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const; //! dump sparse vectors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_sparse_vector( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping) const; //! dump hnsw neighbors by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_neighbors(const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const { auto len1 = dump_graph_neighbors(dumper, reorder_mapping, neighbor_mapping); if (len1 < 0) { return len1; } auto len2 = dump_upper_neighbors(dumper, get_level, reorder_mapping, neighbor_mapping); if (len2 < 0) { return len2; } return len1 + len2; } //! dump segment by dumper //! Return dump size if success, errno(<0) in failure int64_t dump_segment(const IndexDumper::Pointer &dumper, const std::string &segment_id, const void *data, size_t size) const; //! Dump level 0 neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_graph_neighbors( const IndexDumper::Pointer &dumper, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; //! Dump upper level neighbors //! Return dump size if success, errno(<0) in failure int64_t dump_upper_neighbors( const IndexDumper::Pointer &dumper, const std::function &get_level, const std::vector &reorder_mapping, const std::vector &neighbor_mapping) const; public: //! Cleanup the entity virtual int cleanup(void) { header_.clear(); return 0; } //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswSparseEntity::Pointer clone() const { LOG_ERROR("Update neighbors not implemented"); return HnswSparseEntity::Pointer(); } //! Get primary key of the node id virtual key_t get_key(node_id_t id) const = 0; //! Get vector feature data by key virtual const void *get_vector_meta(node_id_t id) const = 0; virtual int get_vector_meta(const node_id_t id, IndexStorage::MemoryBlock &block) const = 0; //! Get vectors feature data by keys virtual int get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const = 0; virtual int get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const = 0; //! Retrieve a sparse vector using a primary key virtual int get_sparse_vector_by_key( uint64_t /*key*/, uint32_t * /*sparse_count*/, std::string * /*sparse_indices_buffer*/, std::string * /*sparse_values_buffer*/) const { LOG_ERROR("get sparse vector not implemented"); return IndexError_NotImplemented; } //! Retrieve a sparse vector using a primary key virtual int get_sparse_vector_by_id( node_id_t /*id*/, uint32_t * /*sparse_count*/, std::string * /*sparse_indices_buffer*/, std::string * /*sparse_values_buffer*/) const { LOG_ERROR("get sparse vector not implemented"); return IndexError_NotImplemented; } //! Get vector sparse feature data by chunk index and offset virtual const void *get_sparse_data(uint64_t offset, uint32_t len) const = 0; //! Get sparse data from id virtual const void *get_sparse_data(node_id_t id) const = 0; virtual int get_sparse_data(uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const = 0; virtual int get_sparse_data(const node_id_t id, IndexStorage::MemoryBlock &block) const = 0; //! Get sparse data from vector virtual std::pair get_sparse_data_from_vector( const void *vec) const = 0; virtual int get_sparse_data_from_vector(const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const = 0; //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const = 0; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t /*level*/, key_t /*key*/, const std::string & /*vec*/, uint32_t /*sparse_count*/, node_id_t * /*id*/) { return IndexError_NotImplemented; } virtual int add_vector(level_t /*level*/, key_t /*key*/, const uint32_t /*sparse_count*/, const uint32_t * /*sparse_indices*/, const void * /*sparse_vec*/, node_id_t * /*id*/) { return IndexError_NotImplemented; } //! Add vector and id virtual int add_vector_with_id(level_t /*level*/, node_id_t /*id*/, const std::string & /*vec*/, uint32_t /*sparse_count*/) { return IndexError_NotImplemented; } virtual int update_neighbors( level_t /*level*/, node_id_t /*id*/, const std::vector> & /*neighbors*/) { LOG_ERROR("Update neighbors dense not implemented"); return 0; } //! Append neighbor_id to node id neighbors on level, size is the current //! neighbors size. Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t /*level*/, node_id_t /*id*/, uint32_t /*size*/, node_id_t /*neighbor_id*/) { LOG_ERROR("Add neighbor not implemented"); } //! Update entry point and max level virtual void update_ep_and_level(node_id_t ep, level_t level) { header_.hnsw.entry_point = ep; header_.hnsw.max_level = level; } virtual int load(const IndexStorage::Pointer & /*container*/, bool /*check_crc*/) { LOG_ERROR("Load not implemented"); return IndexError_NotImplemented; } virtual int dump(const IndexDumper::Pointer & /*dumper*/) { LOG_ERROR("Dump not implemented"); return IndexError_NotImplemented; } static int CalcAndAddPadding(const IndexDumper::Pointer &dumper, size_t data_size, size_t *padding_size); protected: static inline size_t AlignSize(size_t size) { return (size + 0x1F) & (~0x1F); } static inline size_t AlignPageSize(size_t size) { size_t page_mask = ailego::MemoryHelper::PageSize() - 1; return (size + page_mask) & (~page_mask); } //! rearrange vectors to improve cache locality void reshuffle_vectors(const std::function &get_level, std::vector *n2o_mapping, std::vector *o2n_mapping, key_t *keys) const; public: const static std::string kSparseGraphHeaderSegmentId; const static std::string kSparseGraphFeaturesSegmentId; const static std::string kSparseGraphKeysSegmentId; const static std::string kSparseGraphNeighborsSegmentId; const static std::string kSparseGraphOffsetsSegmentId; const static std::string kSparseGraphMappingSegmentId; const static std::string kSparseHnswHeaderSegmentId; const static std::string kSparseHnswNeighborsSegmentId; const static std::string kSparseHnswOffsetsSegmentId; const static std::string kSparseGraphVectorsSegmentId; const static std::string kSparseGraphVectorMetaSegmentId; constexpr static uint32_t kRevision = 0U; constexpr static size_t kMaxGraphLayers = 15; constexpr static uint32_t kDefaultEfConstruction = 500; constexpr static uint32_t kDefaultEf = 500; constexpr static uint32_t kDefaultUpperMaxNeighborCnt = 50; // M of HNSW constexpr static uint32_t kDefaultL0MaxNeighborCnt = 100; constexpr static uint32_t kMaxNeighborCnt = 65535; constexpr static float kDefaultScanRatio = 0.1f; constexpr static uint32_t kDefaultMinScanLimit = 10000; constexpr static uint32_t kDefaultMaxScanLimit = std::numeric_limits::max(); constexpr static float kDefaultBFNegativeProbability = 0.001f; constexpr static uint32_t kDefaultScalingFactor = 50U; constexpr static uint32_t kDefaultBruteForceThreshold = 1000U; constexpr static uint32_t kDefaultDocsHardLimit = 1 << 30U; // 1 billion constexpr static float kDefaultDocsSoftLimitRatio = 0.9f; constexpr static size_t kMaxChunkSize = 0xFFFFFFFF; constexpr static size_t kDefaultChunkSize = 2UL * 1024UL * 1024UL; constexpr static size_t kDefaultMaxChunkCnt = 50000UL; constexpr static float kDefaultNeighborPruneMultiplier = 1.0f; // prune_cnt = upper_max_neighbor_cnt * multiplier constexpr static float kDefaultL0MaxNeighborCntMultiplier = 2.0f; // l0_max_neighbor_cnt = upper_max_neighbor_cnt * multiplier constexpr static uint32_t kSparseMetaSize = 2u * sizeof(uint64_t); constexpr static float kDefaultSparseNeighborRatio = 0.5f; constexpr static uint32_t kSparseMaxDimSize = 16384; constexpr static float kDefaultQueryFilteringRatio = 0.0f; // turn off protected: HNSWSparseHeader header_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_index_hash.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_sparse_chunk.h" namespace zvec { namespace core { //! Persistent hashmap implement through open addressing algorithm template ::value>::type> class HnswSparseIndexHashMap { using key_type = Key; using val_type = Val; struct Iterator { key_type first; val_type second; }; typedef Iterator *iterator; typedef Iterator Item; typedef const Iterator *const_iterator; class Slot { public: Slot(SparseChunk::Pointer &&chunk, const void *data) : chunk_(std::move(chunk)), items_(reinterpret_cast(data)) {} //! Return a empty loc or the key item loc Slot(SparseChunk::Pointer &&chunk, IndexStorage::MemoryBlock &&mem_block) : chunk_(std::move(chunk)), items_block_(std::move(mem_block)) { items_ = reinterpret_cast(items_block_.data()); } const_iterator find(key_type key, uint32_t max_items, uint32_t mask) const { auto it = &items_[key & mask]; for (auto i = 0U; i < max_items; ++i) { if (it->first == key || it->second == EmptyVal) { // LOG_DEBUG("i=%u", i); return it; } ++it; if (it == &items_[max_items]) { it = &items_[0]; } } return nullptr; } bool update(const_iterator it) { uint32_t offset = reinterpret_cast(it) - reinterpret_cast(&items_[0]); if (ailego_unlikely(chunk_->write(offset, it, sizeof(Item)) != sizeof(Item))) { LOG_ERROR("Chunk write failed"); return false; } return true; } private: SparseChunk::Pointer chunk_{}; const Item *items_{nullptr}; // point to chunk data IndexStorage::MemoryBlock items_block_{}; }; public: //! Init the hash //! broker the index allocator //! chunk_size the size of per chunk allocated, actual size may greater //! factor factor = 1/ratio, ratio is the probability of a squence //! number inserted to this container //! max the max number key can be inserted //! expansion_ratio memory expansion ratio int init(SparseChunkBroker::Pointer &broker, uint32_t chunk_size, uint32_t factor, size_t max, float expansion_ratio) { ailego_assert_with(expansion_ratio > 1.0f, "ratio must > 1.0f"); broker_ = broker; size_t items = std::ceil(chunk_size * 1.0f / sizeof(Item)); slot_items_ = 1UL << static_cast((std::ceil(std::log2(items)))); size_t range = slot_items_ * factor / expansion_ratio; mask_bits_ = std::floor(std::log2(range)); range = 1UL << mask_bits_; size_t max_slots = std::ceil(max * 1.0f / range); slots_.reserve(max_slots); slot_loc_mask_ = slot_items_ - 1U; int ret = load(); if (ret != 0) { return ret; } LOG_DEBUG( "HnswIndexHash init, chunkSize=%u factor=%u max=%zu " "ratio=%f slotItems=%u maxSlots=%zu maskBits=%u " "range=%zu", chunk_size, factor, max, expansion_ratio, slot_items_, max_slots, mask_bits_, range); return 0; } int cleanup(void) { broker_.reset(); slots_.clear(); slots_.shrink_to_fit(); mask_bits_ = 0U; slot_items_ = 0U; slot_loc_mask_ = 0U; return 0; } const_iterator end(void) const { return nullptr; } const_iterator find(const key_type key) const { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { return end(); } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); return it && it->second != EmptyVal ? it : nullptr; } bool insert(key_type key, val_type val) { auto idx = key >> mask_bits_; if (idx >= slots_.size()) { if (ailego_unlikely(idx >= slots_.capacity())) { LOG_ERROR("no space to insert"); return false; } for (auto i = slots_.size(); i <= idx; ++i) { if (ailego_unlikely(!alloc_slot(i))) { return false; } } } auto it = slots_[idx].find(key, slot_items_, slot_loc_mask_); if (ailego_unlikely(it == nullptr)) { LOG_ERROR("no space to insert"); return false; } //! TODO: write memory is ok? const_cast(it)->first = key; const_cast(it)->second = val; return slots_[idx].update(it); } private: bool alloc_slot(size_t idx) { ailego_assert_with(idx == slots_.size(), "invalid idx"); size_t size = slot_items_ * sizeof(Item); auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, idx, size); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return false; } SparseChunk::Pointer chunk = p.second; if (ailego_unlikely(chunk->resize(size) != size)) { LOG_ERROR("Chunk resize failed, size=%zu", size); return false; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); return true; } int load(void) { size_t slots_cnt = broker_->get_chunk_cnt(SparseChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX); for (size_t i = 0UL; i < slots_cnt; ++i) { auto chunk = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_NEIGHBOR_INDEX, i); if (!chunk) { LOG_ERROR("Get chunk failed, seq=%zu", i); return IndexError_InvalidFormat; } size_t size = sizeof(Item) * slot_items_; if (chunk->data_size() < size) { LOG_ERROR( "Hash params may be mismatch, seq=%zu, data_size=%zu " "expect=%zu", i, chunk->data_size(), size); return IndexError_InvalidFormat; } //! Read the whole data to memory IndexStorage::MemoryBlock data_block; if (ailego_unlikely(chunk->read(0U, data_block, size) != size)) { LOG_ERROR("Chunk read failed, size=%zu", size); return false; } slots_.emplace_back(std::move(chunk), std::move(data_block)); } return 0; } private: SparseChunkBroker::Pointer broker_{}; // chunk broker std::vector slots_{}; uint32_t mask_bits_{0U}; uint32_t slot_items_{}; // must be a power of 2 uint32_t slot_loc_mask_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_index_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_sparse_entity.h" namespace zvec { namespace core { class HnswSparseIndexProvider : public IndexSparseProvider { public: HnswSparseIndexProvider(const IndexMeta &meta, const HnswSparseEntity::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} HnswSparseIndexProvider(const HnswSparseIndexProvider &) = delete; HnswSparseIndexProvider &operator=(const HnswSparseIndexProvider &) = delete; public: //! Create a new iterator IndexSparseProvider::Iterator::Pointer create_iterator(void) override { return IndexSparseProvider::Iterator::Pointer(new (std::nothrow) Iterator(entity_)); } //! Retrieve count of vectors size_t count(void) const override { return entity_->doc_cnt(); } size_t total_sparse_count(void) const override { return entity_->total_sparse_count(); } //! Retrieve type of vector IndexMeta::DataType data_type(void) const override { return meta_.data_type(); } //! Retrieve a vector using a primary key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override { return entity_->get_sparse_vector_by_key( key, sparse_count, sparse_indices_buffer, sparse_values_buffer); } //! Retrieve the owner class const std::string &owner_class(void) const override { return owner_class_; } private: class Iterator : public IndexSparseProvider::Iterator { public: Iterator(const HnswSparseEntity::Pointer &entity) : entity_(entity), cur_id_(0U), valid_(false) { const void *sparse_data = entity_->get_sparse_data(cur_id_); if (sparse_data != nullptr) { valid_ = true; sparse_indices_buffer_.clear(); sparse_data_buffer_.clear(); SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, entity_->sparse_unit_size()); } } //! Retrieve sparse count virtual uint32_t sparse_count() const override { return sparse_count_; } //! Retrieve sparse indices virtual const uint32_t *sparse_indices() const override { return reinterpret_cast(sparse_indices_buffer_.data()); } //! Retrieve sparse data virtual const void *sparse_data() const override { return reinterpret_cast(sparse_data_buffer_.data()); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return cur_id_ < entity_->doc_cnt() && valid_; } //! Retrieve primary key virtual uint64_t key(void) const override { return entity_->get_key(cur_id_); } //! Next iterator virtual void next(void) override { cur_id_ = get_next_valid_id(cur_id_ + 1); if (cur_id_ < entity_->doc_cnt()) { const void *sparse_data = entity_->get_sparse_data(cur_id_); if (sparse_data != nullptr) { valid_ = true; sparse_indices_buffer_.clear(); sparse_data_buffer_.clear(); SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, entity_->sparse_unit_size()); } else { valid_ = false; } } } //! Reset the iterator void reset(void) { cur_id_ = get_next_valid_id(0); const void *sparse_data = entity_->get_sparse_data(cur_id_); if (sparse_data != nullptr) { valid_ = true; SparseUtility::ReverseSparseFormat( sparse_data, &sparse_count_, &sparse_indices_buffer_, &sparse_data_buffer_, entity_->sparse_unit_size()); } } private: node_id_t get_next_valid_id(node_id_t start_id) { for (node_id_t i = start_id; i < entity_->doc_cnt(); i++) { if (entity_->get_key(i) != kInvalidNodeId) { return i; } } return kInvalidNodeId; } private: const HnswSparseEntity::Pointer entity_; node_id_t cur_id_; uint32_t sparse_count_{0}; std::string sparse_indices_buffer_; std::string sparse_data_buffer_; bool valid_{false}; }; private: const IndexMeta &meta_; const HnswSparseEntity::Pointer entity_; const std::string owner_class_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { static const std::string PARAM_HNSW_SPARSE_BUILDER_THREAD_COUNT( "proxima.hnsw.sparse_builder.thread_count"); static const std::string PARAM_HNSW_SPARSE_BUILDER_MEMORY_QUOTA( "proxima.hnsw.sparse_builder.memory_quota"); static const std::string PARAM_HNSW_SPARSE_BUILDER_EFCONSTRUCTION( "proxima.hnsw.sparse_builder.efconstruction"); static const std::string PARAM_HNSW_SPARSE_BUILDER_SCALING_FACTOR( "proxima.hnsw.sparse_builder.scaling_factor"); static const std::string PARAM_HNSW_SPARSE_BUILDER_CHECK_INTERVAL_SECS( "proxima.hnsw.sparse_builder.check_interval_secs"); static const std::string PARAM_HNSW_SPARSE_BUILDER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw.sparse_builder.neighbor_prune_multiplier"); static const std::string PARAM_HNSW_SPARSE_BUILDER_MIN_NEIGHBOR_COUNT( "proxima.hnsw.sparse_builder.min_neighbor_count"); static const std::string PARAM_HNSW_SPARSE_BUILDER_MAX_NEIGHBOR_COUNT( "proxima.hnsw.sparse_builder.max_neighbor_count"); static const std::string PARAM_HNSW_SPARSE_BUILDER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw.sparse_builder.l0_max_neighbor_count_multiplier"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_EF( "proxima.hnsw.sparse_searcher.ef"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw.sparse_searcher.brute_force_threshold"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE( "proxima.hnsw.sparse_searcher.neighbors_in_memory_enable"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_MAX_SCAN_RATIO( "proxima.hnsw.sparse_searcher.max_scan_ratio"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_CHECK_CRC_ENABLE( "proxima.hnsw.sparse_searcher.check_crc_enable"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw.sparse_searcher.visit_bloomfilter_enable"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw.sparse_searcher.visit_bloomfilter_negative_prob"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw.sparse_searcher.force_padding_result_enable"); static const std::string PARAM_HNSW_SPARSE_SEARCHER_QUERY_FILTERING_RATIO( "proxima.hnsw.sparse_searcher.query_filtering_ratio"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_RATIO( "proxima.hnsw.sparse_streamer.max_scan_ratio"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MIN_SCAN_LIMIT( "proxima.hnsw.sparse_streamer.min_scan_limit"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_LIMIT( "proxima.hnsw.sparse_streamer.max_scan_limit"); static const std::string PARAM_HNSW_SPARSE_STREAMER_EF( "proxima.hnsw.sparse_streamer.ef"); static const std::string PARAM_HNSW_SPARSE_STREAMER_EFCONSTRUCTION( "proxima.hnsw.sparse_streamer.efconstruction"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MAX_NEIGHBOR_COUNT( "proxima.hnsw.sparse_streamer.max_neighbor_count"); static const std::string PARAM_HNSW_SPARSE_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER( "proxima.hnsw.sparse_streamer.l0_max_neighbor_count_multiplier"); static const std::string PARAM_HNSW_SPARSE_STREAMER_SCALING_FACTOR( "proxima.hnsw.sparse_streamer.scaling_factor"); static const std::string PARAM_HNSW_SPARSE_STREAMER_BRUTE_FORCE_THRESHOLD( "proxima.hnsw.sparse_streamer.brute_force_threshold"); static const std::string PARAM_HNSW_SPARSE_STREAMER_DOCS_HARD_LIMIT( "proxima.hnsw.sparse_streamer.docs_hard_limit"); static const std::string PARAM_HNSW_SPARSE_STREAMER_DOCS_SOFT_LIMIT( "proxima.hnsw.sparse_streamer.docs_soft_limit"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MAX_INDEX_SIZE( "proxima.hnsw.sparse_streamer.max_index_size"); static const std::string PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_ENABLE( "proxima.hnsw.sparse_streamer.visit_bloomfilter_enable"); static const std::string PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB( "proxima.hnsw.sparse_streamer.visit_bloomfilter_negative_prob"); static const std::string PARAM_HNSW_SPARSE_STREAMER_CHECK_CRC_ENABLE( "proxima.hnsw.sparse_streamer.check_crc_enable"); static const std::string PARAM_HNSW_SPARSE_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER( "proxima.hnsw.sparse_streamer.neighbor_prune_multiplier"); static const std::string PARAM_HNSW_SPARSE_STREAMER_CHUNK_SIZE( "proxima.hnsw.sparse_streamer.chunk_size"); static const std::string PARAM_HNSW_SPARSE_STREAMER_FILTER_SAME_KEY( "proxima.hnsw.sparse_streamer.filter_same_key"); static const std::string PARAM_HNSW_SPARSE_STREAMER_GET_VECTOR_ENABLE( "proxima.hnsw.sparse_streamer.get_vector_enable"); static const std::string PARAM_HNSW_SPARSE_STREAMER_MIN_NEIGHBOR_COUNT( "proxima.hnsw.sparse_streamer.min_neighbor_count"); static const std::string PARAM_HNSW_SPARSE_STREAMER_FORCE_PADDING_RESULT_ENABLE( "proxima.hnsw.sparse_streamer.force_padding_result_enable"); static const std::string PARAM_HNSW_SPARSE_STREAMER_QUERY_FILTERING_RATIO( "proxima.hnsw.sparse_streamer.query_filtering_ratio"); static const std::string PARAM_HNSW_SPARSE_REDUCER_WORKING_PATH( "proxima.hnsw.sparse_reducer.working_path"); static const std::string PARAM_HNSW_SPARSE_REDUCER_NUM_OF_ADD_THREADS( "proxima.hnsw.sparse_reducer.num_of_add_threads"); static const std::string PARAM_HNSW_SPARSE_REDUCER_INDEX_NAME( "proxima.hnsw.sparse_reducer.index_name"); static const std::string PARAM_HNSW_SPARSE_REDUCER_EFCONSTRUCTION( "proxima.hnsw.sparse_reducer.efconstruction"); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_searcher.h" #include "hnsw_sparse_algorithm.h" #include "hnsw_sparse_index_provider.h" #include "hnsw_sparse_params.h" namespace zvec { namespace core { HnswSparseSearcher::HnswSparseSearcher() {} HnswSparseSearcher::~HnswSparseSearcher() {} int HnswSparseSearcher::init(const ailego::Params &search_params) { params_ = search_params; params_.get(PARAM_HNSW_SPARSE_SEARCHER_EF, &ef_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_MAX_SCAN_RATIO, &max_scan_ratio_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_CHECK_CRC_ENABLE, &check_crc_enabled_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_NEIGHBORS_IN_MEMORY_ENABLE, &neighbors_in_memory_enabled_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_probability_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params_.get(PARAM_HNSW_SPARSE_SEARCHER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); query_filtering_enabled_ = params_.get(PARAM_HNSW_SPARSE_SEARCHER_QUERY_FILTERING_RATIO, &query_filtering_ratio_); if (ef_ == 0) { ef_ = HnswSparseEntity::kDefaultEf; } if (bf_negative_probability_ <= 0.0f || bf_negative_probability_ >= 1.0f) { LOG_ERROR( "[%s] must be in range (0,1)", PARAM_HNSW_SPARSE_SEARCHER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } if (query_filtering_enabled_ && (query_filtering_ratio_ <= 0.0f || query_filtering_ratio_ >= 1.0f)) { LOG_ERROR("[%s] must be in range (0, 1)", PARAM_HNSW_SPARSE_SEARCHER_QUERY_FILTERING_RATIO.c_str()); return IndexError_InvalidArgument; } entity_.set_neighbors_in_memory(neighbors_in_memory_enabled_); state_ = STATE_INITED; LOG_DEBUG( "Init params: ef=%u maxScanRatio=%f bfEnabled=%u checkCrcEnabled=%u " "neighborsInMemoryEnabled=%u bfNagtiveProb=%f bruteForceThreshold=%u " "forcePadding=%u filteringRatio=%f", ef_, max_scan_ratio_, bf_enabled_, check_crc_enabled_, neighbors_in_memory_enabled_, bf_negative_probability_, bruteforce_threshold_, force_padding_topk_enabled_, query_filtering_ratio_); return 0; } void HnswSparseSearcher::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } } int HnswSparseSearcher::cleanup() { LOG_INFO("Begin HnswSparseSearcher:cleanup"); metric_.reset(); meta_.clear(); stats_.clear_attributes(); stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); max_scan_ratio_ = HnswSparseEntity::kDefaultScanRatio; max_scan_num_ = 0U; ef_ = HnswSparseEntity::kDefaultEf; bf_enabled_ = false; bf_negative_probability_ = HnswSparseEntity::kDefaultBFNegativeProbability; bruteforce_threshold_ = HnswSparseEntity::kDefaultBruteForceThreshold; check_crc_enabled_ = false; neighbors_in_memory_enabled_ = false; entity_.cleanup(); state_ = STATE_INIT; LOG_INFO("End HnswSparseSearcher:cleanup"); return 0; } int HnswSparseSearcher::load(IndexStorage::Pointer container, IndexMetric::Pointer metric) { if (state_ != STATE_INITED) { LOG_ERROR("Init the searcher first before load index"); return IndexError_Runtime; } LOG_INFO("Begin HnswSparseSearcher:load"); auto start_time = ailego::Monotime::MilliSeconds(); int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } ret = entity_.load(container, check_crc_enabled_); if (ret != 0) { LOG_ERROR("HnswSparseSearcher load index failed"); return ret; } alg_ = HnswSparseAlgorithm::UPointer(new HnswSparseAlgorithm(entity_)); if (metric) { metric_ = metric; } else { metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("CreateMeasure failed, name: %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("IndexMetric init failed, ret=%d", ret); return ret; } if (metric_->query_metric()) { metric_ = metric_->query_metric(); } } // if (!metric_->is_matched(meta_)) { // LOG_ERROR("IndexMeasure not match index meta"); // return IndexError_Mismatch; // } max_scan_num_ = static_cast(max_scan_ratio_ * entity_.doc_cnt()); max_scan_num_ = std::max(4096U, max_scan_num_); stats_.set_loaded_count(entity_.doc_cnt()); stats_.set_loaded_costtime(ailego::Monotime::MilliSeconds() - start_time); state_ = STATE_LOADED; magic_ = IndexContext::GenerateMagic(); LOG_INFO("End HnswSparseSearcher::load"); return 0; } int HnswSparseSearcher::unload() { LOG_INFO("HnswSparseSearcher unload index"); meta_.clear(); entity_.cleanup(); metric_.reset(); max_scan_num_ = 0; stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); state_ = STATE_INITED; return 0; } int HnswSparseSearcher::update_context(HnswSparseContext *ctx) const { const HnswSparseEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_num(max_scan_num_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswSparseContext::kSparseSearcherContext, meta_, metric_, entity, magic_); } //! Similarity search with sparse inputs int HnswSparseSearcher::search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; std::string sparse_query_filtered_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); if (query_filtering_enabled_) { if (!SparseUtility::FilterSparseQuery( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, qmeta.data_type(), entity_.sparse_unit_size(), query_filtering_ratio_, &sparse_query_filtered_buffer)) { LOG_ERROR("Hnsw filtering failed"); return IndexError_Runtime; } ctx->reset_query(sparse_query_filtered_buffer.data()); } else { ctx->reset_query(sparse_query_buffer.data()); } int ret = alg_->search(ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } if (query_filtering_enabled_) { ctx->reset_query(sparse_query_buffer.data()); ctx->recal_topk_dist(); } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } //! Similarity search with sparse inputs int HnswSparseSearcher::search_bf_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { if (ailego_unlikely(!context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_Runtime; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } else { for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->topk_heap().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().dist(id); ctx->topk_heap().emplace(id, dist); } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } //! Similarity search with sparse inputs int HnswSparseSearcher::search_bf_by_p_keys_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!context)) { LOG_ERROR("The context is not created by this searcher"); return IndexError_Mismatch; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer int ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->resize_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_Runtime; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } else { for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->topk_heap().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().dist(id); ctx->topk_heap().emplace(id, dist); } } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } IndexSearcher::Context::Pointer HnswSparseSearcher::create_context() const { if (ailego_unlikely(state_ != STATE_LOADED)) { LOG_ERROR("Load the index first before create context"); return Context::Pointer(); } const HnswSparseEntity::Pointer search_ctx_entity = entity_.clone(); if (!search_ctx_entity) { LOG_ERROR("Failed to create search context entity"); return Context::Pointer(); } HnswSparseContext *ctx = new (std::nothrow) HnswSparseContext(metric_, search_ctx_entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswSparseContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_num(max_scan_num_); uint32_t filter_mode = bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; ctx->set_filter_mode(filter_mode); ctx->set_filter_negative_probability(bf_negative_probability_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswSparseContext::kSparseSearcherContext)) != 0) { LOG_ERROR("Init HnswSparseContext failed"); delete ctx; return Context::Pointer(); } return Context::Pointer(ctx); } IndexSearcher::SparseProvider::Pointer HnswSparseSearcher::create_sparse_provider(void) const { LOG_DEBUG("HnswSparseSearcher create sparse provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswSparseEntity failed"); return SparseProvider::Pointer(); } return SparseProvider::Pointer(new (std::nothrow) HnswSparseIndexProvider( meta_, entity, "HnswSparseSearcher")); } int HnswSparseSearcher::get_sparse_vector( uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { return entity_.get_sparse_vector_by_key( key, sparse_count, sparse_indices_buffer, sparse_values_buffer); } INDEX_FACTORY_REGISTER_SEARCHER(HnswSparseSearcher); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "hnsw_sparse_searcher_entity.h" #include "hnsw_sparse_streamer.h" namespace zvec { namespace core { class HnswSparseSearcher : public IndexSearcher { public: using ContextPointer = IndexSearcher::Context::Pointer; public: HnswSparseSearcher(void); virtual ~HnswSparseSearcher(void); HnswSparseSearcher(const HnswSparseSearcher &) = delete; HnswSparseSearcher &operator=(const HnswSparseSearcher &) = delete; protected: //! Initialize Searcher int init(const ailego::Params ¶ms) override; //! Cleanup Searcher int cleanup(void) override; //! Load Index from storage int load(IndexStorage::Pointer container, IndexMetric::Pointer measure) override; //! Unload index from storage int unload(void) override; //! Similarity search with sparse inputs int search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return search_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity search with sparse inputs int search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override { return search_bf_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override { return search_bf_by_p_keys_impl(&sparse_count, sparse_indices, sparse_query, p_keys, qmeta, 1, context); } //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch sparser vector by key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; //! Create a searcher context ContextPointer create_context() const override; //! Create a new iterator IndexSearcher::SparseProvider::Pointer create_sparse_provider( void) const override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve meta of index const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index const ailego::Params ¶ms(void) const override { return params_; } void print_debug_info() override; private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswSparseContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; HnswSparseSearcherEntity entity_{}; HnswSparseAlgorithm::UPointer alg_; // impl graph algorithm IndexMetric::Pointer metric_{}; IndexMeta meta_{}; ailego::Params params_{}; Stats stats_; uint32_t ef_{HnswSparseEntity::kDefaultEf}; uint32_t max_scan_num_{0U}; uint32_t bruteforce_threshold_{HnswSparseEntity::kDefaultBruteForceThreshold}; float max_scan_ratio_{HnswSparseEntity::kDefaultScanRatio}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool neighbors_in_memory_enabled_{false}; bool force_padding_topk_enabled_{false}; float bf_negative_probability_{ HnswSparseEntity::kDefaultBFNegativeProbability}; bool query_filtering_enabled_{false}; float query_filtering_ratio_{HnswSparseEntity::kDefaultQueryFilteringRatio}; uint32_t magic_{0U}; State state_{STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_searcher_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_searcher_entity.h" #include #include "utility/sparse_utility.h" namespace zvec { namespace core { HnswSparseSearcherEntity::HnswSparseSearcherEntity() {} int HnswSparseSearcherEntity::cleanup(void) { container_.reset(); sparse_vector_meta_.reset(); keys_.reset(); neighbors_.reset(); neighbors_meta_.reset(); sparse_vectors_.reset(); neighbors_in_memory_enabled_ = false; loaded_ = false; this->HnswSparseEntity::cleanup(); return 0; } key_t HnswSparseSearcherEntity::get_key(node_id_t id) const { const void *key; if (ailego_unlikely(keys_->read(id * sizeof(key_t), &key, sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidKey; } return *(reinterpret_cast(key)); } //! Get vector local id by key node_id_t HnswSparseSearcherEntity::get_id(key_t key) const { if (ailego_unlikely(!mapping_)) { LOG_ERROR("Index missing mapping segment"); return kInvalidNodeId; } //! Do binary search node_id_t start = 0UL; node_id_t end = doc_cnt(); const void *data; node_id_t idx = 0u; while (start < end) { idx = start + (end - start) / 2; if (ailego_unlikely( mapping_->read(idx * sizeof(node_id_t), &data, sizeof(node_id_t)) != sizeof(node_id_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } const key_t *mkey; node_id_t local_id = *reinterpret_cast(data); if (ailego_unlikely(keys_->read(local_id * sizeof(key_t), (const void **)(&mkey), sizeof(key_t)) != sizeof(key_t))) { LOG_ERROR("Read key from segment failed"); return kInvalidNodeId; } if (*mkey < key) { start = idx + 1; } else if (*mkey > key) { end = idx; } else { return local_id; } } return kInvalidNodeId; } int HnswSparseSearcherEntity::get_sparse_vector_by_key( key_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { *sparse_count = 0; auto id = get_id(key); if (id == kInvalidNodeId) { return IndexError_NoExist; } const void *sparse_data = get_sparse_data(id); if (sparse_data == nullptr) { return IndexError_InvalidValue; } SparseUtility::ReverseSparseFormat(sparse_data, sparse_count, sparse_indices_buffer, sparse_values_buffer, sparse_unit_size()); return 0; } const void *HnswSparseSearcherEntity::get_vector_meta(node_id_t id) const { size_t read_size = sparse_meta_size(); size_t offset = sparse_meta_size() * id; const void *vec; if (ailego_unlikely(sparse_vector_meta_->read(offset, &vec, read_size) != read_size)) { LOG_ERROR("Read vector from segment failed"); return nullptr; } return vec; } int HnswSparseSearcherEntity::get_vector_meta( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_vector_meta(id); block.reset((void *)vec); return 0; } int HnswSparseSearcherEntity::get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const { ailego_assert_with(count <= segment_datas_.size(), "invalid count"); size_t read_size = sparse_meta_size(); for (uint32_t i = 0; i < count; ++i) { segment_datas_[i].offset = sparse_meta_size() * ids[i]; segment_datas_[i].length = read_size; ailego_assert_with( segment_datas_[i].offset < sparse_vector_meta_->data_size(), "invalid offset"); } if (ailego_unlikely(!sparse_vector_meta_->read(&segment_datas_[0], count))) { LOG_ERROR("Read vectors from segment failed"); return IndexError_ReadData; } for (uint32_t i = 0; i < count; ++i) { vecs[i] = segment_datas_[i].data; } return 0; } int HnswSparseSearcherEntity::get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const { const void *vecs[count]; get_vector_metas(ids, count, vecs); for (uint32_t i = 0; i < count; ++i) { block_vecs.emplace_back(IndexStorage::MemoryBlock((void *)vecs[i])); } return 0; } const Neighbors HnswSparseSearcherEntity::get_neighbors(level_t level, node_id_t id) const { if (level == 0) { if (neighbors_in_memory_enabled_) { auto hd = reinterpret_cast( fixed_neighbors_.get() + neighbors_size() * id); return {hd->neighbor_cnt, hd->neighbors}; } const SparseGraphNeighborMeta *m; if (ailego_unlikely( neighbors_meta_->read(id * sizeof(SparseGraphNeighborMeta), (const void **)(&m), sizeof(SparseGraphNeighborMeta)) != sizeof(SparseGraphNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } const void *data; if (ailego_unlikely(neighbors_->read(m->offset, &data, m->neighbor_cnt * sizeof(node_id_t)) != m->neighbor_cnt * sizeof(node_id_t))) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } return {static_cast(m->neighbor_cnt), reinterpret_cast(data)}; } //! Read level > 0 neighbors const HnswSparseNeighborMeta *m; if (ailego_unlikely( upper_neighbors_meta_->read(id * sizeof(HnswSparseNeighborMeta), (const void **)(&m), sizeof(HnswSparseNeighborMeta)) != sizeof(HnswSparseNeighborMeta))) { LOG_ERROR("Read neighbors meta from segment failed"); return {0, nullptr}; } ailego_assert_with(level <= m->level, "invalid level"); size_t offset = m->offset + (level - 1) * upper_neighbors_size(); ailego_assert_with(offset <= upper_neighbors_->data_size(), "invalid offset"); const void *data; if (ailego_unlikely( upper_neighbors_->read(offset, &data, upper_neighbors_size()) != upper_neighbors_size())) { LOG_ERROR("Read neighbors from segment failed"); return {0, nullptr}; } auto hd = reinterpret_cast(data); return {hd->neighbor_cnt, hd->neighbors}; } int HnswSparseSearcherEntity::load(const IndexStorage::Pointer &container, bool check_crc) { container_ = container; int ret = load_segments(check_crc); if (ret != 0) { return ret; } loaded_ = true; LOG_INFO( "Index info: docCnt=%u entryPoint=%u maxLevel=%d efConstruct=%zu " "l0NeighborCnt=%zu upperNeighborCnt=%zu scalingFactor=%zu " "nodeSize=%zu sparesMetaSegmentSize=%zu keySegmentSize=%zu " "neighborsSegmentSize=%zu neighborsMetaSegmentSize=%zu " "sparseVectorSegmentSize=%zu", doc_cnt(), entry_point(), cur_max_level(), ef_construction(), l0_neighbor_cnt(), upper_neighbor_cnt(), scaling_factor(), node_size(), sparse_vector_meta_->data_size(), keys_->data_size(), neighbors_->data_size(), neighbors_meta_->data_size(), sparse_vectors_->data_size()); return 0; } int HnswSparseSearcherEntity::load_segments(bool check_crc) { //! load header const void *data = nullptr; HNSWSparseHeader hd; auto graph_hd_segment = container_->get(kSparseGraphHeaderSegmentId); if (!graph_hd_segment || graph_hd_segment->data_size() < sizeof(hd.graph)) { LOG_ERROR("Miss or invalid segment %s", kSparseGraphHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (graph_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.graph)) != sizeof(hd.graph)) { LOG_ERROR("Read segment %s failed", kSparseGraphHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.graph, data, sizeof(hd.graph)); auto hnsw_hd_segment = container_->get(kSparseHnswHeaderSegmentId); if (!hnsw_hd_segment || hnsw_hd_segment->data_size() < sizeof(hd.hnsw)) { LOG_ERROR("Miss or invalid segment %s", kSparseHnswHeaderSegmentId.c_str()); return IndexError_InvalidFormat; } if (hnsw_hd_segment->read(0, reinterpret_cast(&data), sizeof(hd.hnsw)) != sizeof(hd.hnsw)) { LOG_ERROR("Read segment %s failed", kSparseHnswHeaderSegmentId.c_str()); return IndexError_ReadData; } memcpy(&hd.hnsw, data, sizeof(hd.hnsw)); *mutable_header() = hd; segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt())); sparse_vector_meta_ = container_->get(kSparseGraphVectorMetaSegmentId); if (!sparse_vector_meta_) { LOG_ERROR("IndexStorage get segment %s failed", kSparseGraphVectorMetaSegmentId.c_str()); return IndexError_InvalidFormat; } keys_ = container_->get(kSparseGraphKeysSegmentId); if (!keys_) { LOG_ERROR("IndexStorage get segment %s failed", kSparseGraphKeysSegmentId.c_str()); return IndexError_InvalidFormat; } sparse_vectors_ = container_->get(kSparseGraphVectorsSegmentId); if (!sparse_vectors_) { LOG_ERROR("IndexStorage get segment %s failed", kSparseGraphVectorsSegmentId.c_str()); return IndexError_InvalidFormat; } neighbors_ = container_->get(kSparseGraphNeighborsSegmentId); if (!neighbors_ || (neighbors_->data_size() == 0 && doc_cnt() > 1)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kSparseGraphNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } neighbors_meta_ = container_->get(kSparseGraphOffsetsSegmentId); if (!neighbors_meta_ || neighbors_meta_->data_size() < sizeof(SparseGraphNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kSparseGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_ = container_->get(kSparseHnswNeighborsSegmentId); if (!upper_neighbors_ || (upper_neighbors_->data_size() == 0 && cur_max_level() > 0)) { LOG_ERROR("IndexStorage get segment %s failed or empty", kSparseHnswNeighborsSegmentId.c_str()); return IndexError_InvalidArgument; } upper_neighbors_meta_ = container_->get(kSparseHnswOffsetsSegmentId); if (!upper_neighbors_meta_ || upper_neighbors_meta_->data_size() < sizeof(HnswSparseNeighborMeta) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kSparseHnswOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } mapping_ = container_->get(kSparseGraphMappingSegmentId); if (!mapping_ || mapping_->data_size() < sizeof(node_id_t) * doc_cnt()) { LOG_ERROR("IndexStorage get segment %s failed or invalid size", kSparseGraphMappingSegmentId.c_str()); return IndexError_InvalidArgument; } if (check_crc) { std::vector segments; segments.emplace_back(graph_hd_segment); segments.emplace_back(hnsw_hd_segment); segments.emplace_back(sparse_vector_meta_); segments.emplace_back(keys_); segments.emplace_back(sparse_vectors_); segments.emplace_back(neighbors_); segments.emplace_back(neighbors_meta_); segments.emplace_back(upper_neighbors_); segments.emplace_back(upper_neighbors_meta_); if (!do_crc_check(segments)) { LOG_ERROR("Check index crc failed, the index may broken"); return IndexError_Runtime; } } if (neighbors_in_memory_enabled_) { int ret = load_and_flat_neighbors(); if (ret != 0) { return ret; } } return 0; } int HnswSparseSearcherEntity::load_and_flat_neighbors() { fixed_neighbors_.reset( new (std::nothrow) char[neighbors_size() * doc_cnt()]{}, std::default_delete()); if (!fixed_neighbors_) { LOG_ERROR("Malloc memory failed"); return IndexError_NoMemory; } //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = container_->get(kSparseGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const SparseGraphNeighborMeta *neighbors_index = nullptr; if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), neighbors_meta->data_size()) != neighbors_meta->data_size()) { LOG_ERROR("Read segment %s data failed", kSparseGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } const char *neighbor_data; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t rd_size = neighbors_index[id].neighbor_cnt * sizeof(node_id_t); if (ailego_unlikely( neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbor_data), rd_size) != rd_size)) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory char *dst = fixed_neighbors_.get() + neighbors_size() * id; *reinterpret_cast(dst) = neighbors_index[id].neighbor_cnt; memcpy(dst + sizeof(uint32_t), neighbor_data, rd_size); } return 0; } int HnswSparseSearcherEntity::get_fixed_neighbors( std::vector *fixed_neighbors) const { //! Get a new segemnt to release the buffer after loading neighbors auto neighbors_meta = container_->get(kSparseGraphOffsetsSegmentId); if (!neighbors_meta) { LOG_ERROR("IndexStorage get segment graph.offsets failed"); return IndexError_InvalidArgument; } const SparseGraphNeighborMeta *neighbors_index = nullptr; size_t meta_size = neighbors_meta->data_size(); if (neighbors_meta->read(0, reinterpret_cast(&neighbors_index), meta_size) != meta_size) { LOG_ERROR("Read segment %s data failed", kSparseGraphOffsetsSegmentId.c_str()); return IndexError_InvalidArgument; } size_t fixed_neighbor_cnt = l0_neighbor_cnt(); fixed_neighbors->resize((fixed_neighbor_cnt + 1) * doc_cnt(), kInvalidNodeId); size_t neighbors_cnt_offset = fixed_neighbor_cnt * doc_cnt(); size_t total_neighbor_cnt = 0; for (node_id_t id = 0; id < doc_cnt(); ++id) { size_t cur_neighbor_cnt = neighbors_index[id].neighbor_cnt; if (cur_neighbor_cnt == 0) { (*fixed_neighbors)[neighbors_cnt_offset + id] = 0; continue; } size_t rd_size = cur_neighbor_cnt * sizeof(node_id_t); const uint32_t *neighbors; if (neighbors_->read(neighbors_index[id].offset, reinterpret_cast(&neighbors), rd_size) != rd_size) { LOG_ERROR("Read neighbors from segment failed"); return IndexError_ReadData; } // copy level 0 neighbors to fixed size neighbors memory auto it = fixed_neighbors->begin() + id * fixed_neighbor_cnt; std::copy(neighbors, neighbors + cur_neighbor_cnt, it); (*fixed_neighbors)[neighbors_cnt_offset + id] = cur_neighbor_cnt; total_neighbor_cnt += cur_neighbor_cnt; } LOG_INFO("total neighbor cnt: %zu, average neighbor cnt: %zu", total_neighbor_cnt, total_neighbor_cnt / doc_cnt()); return 0; } bool HnswSparseSearcherEntity::do_crc_check( std::vector &segments) const { constexpr size_t blk_size = 4096; const void *data; for (auto &segment : segments) { size_t offset = 0; size_t rd_size; uint32_t crc = 0; while (offset < segment->data_size()) { size_t size = std::min(blk_size, segment->data_size() - offset); if ((rd_size = segment->read(offset, &data, size)) <= 0) { break; } offset += rd_size; crc = ailego::Crc32c::Hash(data, rd_size, crc); } if (crc != segment->data_crc()) { return false; } } return true; } const HnswSparseEntity::Pointer HnswSparseSearcherEntity::clone() const { auto keys = keys_->clone(); if (ailego_unlikely(!keys)) { LOG_ERROR("clone segment %s failed", kSparseGraphKeysSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto mapping = mapping_->clone(); if (ailego_unlikely(!mapping)) { LOG_ERROR("clone segment %s failed", kSparseGraphMappingSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto sparse_vector_meta = sparse_vector_meta_->clone(); if (ailego_unlikely(!sparse_vector_meta)) { LOG_ERROR("clone segment %s failed", kSparseGraphVectorMetaSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto sparse_vectors = sparse_vectors_->clone(); if (ailego_unlikely(!sparse_vectors)) { LOG_ERROR("clone segment %s failed", kSparseGraphVectorsSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto neighbors = neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kSparseGraphNeighborsSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto upper_neighbors = upper_neighbors_->clone(); if (ailego_unlikely(!neighbors)) { LOG_ERROR("clone segment %s failed", kSparseHnswNeighborsSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto neighbors_meta = neighbors_meta_->clone(); if (ailego_unlikely(!neighbors_meta)) { LOG_ERROR("clone segment %s failed", kSparseGraphOffsetsSegmentId.c_str()); return HnswSparseEntity::Pointer(); } auto upper_neighbors_meta = upper_neighbors_meta_->clone(); if (ailego_unlikely(!upper_neighbors_meta)) { LOG_ERROR("clone segment %s failed", kSparseHnswOffsetsSegmentId.c_str()); return HnswSparseEntity::Pointer(); } SegmentGroupParam neighbor_group{neighbors, neighbors_meta, upper_neighbors, upper_neighbors_meta}; SegmentGroupParam dense_neighbor_group{nullptr, nullptr, nullptr, nullptr}; SegmentGroupParam sparse_neighbor_group{nullptr, nullptr, nullptr, nullptr}; HnswSparseSearcherEntity *entity = new (std::nothrow) HnswSparseSearcherEntity(header(), keys, mapping, neighbor_group, sparse_vector_meta, sparse_vectors, fixed_neighbors_, neighbors_in_memory_enabled_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswSparseSearcherEntity new failed"); } return HnswSparseEntity::Pointer(entity); } //! Get vector sparse feature data by chunk index and offset const void *HnswSparseSearcherEntity::get_sparse_data(uint64_t offset, uint32_t len) const { const void *sparse_data = nullptr; uint32_t real_length = sparse_vectors_->read(offset, &sparse_data, len); if (ailego_unlikely(real_length != len)) { LOG_ERROR("Read sparse data from segment failed, %u vs %u", real_length, len); return nullptr; } return sparse_data; } int HnswSparseSearcherEntity::get_sparse_data( uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const { const void *vec = get_sparse_data(offset, len); block.reset((void *)vec); return 0; } //! Get sparse data from id const void *HnswSparseSearcherEntity::get_sparse_data(node_id_t id) const { const void *vec = get_vector_meta(id); if (vec == nullptr) { LOG_ERROR("get vector failed, id: %u", id); return nullptr; } auto sparse_data = get_sparse_data_from_vector(vec); return sparse_data.first; } int HnswSparseSearcherEntity::get_sparse_data( const node_id_t id, IndexStorage::MemoryBlock &block) const { const void *vec = get_sparse_data(id); block.reset((void *)vec); return 0; } //! Get sparse data from vector std::pair HnswSparseSearcherEntity::get_sparse_data_from_vector(const void *vec) const { if (vec == nullptr) { LOG_ERROR("vec is nullptr"); return std::make_pair(nullptr, 0); } const char *vec_ptr = reinterpret_cast(vec); uint64_t offset = *((uint64_t *)(vec_ptr)); uint32_t sparse_vector_len = *((uint32_t *)(vec_ptr + sizeof(uint64_t))); const void *sparse_data = get_sparse_data(offset, sparse_vector_len); if (ailego_unlikely(sparse_data == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)offset, sparse_vector_len); return std::make_pair(nullptr, 0); } return std::make_pair(sparse_data, sparse_vector_len); } int HnswSparseSearcherEntity::get_sparse_data_from_vector( const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const { std::pair sparse_data = get_sparse_data_from_vector(vec); block.reset((void *)sparse_data.first); sparse_length = sparse_data.second; return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_searcher_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "hnsw_sparse_builder_entity.h" #include "hnsw_sparse_entity.h" namespace zvec { namespace core { class HnswSparseSearcherEntity : public HnswSparseEntity { public: using Pointer = std::shared_ptr; using SegmentPointer = IndexStorage::Segment::Pointer; public: struct SegmentGroupParam { SegmentGroupParam(SegmentPointer neighbors_in, SegmentPointer neighbors_meta_in, SegmentPointer upper_neighbors_in, SegmentPointer upper_neighbors_meta_in) : neighbors{neighbors_in}, neighbors_meta{neighbors_meta_in}, upper_neighbors{upper_neighbors_in}, upper_neighbors_meta{upper_neighbors_meta_in} {} SegmentPointer neighbors{nullptr}; SegmentPointer neighbors_meta{nullptr}; SegmentPointer upper_neighbors{nullptr}; SegmentPointer upper_neighbors_meta{nullptr}; }; //! Constructor HnswSparseSearcherEntity(); //! Make a copy of searcher entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswSparseEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector local id by key node_id_t get_id(key_t key) const; //! Get sparse vector feature data by key virtual int get_sparse_vector_by_key( key_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; //! Get vector feature data by id virtual const void *get_vector_meta(node_id_t id) const override; virtual int get_vector_meta(const node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Get vector feature data by id virtual int get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const override; //! Get vector sparse feature data by chunk index and offset virtual const void *get_sparse_data(uint64_t offset, uint32_t len) const override; //! Get sparse data from id virtual const void *get_sparse_data(node_id_t id) const override; virtual int get_sparse_data(uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const override; virtual int get_sparse_data(const node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Get sparse data from vector virtual std::pair get_sparse_data_from_vector( const void *vec) const override; virtual int get_sparse_data_from_vector(const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const override; //! Get the node id's neighbors on graph level virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; virtual int load(const IndexStorage::Pointer &container, bool check_crc) override; int load_segments(bool check_crc); virtual int cleanup(void) override; public: bool is_loaded() const { return loaded_; } void set_neighbors_in_memory(bool enabled) { neighbors_in_memory_enabled_ = enabled; } //! get fixed length neighbors data int get_fixed_neighbors(std::vector *fixed_neighbors) const; private: //! Constructor HnswSparseSearcherEntity(const HNSWSparseHeader &hd, const SegmentPointer &keys, const SegmentPointer &mapping, const SegmentGroupParam &neighbor_group, const SegmentPointer &sparse_vector_meta, const SegmentPointer &sparse_vectors, const std::shared_ptr &fixed_neighbors, bool neighbors_in_memory_enabled) : HnswSparseEntity(hd), keys_(keys), mapping_(mapping), neighbors_(neighbor_group.neighbors), neighbors_meta_(neighbor_group.neighbors_meta), upper_neighbors_(neighbor_group.upper_neighbors), upper_neighbors_meta_(neighbor_group.upper_neighbors_meta), sparse_vector_meta_(sparse_vector_meta), sparse_vectors_(sparse_vectors), neighbors_in_memory_enabled_(neighbors_in_memory_enabled) { segment_datas_.resize(std::max(l0_neighbor_cnt(), upper_neighbor_cnt()), IndexStorage::SegmentData(0U, 0U)); fixed_neighbors_ = fixed_neighbors; } bool do_crc_check(std::vector &segments) const; inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } //! If neighbors_in_memory_enabled, load the level0 neighbors to memory int load_and_flat_neighbors(void); public: HnswSparseSearcherEntity(const HnswSparseSearcherEntity &) = delete; HnswSparseSearcherEntity &operator=(const HnswSparseSearcherEntity &) = delete; private: IndexStorage::Pointer container_{}; SegmentPointer keys_{}; SegmentPointer mapping_{}; SegmentPointer neighbors_{}; SegmentPointer neighbors_meta_{}; SegmentPointer upper_neighbors_{}; SegmentPointer upper_neighbors_meta_{}; SegmentPointer sparse_vector_meta_{}; SegmentPointer sparse_vectors_{}; mutable std::vector segment_datas_{}; std::shared_ptr fixed_neighbors_{}; // level 0 fixed size neighbors bool neighbors_in_memory_enabled_{false}; bool loaded_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_streamer.h" #include #include #include #include #include "hnsw_sparse_algorithm.h" #include "hnsw_sparse_context.h" #include "hnsw_sparse_dist_calculator.h" #include "hnsw_sparse_index_provider.h" namespace zvec { namespace core { HnswSparseStreamer::HnswSparseStreamer() : entity_(stats_) {} HnswSparseStreamer::~HnswSparseStreamer() { if (state_ == STATE_INITED) { this->cleanup(); } } int HnswSparseStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; meta_.set_streamer("HnswSparseStreamer", HnswSparseEntity::kRevision, params); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_INDEX_SIZE, &max_index_size_); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_NEIGHBOR_COUNT, &upper_max_neighbor_cnt_); float multiplier = HnswSparseEntity::kDefaultL0MaxNeighborCntMultiplier; params.get(PARAM_HNSW_SPARSE_STREAMER_L0_MAX_NEIGHBOR_COUNT_MULTIPLIER, &multiplier); l0_max_neighbor_cnt_ = multiplier * upper_max_neighbor_cnt_; multiplier = HnswSparseEntity::kDefaultNeighborPruneMultiplier; params.get(PARAM_HNSW_SPARSE_STREAMER_NEIGHBOR_PRUNE_MULTIPLIER, &multiplier); size_t prune_cnt = multiplier * upper_max_neighbor_cnt_; scaling_factor_ = upper_max_neighbor_cnt_; params.get(PARAM_HNSW_SPARSE_STREAMER_SCALING_FACTOR, &scaling_factor_); params.get(PARAM_HNSW_SPARSE_STREAMER_DOCS_HARD_LIMIT, &docs_hard_limit_); params.get(PARAM_HNSW_SPARSE_STREAMER_EF, &ef_); params.get(PARAM_HNSW_SPARSE_STREAMER_EFCONSTRUCTION, &ef_construction_); params.get(PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_ENABLE, &bf_enabled_); params.get(PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB, &bf_negative_prob_); params.get(PARAM_HNSW_SPARSE_STREAMER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_RATIO, &max_scan_ratio_); params.get(PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_LIMIT, &max_scan_limit_); params.get(PARAM_HNSW_SPARSE_STREAMER_MIN_SCAN_LIMIT, &min_scan_limit_); params.get(PARAM_HNSW_SPARSE_STREAMER_CHECK_CRC_ENABLE, &check_crc_enabled_); params.get(PARAM_HNSW_SPARSE_STREAMER_CHUNK_SIZE, &chunk_size_); params.get(PARAM_HNSW_SPARSE_STREAMER_FILTER_SAME_KEY, &filter_same_key_); params.get(PARAM_HNSW_SPARSE_STREAMER_GET_VECTOR_ENABLE, &get_vector_enabled_); params.get(PARAM_HNSW_SPARSE_STREAMER_MIN_NEIGHBOR_COUNT, &min_neighbor_cnt_); params.get(PARAM_HNSW_SPARSE_STREAMER_FORCE_PADDING_RESULT_ENABLE, &force_padding_topk_enabled_); query_filtering_enabled_ = params.get(PARAM_HNSW_SPARSE_STREAMER_QUERY_FILTERING_RATIO, &query_filtering_ratio_); params.get(PARAM_HNSW_SPARSE_STREAMER_DOCS_SOFT_LIMIT, &docs_soft_limit_); if (docs_soft_limit_ > 0 && docs_soft_limit_ > docs_hard_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_SPARSE_STREAMER_DOCS_HARD_LIMIT.c_str(), PARAM_HNSW_SPARSE_STREAMER_DOCS_SOFT_LIMIT.c_str()); return IndexError_InvalidArgument; } else if (docs_soft_limit_ == 0UL) { docs_soft_limit_ = docs_hard_limit_ * HnswSparseEntity::kDefaultDocsSoftLimitRatio; } if (ef_ == 0U) { ef_ = HnswSparseEntity::kDefaultEf; } if (ef_construction_ == 0U) { ef_construction_ = HnswSparseEntity::kDefaultEfConstruction; } if (upper_max_neighbor_cnt_ == 0U) { upper_max_neighbor_cnt_ = HnswSparseEntity::kDefaultUpperMaxNeighborCnt; } if (upper_max_neighbor_cnt_ > HnswSparseEntity::kMaxNeighborCnt) { LOG_ERROR("[%s] must be in range (0,%d)", PARAM_HNSW_SPARSE_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), HnswSparseEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (l0_max_neighbor_cnt_ == 0U) { l0_max_neighbor_cnt_ = HnswSparseEntity::kDefaultL0MaxNeighborCnt; } if (l0_max_neighbor_cnt_ > HnswSparseEntity::kMaxNeighborCnt) { LOG_ERROR("UpperNeighborCnt must be in range (0,%d)", HnswSparseEntity::kMaxNeighborCnt); return IndexError_InvalidArgument; } if (min_neighbor_cnt_ > upper_max_neighbor_cnt_) { LOG_ERROR("[%s]-[%u] must be <= [%s]-[%u]", PARAM_HNSW_SPARSE_STREAMER_MIN_NEIGHBOR_COUNT.c_str(), min_neighbor_cnt_, PARAM_HNSW_SPARSE_STREAMER_MAX_NEIGHBOR_COUNT.c_str(), upper_max_neighbor_cnt_); return IndexError_InvalidArgument; } if (bf_negative_prob_ <= 0.0f || bf_negative_prob_ >= 1.0f) { LOG_ERROR( "[%s] must be in range (0,1)", PARAM_HNSW_SPARSE_STREAMER_VISIT_BLOOMFILTER_NEGATIVE_PROB.c_str()); return IndexError_InvalidArgument; } if (scaling_factor_ == 0U) { scaling_factor_ = HnswSparseEntity::kDefaultScalingFactor; } if (scaling_factor_ < 5 || scaling_factor_ > 1000) { LOG_ERROR("[%s] must be in range [5,1000]", PARAM_HNSW_SPARSE_STREAMER_SCALING_FACTOR.c_str()); return IndexError_InvalidArgument; } if (max_scan_ratio_ <= 0.0f || max_scan_ratio_ > 1.0f) { LOG_ERROR("[%s] must be in range (0.0f,1.0f]", PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_RATIO.c_str()); return IndexError_InvalidArgument; } if (max_scan_limit_ < min_scan_limit_) { LOG_ERROR("[%s] must be >= [%s]", PARAM_HNSW_SPARSE_STREAMER_MAX_SCAN_LIMIT.c_str(), PARAM_HNSW_SPARSE_STREAMER_MIN_SCAN_LIMIT.c_str()); return IndexError_InvalidArgument; } if (prune_cnt == 0UL) { prune_cnt = upper_max_neighbor_cnt_; } if (chunk_size_ == 0UL) { chunk_size_ = HnswSparseEntity::kDefaultChunkSize; } if (chunk_size_ > HnswSparseEntity::kMaxChunkSize) { LOG_ERROR("[%s] must be < %zu", PARAM_HNSW_SPARSE_STREAMER_CHUNK_SIZE.c_str(), HnswSparseEntity::kMaxChunkSize); return IndexError_InvalidArgument; } if (query_filtering_enabled_ && (query_filtering_ratio_ <= 0.0f || query_filtering_ratio_ >= 1.0f)) { LOG_ERROR("[%s] must be in range (0, 1)", PARAM_HNSW_SPARSE_SEARCHER_QUERY_FILTERING_RATIO.c_str()); return IndexError_InvalidArgument; } entity_.set_ef_construction(ef_construction_); entity_.set_l0_neighbor_cnt(l0_max_neighbor_cnt_); entity_.set_upper_neighbor_cnt(upper_max_neighbor_cnt_); entity_.set_scaling_factor(scaling_factor_); entity_.set_prune_cnt(prune_cnt); entity_.set_chunk_size(chunk_size_); entity_.set_filter_same_key(filter_same_key_); entity_.set_get_vector(get_vector_enabled_); entity_.set_min_neighbor_cnt(min_neighbor_cnt_); entity_.set_sparse_meta_size(HnswSparseEntity::kSparseMetaSize); entity_.set_sparse_unit_size(meta_.unit_size()); int ret = entity_.init(max_index_size_, docs_hard_limit_); if (ret != 0) { LOG_ERROR("Hnsw entity init failed for %s", IndexError::What(ret)); return ret; } LOG_DEBUG( "Init params: maxIndexSize=%zu docsHardLimit=%zu docsSoftLimit=%zu " "efConstruction=%u ef=%u l0NeighborCnt=%u upperNeighborCnt=%u " "scalingFactor=%u maxScanRatio=%.3f minScanLimit=%zu maxScanLimit=%zu " "bfEnabled=%d bruteFoceThreshold=%zu bfNegativeProbability=%.5f " "checkCrcEnabled=%d pruneSize=%zu chunkSize=%zu " "filterSameKey=%u getVectorEnabled=%u " "minNeighborCount=%u forcePadding=%u filteringRatio=%f", max_index_size_, docs_hard_limit_, docs_soft_limit_, ef_construction_, ef_, l0_max_neighbor_cnt_, upper_max_neighbor_cnt_, scaling_factor_, max_scan_ratio_, min_scan_limit_, max_scan_limit_, bf_enabled_, bruteforce_threshold_, bf_negative_prob_, check_crc_enabled_, prune_cnt, chunk_size_, filter_same_key_, get_vector_enabled_, min_neighbor_cnt_, force_padding_topk_enabled_, query_filtering_ratio_); alg_ = HnswSparseAlgorithm::UPointer(new HnswSparseAlgorithm(entity_)); ret = alg_->init(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswSparseStreamer::cleanup(void) { if (state_ == STATE_OPENED) { this->close(); } LOG_INFO("HnswSparseStreamer cleanup"); meta_.clear(); metric_.reset(); stats_.clear(); entity_.cleanup(); if (alg_) { alg_->cleanup(); } max_index_size_ = 0UL; docs_hard_limit_ = HnswSparseEntity::kDefaultDocsHardLimit; docs_soft_limit_ = 0UL; upper_max_neighbor_cnt_ = HnswSparseEntity::kDefaultUpperMaxNeighborCnt; ef_ = HnswSparseEntity::kDefaultEf; ef_construction_ = HnswSparseEntity::kDefaultEfConstruction; bf_enabled_ = false; scaling_factor_ = HnswSparseEntity::kDefaultScalingFactor; bruteforce_threshold_ = HnswSparseEntity::kDefaultBruteForceThreshold; max_scan_limit_ = HnswSparseEntity::kDefaultMaxScanLimit; min_scan_limit_ = HnswSparseEntity::kDefaultMinScanLimit; chunk_size_ = HnswSparseEntity::kDefaultChunkSize; bf_negative_prob_ = HnswSparseEntity::kDefaultBFNegativeProbability; max_scan_ratio_ = HnswSparseEntity::kDefaultScanRatio; state_ = STATE_INIT; check_crc_enabled_ = false; filter_same_key_ = false; get_vector_enabled_ = false; sparse_neighbor_ratio_ = HnswSparseEntity::kDefaultSparseNeighborRatio; sparse_neighbor_cnt_ = 0UL; sparse_min_neighbor_cnt_ = 0UL; upper_sparse_neighbor_cnt_ = 0UL; return 0; } int HnswSparseStreamer::open(IndexStorage::Pointer stg) { LOG_INFO("HnswSparseStreamer open"); if (ailego_unlikely(state_ != STATE_INITED)) { LOG_ERROR("Open storage failed, init streamer first!"); return IndexError_NoReady; } int ret = entity_.open(std::move(stg), check_crc_enabled_); if (ret != 0) { return ret; } IndexMeta index_meta; ret = entity_.get_index_meta(&index_meta); if (ret == IndexError_NoExist) { // Set IndexMeta for the new index ret = entity_.set_index_meta(meta_); if (ret != 0) { LOG_ERROR("Failed to set index meta for %s", IndexError::What(ret)); return ret; } } else if (ret != 0) { LOG_ERROR("Failed to get index meta for %s", IndexError::What(ret)); return ret; } else { if (index_meta.metric_name() != meta_.metric_name() || index_meta.data_type() != meta_.data_type()) { LOG_ERROR("IndexMeta mismatch from the previous in index"); return IndexError_Mismatch; } // The IndexMetric Params may be updated like MipsSquaredEuclidean auto metric_params = index_meta.metric_params(); metric_params.merge(meta_.metric_params()); meta_.set_metric(index_meta.metric_name(), 0, metric_params); } metric_ = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric_) { LOG_ERROR("Failed to create metric %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric_->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("Failled to init metric, ret=%d", ret); return ret; } if (!metric_->sparse_distance()) { LOG_ERROR("Invalid metric distance"); return IndexError_InvalidArgument; } add_distance_ = metric_->sparse_distance(); search_distance_ = add_distance_; if (metric_->query_metric() && metric_->query_metric()->distance()) { search_distance_ = metric_->query_metric()->sparse_distance(); } state_ = STATE_OPENED; magic_ = IndexContext::GenerateMagic(); return 0; } int HnswSparseStreamer::close(void) { LOG_INFO("HnswSparseStreamer close"); stats_.clear(); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); int ret = entity_.close(); if (ret != 0) { return ret; } state_ = STATE_INITED; return 0; } int HnswSparseStreamer::flush(uint64_t checkpoint) { LOG_INFO("HnswSparseStreamer flush checkpoint=%zu", (size_t)checkpoint); meta_.set_metric(metric_->name(), 0, metric_->params()); entity_.set_index_meta(meta_); return entity_.flush(checkpoint); } int HnswSparseStreamer::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("HnswSparseStreamer dump"); shared_mutex_.lock(); AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); meta_.set_searcher("HnswSparseSearcher", HnswSparseEntity::kRevision, ailego::Params()); int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } return entity_.dump(dumper); } IndexStreamer::Context::Pointer HnswSparseStreamer::create_context(void) const { if (ailego_unlikely(state_ != STATE_OPENED)) { LOG_ERROR("Create context failed, open storage first!"); return Context::Pointer(); } HnswSparseEntity::Pointer entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("CreateContext clone init failed"); return Context::Pointer(); } HnswSparseContext *ctx = new (std::nothrow) HnswSparseContext(metric_, entity); if (ailego_unlikely(ctx == nullptr)) { LOG_ERROR("Failed to new HnswSparseContext"); return Context::Pointer(); } ctx->set_ef(ef_); ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_filter_mode(bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap); ctx->set_filter_negative_probability(bf_negative_prob_); ctx->set_magic(magic_); ctx->set_force_padding_topk(force_padding_topk_enabled_); ctx->set_bruteforce_threshold(bruteforce_threshold_); if (ailego_unlikely(ctx->init(HnswSparseContext::kSparseStreamerContext)) != 0) { LOG_ERROR("Init HnswSparseContext failed"); delete ctx; return Context::Pointer(); } return Context::Pointer(ctx); } IndexStreamer::SparseProvider::Pointer HnswSparseStreamer::create_sparse_provider(void) const { LOG_DEBUG("HnswSparseStreamer create sparse provider"); auto entity = entity_.clone(); if (ailego_unlikely(!entity)) { LOG_ERROR("Clone HnswSparseEntity failed"); return SparseProvider::Pointer(); } return SparseProvider::Pointer( new HnswSparseIndexProvider(meta_, entity, "HnswSparseStreamer")); } int HnswSparseStreamer::update_context(HnswSparseContext *ctx) const { const HnswSparseEntity::Pointer entity = entity_.clone(); if (!entity) { LOG_ERROR("Failed to clone search context entity"); return IndexError_Runtime; } ctx->set_max_scan_limit(max_scan_limit_); ctx->set_min_scan_limit(min_scan_limit_); ctx->set_max_scan_ratio(max_scan_ratio_); ctx->set_bruteforce_threshold(bruteforce_threshold_); return ctx->update_context(HnswSparseContext::kSparseStreamerContext, meta_, metric_, entity, magic_); } //! Add a vector with id into index with sparse inputs int HnswSparseStreamer::add_with_id_impl(uint32_t id, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) { int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { LOG_WARN( "Failed to add sparse vector: number of non-zero elements (%u) exceeds " "maximum allowed (%u), id=%u", sparse_count, HnswSparseEntity::kSparseMaxDimSize, id); return IndexError_InvalidValue; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_SPARSE_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_SPARSE_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_); std::string sparse_query_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_query, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); level_t level = alg_->get_random_level(); ret = entity_.add_vector_with_id(level, id, sparse_query_buffer, sparse_count); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw stramer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } //! Add a vector into index with sparse inputs int HnswSparseStreamer::add_impl(uint64_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) { int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } if (ailego_unlikely(sparse_count > HnswSparseEntity::kSparseMaxDimSize)) { LOG_WARN( "Failed to add sparse vector: number of non-zero elements (%u) exceeds " "maximum allowed (%u), key=%zu", sparse_count, HnswSparseEntity::kSparseMaxDimSize, (size_t)pkey); return IndexError_InvalidValue; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } if (ailego_unlikely(entity_.doc_cnt() >= docs_soft_limit_)) { if (entity_.doc_cnt() >= docs_hard_limit_) { LOG_ERROR("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_SPARSE_STREAMER_DOCS_HARD_LIMIT.c_str()); const std::lock_guard lk(mutex_); (*stats_.mutable_discarded_count())++; return IndexError_IndexFull; } else { LOG_WARN("Current docs %u exceed [%s]", entity_.doc_cnt(), PARAM_HNSW_SPARSE_STREAMER_DOCS_SOFT_LIMIT.c_str()); } } if (ailego_unlikely(!shared_mutex_.try_lock_shared())) { LOG_ERROR("Cannot add vector while dumping index"); (*stats_.mutable_discarded_count())++; return IndexError_Unsupported; } AILEGO_DEFER([&]() { shared_mutex_.unlock_shared(); }); ctx->clear(); ctx->update_dist_caculator_distance(add_distance_); std::string sparse_query_buffer; SparseUtility::TransSparseFormat(sparse_count, sparse_indices, sparse_query, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); level_t level = alg_->get_random_level(); node_id_t id; ret = entity_.add_vector(level, pkey, sparse_query_buffer, sparse_count, &id); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw streamer add vector failed"); (*stats_.mutable_discarded_count())++; return ret; } ret = alg_->add_node(id, level, ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw stramer add node failed"); (*stats_.mutable_discarded_count())++; return ret; } if (ailego_unlikely(ctx->error())) { (*stats_.mutable_discarded_count())++; return IndexError_Runtime; } (*stats_.mutable_added_count())++; return 0; } //! Similarity search with sparse inputs int HnswSparseStreamer::search_impl( const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity search with sparse inputs int HnswSparseStreamer::search_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (entity_.doc_cnt() <= ctx->get_bruteforce_threshold()) { return search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context); } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_); ctx->resize_results(count); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; std::string sparse_query_filtered_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); if (query_filtering_enabled_) { if (!SparseUtility::FilterSparseQuery( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, qmeta.data_type(), entity_.sparse_unit_size(), query_filtering_ratio_, &sparse_query_filtered_buffer)) { LOG_ERROR("Hnsw filtering failed"); return IndexError_Runtime; } ctx->reset_query(sparse_query_filtered_buffer.data()); } else { ctx->reset_query(sparse_query_buffer.data()); } ret = alg_->search(ctx); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } if (query_filtering_enabled_) { ctx->reset_query(sparse_query_buffer.data()); ctx->recal_topk_dist(); } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } //! Similarity search with sparse inputs int HnswSparseStreamer::search_bf_impl( const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, IndexStreamer::Context::Pointer &context) const { return search_bf_impl(&sparse_count, sparse_indices, sparse_query, qmeta, 1, context); } //! Similarity search with sparse inputs int HnswSparseStreamer::search_bf_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, IndexStreamer::Context::Pointer &context) const { int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_); ctx->resize_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_Runtime; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->group_topk_heaps().clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!ctx->filter().is_valid() || !ctx->filter()(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } else { auto &filter = ctx->filter(); auto &topk = ctx->topk_heap(); for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); topk.clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { if (entity_.get_key(id) == kInvalidKey) { continue; } if (!filter.is_valid() || !filter(entity_.get_key(id))) { dist_t dist = ctx->dist_calculator().dist(id); topk.emplace(id, dist); } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } } return 0; } //! Linear search by primary keys int HnswSparseStreamer::search_bf_by_p_keys_impl( const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const { return search_bf_by_p_keys_impl(&sparse_count, sparse_indices, sparse_query, p_keys, qmeta, 1, context); } //! Linear search by primary keys with sparse inputs int HnswSparseStreamer::search_bf_by_p_keys_impl( const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { int ret = check_params(qmeta); if (ailego_unlikely(ret != 0)) { return ret; } if (ailego_unlikely(p_keys.size() != count)) { LOG_ERROR("The size of p_keys is not equal to count"); return IndexError_InvalidArgument; } HnswSparseContext *ctx = dynamic_cast(context.get()); ailego_do_if_false(ctx) { LOG_ERROR("Cast context to HnswSparseContext failed"); return IndexError_Cast; } if (ctx->magic() != magic_) { //! context is created by another searcher or streamer ret = update_context(ctx); if (ret != 0) { return ret; } } ctx->clear(); ctx->update_dist_caculator_distance(search_distance_); ctx->resize_results(count); const uint32_t *sparse_indices_tmp = sparse_indices; const void *sparse_query_tmp = sparse_query; if (ctx->group_by_search()) { if (!ctx->group_by().is_valid()) { LOG_ERROR("Invalid group-by function"); return IndexError_Runtime; } std::function group_by = [&](node_id_t id) { return ctx->group_by()(entity_.get_key(id)); }; for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); ctx->group_topk_heaps().clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { uint64_t pk = p_keys[q][idx]; if (!ctx->filter().is_valid() || !ctx->filter()(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().dist(id); std::string group_id = group_by(id); auto &topk_heap = ctx->group_topk_heaps()[group_id]; if (topk_heap.empty()) { topk_heap.limit(ctx->group_topk()); } topk_heap.emplace_back(id, dist); } } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } else { auto &filter = ctx->filter(); auto &topk = ctx->topk_heap(); for (size_t q = 0; q < count; ++q) { std::string sparse_query_buffer; SparseUtility::TransSparseFormat( sparse_count[q], sparse_indices_tmp, sparse_query_tmp, entity_.sparse_unit_size(), sparse_query_buffer); ctx->reset_query(sparse_query_buffer.data()); topk.clear(); for (size_t idx = 0; idx < p_keys[q].size(); ++idx) { key_t pk = p_keys[q][idx]; if (!filter.is_valid() || !filter(pk)) { node_id_t id = entity_.get_id(pk); if (id != kInvalidNodeId) { dist_t dist = ctx->dist_calculator().dist(id); topk.emplace(id, dist); } } } ctx->topk_to_result(q); sparse_indices_tmp += sparse_count[q]; sparse_query_tmp = reinterpret_cast(sparse_query_tmp) + sparse_count[q] * qmeta.unit_size(); } } if (ailego_unlikely(ctx->error())) { return IndexError_Runtime; } return 0; } void HnswSparseStreamer::print_debug_info() { for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { Neighbors neighbours = entity_.get_neighbors(0, id); std::cout << "node: " << id << "; "; for (uint32_t i = 0; i < neighbours.size(); ++i) { std::cout << neighbours[i]; if (i == neighbours.size() - 1) { std::cout << std::endl; } else { std::cout << ", "; } } } // entity_.print_key_map(); } INDEX_FACTORY_REGISTER_STREAMER(HnswSparseStreamer); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "hnsw_sparse_algorithm.h" #include "hnsw_sparse_streamer_entity.h" namespace zvec { namespace core { class HnswSparseStreamer : public IndexStreamer { public: using ContextPointer = IndexStreamer::Context::Pointer; HnswSparseStreamer(void); virtual ~HnswSparseStreamer(void); HnswSparseStreamer(const HnswSparseStreamer &streamer) = delete; HnswSparseStreamer &operator=(const HnswSparseStreamer &streamer) = delete; protected: //! Initialize Streamer int init(const IndexMeta &imeta, const ailego::Params ¶ms) override; //! Cleanup Streamer int cleanup(void) override; //! Create a context Context::Pointer create_context(void) const override; //! Create a new sparse iterator IndexStreamer::SparseProvider::Pointer create_sparse_provider( void) const override; int add_impl(uint64_t pkey, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; int add_with_id_impl(uint32_t id, const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) override; //! Similarity search with sparse inputs int search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search with sparse inputs int search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search with sparse inputs int search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Linear search by primary keys int search_bf_by_p_keys_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, ContextPointer &context) const override; //! Linear search by primary keys with sparse inputs int search_bf_by_p_keys_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const std::vector> &p_keys, const IndexQueryMeta &qmeta, uint32_t count, ContextPointer &context) const override; //! Fetch sparse vector by key int get_sparse_vector(uint64_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override { return entity_.get_sparse_vector_by_key( key, sparse_count, sparse_indices_buffer, sparse_values_buffer); } //! Fetch vector by id int get_sparse_vector_by_id( uint32_t id, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override { return entity_.get_sparse_vector_by_id( id, sparse_count, sparse_indices_buffer, sparse_values_buffer); } //! Open index from file path int open(IndexStorage::Pointer stg) override; //! Close file int close(void) override; //! flush file int flush(uint64_t checkpoint) override; //! Dump index into storage int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics const Stats &stats(void) const override { return stats_; } //! Retrieve sparse meta of index const IndexMeta &meta(void) const override { return meta_; } void print_debug_info() override; private: inline int check_params(const IndexQueryMeta &qmeta) const { if (ailego_unlikely(qmeta.data_type() != meta_.data_type())) { LOG_ERROR("Unsupported query meta"); return IndexError_Mismatch; } return 0; } inline int check_sparse_count_is_zero(const uint32_t *sparse_count, uint32_t count) const { for (uint32_t i = 0; i < count; ++i) { if (sparse_count[i] != 0) LOG_ERROR("Sparse cout is not empty. Index: %u, Sparse Count: %u", i, sparse_count[i]); return IndexError_InvalidArgument; } return 0; } private: //! To share ctx across streamer/searcher, we need to update the context for //! current streamer/searcher int update_context(HnswSparseContext *ctx) const; private: enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; class Stats : public IndexStreamer::Stats { public: void clear(void) { set_revision_id(0u); set_loaded_count(0u); set_added_count(0u); set_discarded_count(0u); set_index_size(0u); set_dumped_size(0u); set_check_point(0u); set_create_time(0u); set_update_time(0u); clear_attributes(); } }; HnswSparseStreamerEntity entity_; HnswSparseAlgorithm::UPointer alg_; IndexMeta meta_{}; IndexMetric::Pointer metric_{}; IndexMetric::MatrixSparseDistance add_distance_{}; IndexMetric::MatrixSparseDistance search_distance_{}; Stats stats_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; size_t chunk_size_{HnswSparseEntity::kDefaultChunkSize}; size_t docs_hard_limit_{HnswSparseEntity::kDefaultDocsHardLimit}; size_t docs_soft_limit_{0UL}; uint32_t min_neighbor_cnt_{0u}; uint32_t upper_max_neighbor_cnt_{ HnswSparseEntity::kDefaultUpperMaxNeighborCnt}; uint32_t l0_max_neighbor_cnt_{HnswSparseEntity::kDefaultL0MaxNeighborCnt}; uint32_t ef_{HnswSparseEntity::kDefaultEf}; uint32_t ef_construction_{HnswSparseEntity::kDefaultEfConstruction}; uint32_t scaling_factor_{HnswSparseEntity::kDefaultScalingFactor}; size_t bruteforce_threshold_{HnswSparseEntity::kDefaultBruteForceThreshold}; size_t max_scan_limit_{HnswSparseEntity::kDefaultMaxScanLimit}; size_t min_scan_limit_{HnswSparseEntity::kDefaultMinScanLimit}; float bf_negative_prob_{HnswSparseEntity::kDefaultBFNegativeProbability}; float max_scan_ratio_{HnswSparseEntity::kDefaultScanRatio}; float sparse_neighbor_ratio_{HnswSparseEntity::kDefaultSparseNeighborRatio}; uint32_t sparse_neighbor_cnt_{0UL}; uint32_t sparse_min_neighbor_cnt_{0UL}; uint32_t upper_sparse_neighbor_cnt_{0UL}; bool query_filtering_enabled_{false}; float query_filtering_ratio_{HnswSparseEntity::kDefaultQueryFilteringRatio}; uint32_t magic_{0U}; State state_{STATE_INIT}; bool bf_enabled_{false}; bool check_crc_enabled_{false}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; bool force_padding_topk_enabled_{false}; //! avoid add vector while dumping index ailego::SharedMutex shared_mutex_{}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_sparse_streamer_entity.h" #include #include #include #include #include "utility/sparse_utility.h" #include "hnsw_sparse_dist_calculator.h" namespace zvec { namespace core { HnswSparseStreamerEntity::HnswSparseStreamerEntity(IndexStreamer::Stats &stats) : stats_(stats) {} HnswSparseStreamerEntity::~HnswSparseStreamerEntity() {} int HnswSparseStreamerEntity::init(uint64_t max_index_size, size_t max_doc_cnt) { if (std::pow(scaling_factor(), kMaxGraphLayers) < max_doc_cnt) { LOG_ERROR("scalingFactor=%zu is too small", scaling_factor()); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); broker_ = std::make_shared(stats_); upper_neighbor_index_ = std::make_shared(); keys_map_lock_ = std::make_shared(); keys_map_ = std::make_shared>(); if (!keys_map_ || !upper_neighbor_index_ || !broker_ || !keys_map_lock_) { LOG_ERROR("HnswSparseStreamerEntity new object failed"); return IndexError_NoMemory; } keys_map_->set_empty_key(kInvalidKey); neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); //! vector + key + level 0 neighbors size_t size = sizeof(key_t) + neighbor_size_ + sparse_meta_size(); size = AlignSize(size); set_node_size(size); return init_chunk_params(max_index_size); } int HnswSparseStreamerEntity::cleanup() { std::lock_guard lock(mutex_); mutable_header()->clear(); chunk_size_ = kDefaultChunkSize; node_index_mask_bits_ = 0U; node_index_mask_ = 0U; node_cnt_per_chunk_ = 0U; neighbor_size_ = 0U; upper_neighbor_size_ = 0U; if (upper_neighbor_index_) { upper_neighbor_index_->cleanup(); } if (keys_map_) { keys_map_->clear(); } node_chunks_.clear(); upper_neighbor_chunks_.clear(); filter_same_key_ = false; get_vector_enabled_ = false; broker_.reset(); return 0; } int HnswSparseStreamerEntity::update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) { std::vector buffer(neighbor_size_); NeighborsHeader *hd = reinterpret_cast(buffer.data()); hd->neighbor_cnt = neighbors.size(); size_t i = 0; for (; i < neighbors.size(); ++i) { hd->neighbors[i] = neighbors[i].first; } auto loc = get_neighbor_chunk_loc(level, id); size_t size = reinterpret_cast(&hd->neighbors[i]) - &buffer[0]; size_t ret = loc.first->write(loc.second, hd, size); if (ailego_unlikely(ret != size)) { LOG_ERROR("Write neighbor header failed, ret=%zu", ret); return IndexError_Runtime; } return 0; } const Neighbors HnswSparseStreamerEntity::get_neighbors(level_t level, node_id_t id) const { SparseChunk *chunk = nullptr; size_t offset = 0UL; size_t neighbor_size = neighbor_size_; if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; offset = (id & node_index_mask_) * node_size() + sizeof(key_t) + sparse_meta_size(); sync_chunks(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_with(chunk_idx < node_chunks_.size(), "invalid chunk idx"); chunk = node_chunks_[chunk_idx].get(); } else { auto p = get_upper_neighbor_chunk_loc(level, id); chunk = upper_neighbor_chunks_[p.first].get(); offset = p.second; neighbor_size = upper_neighbor_size_; } ailego_assert_with(offset < chunk->data_size(), "invalid chunk offset"); IndexStorage::MemoryBlock neighbor_block; size_t size = chunk->read(offset, neighbor_block, neighbor_size); if (ailego_unlikely(size != neighbor_size)) { LOG_ERROR("Read neighbor header failed, ret=%zu", size); return Neighbors(); } return Neighbors(std::move(neighbor_block)); } //! Get vector feature data by key const void *HnswSparseStreamerEntity::get_vector_meta(node_id_t id) const { auto loc = get_vector_chunk_loc(id); const void *vec = nullptr; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = sparse_meta_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vec, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); } return vec; } int HnswSparseStreamerEntity::get_vector_meta( const node_id_t id, IndexStorage::MemoryBlock &block) const { auto loc = get_vector_chunk_loc(id); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = sparse_meta_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, block, read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } return 0; } int HnswSparseStreamerEntity::get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const { for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = sparse_meta_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, &vecs[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } } return 0; } int HnswSparseStreamerEntity::get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const { block_vecs.resize(count); for (auto i = 0U; i < count; ++i) { auto loc = get_vector_chunk_loc(ids[i]); ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t read_size = sparse_meta_size(); size_t ret = node_chunks_[loc.first]->read(loc.second, block_vecs[i], read_size); if (ailego_unlikely(ret != read_size)) { LOG_ERROR("Read vector failed, offset=%u, read size=%zu, ret=%zu", loc.second, read_size, ret); return IndexError_ReadData; } } return 0; } //! Get vector feature data by key const void *HnswSparseStreamerEntity::get_sparse_data(uint64_t offset, uint32_t len) const { uint32_t chunk_index = offset >> 32; uint32_t chunk_offset = offset & 0xFFFFFFFF; auto loc = get_sparse_chunk_loc(chunk_index, chunk_offset); const void *data = nullptr; ailego_assert_with(loc.first < sparse_node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < sparse_node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t ret = sparse_node_chunks_[loc.first]->read(loc.second, &data, len); if (ailego_unlikely(ret != len)) { LOG_ERROR("Read sparse vector failed, offset=%zu, read size=%u, ret=%zu", (size_t)offset, len, ret); } return data; } int HnswSparseStreamerEntity::get_sparse_data( uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const { uint32_t chunk_index = offset >> 32; uint32_t chunk_offset = offset & 0xFFFFFFFF; auto loc = get_sparse_chunk_loc(chunk_index, chunk_offset); ailego_assert_with(loc.first < sparse_node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < sparse_node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t ret = sparse_node_chunks_[loc.first]->read(loc.second, block, len); if (ailego_unlikely(ret != len)) { LOG_ERROR("Read sparse vector failed, offset=%zu, read size=%u, ret=%zu", (size_t)offset, len, ret); return IndexError_ReadData; } return 0; } //! Get sparse data from id const void *HnswSparseStreamerEntity::get_sparse_data(node_id_t id) const { auto sparse_data = get_sparse_data_from_vector(get_vector_meta(id)); return sparse_data.first; } int HnswSparseStreamerEntity::get_sparse_data( node_id_t id, IndexStorage::MemoryBlock &block) const { IndexStorage::MemoryBlock meta_block; get_vector_meta(id, meta_block); int sparse_length = 0; return get_sparse_data_from_vector(meta_block.data(), block, sparse_length); } //! Get sparse data from vector std::pair HnswSparseStreamerEntity::get_sparse_data_from_vector(const void *vec) const { const char *vec_ptr = reinterpret_cast(vec); uint64_t offset = *((uint64_t *)(vec_ptr)); uint32_t sparse_vector_len = *((uint32_t *)(vec_ptr + sizeof(uint64_t))); if (sparse_vector_len > 0) { const void *sparse_data = get_sparse_data(offset, sparse_vector_len); if (ailego_unlikely(sparse_data == nullptr)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)offset, sparse_vector_len); return std::make_pair(nullptr, 0); } return std::make_pair(sparse_data, sparse_vector_len); } return std::make_pair(nullptr, 0); } int HnswSparseStreamerEntity::get_sparse_data_from_vector( const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const { const char *vec_ptr = reinterpret_cast(vec); uint64_t offset = *((uint64_t *)(vec_ptr)); uint32_t sparse_vector_len = *((uint32_t *)(vec_ptr + sizeof(uint64_t))); if (sparse_vector_len > 0) { int ret = get_sparse_data(offset, sparse_vector_len, block); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Get nullptr sparse, offset=%zu, len=%u", (size_t)offset, sparse_vector_len); return IndexError_ReadData; } sparse_length = sparse_vector_len; } return 0; } key_t HnswSparseStreamerEntity::get_key(node_id_t id) const { auto loc = get_key_chunk_loc(id); IndexStorage::MemoryBlock key_block; ailego_assert_with(loc.first < node_chunks_.size(), "invalid chunk idx"); ailego_assert_with(loc.second < node_chunks_[loc.first]->data_size(), "invalid chunk offset"); size_t ret = node_chunks_[loc.first]->read(loc.second, key_block, sizeof(key_t)); if (ailego_unlikely(ret != sizeof(key_t))) { LOG_ERROR("Read vector failed, ret=%zu", ret); return kInvalidKey; } return *reinterpret_cast(key_block.data()); } void HnswSparseStreamerEntity::add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) { auto loc = get_neighbor_chunk_loc(level, id); size_t offset = loc.second + sizeof(NeighborsHeader) + size * sizeof(node_id_t); ailego_assert_with(size < neighbor_cnt(level), "invalid neighbor size"); ailego_assert_with(offset < loc.first->data_size(), "invalid chunk offset"); size_t ret = loc.first->write(offset, &neighbor_id, sizeof(node_id_t)); if (ailego_unlikely(ret != sizeof(node_id_t))) { LOG_ERROR("Write neighbor id failed, ret=%zu", ret); return; } uint32_t neighbors = size + 1; ret = loc.first->write(loc.second, &neighbors, sizeof(uint32_t)); if (ailego_unlikely(ret != sizeof(uint32_t))) { LOG_ERROR("Write neighbor cnt failed, ret=%zu", ret); } return; } int HnswSparseStreamerEntity::init_chunks( const SparseChunk::Pointer &header_chunk) { if (header_chunk->data_size() < header_size()) { LOG_ERROR("Invalid header chunk size"); return IndexError_InvalidFormat; } IndexStorage::MemoryBlock data_block; size_t size = header_chunk->read(0UL, data_block, header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Read header chunk failed"); return IndexError_ReadData; } *mutable_header() = *reinterpret_cast(data_block.data()); int ret = check_hnsw_index(&header()); if (ret != 0) { broker_->close(); return ret; } node_chunks_.resize( broker_->get_chunk_cnt(SparseChunkBroker::CHUNK_TYPE_NODE)); for (auto seq = 0UL; seq < node_chunks_.size(); ++seq) { node_chunks_[seq] = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_NODE, seq); if (!node_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer data chunk %zu th of %zu", seq, node_chunks_.size()); return IndexError_InvalidFormat; } } upper_neighbor_chunks_.resize( broker_->get_chunk_cnt(SparseChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR)); for (auto seq = 0UL; seq < upper_neighbor_chunks_.size(); ++seq) { upper_neighbor_chunks_[seq] = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, seq); if (!upper_neighbor_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer index chunk %zu th of %zu", seq, upper_neighbor_chunks_.size()); return IndexError_InvalidFormat; } } sparse_node_chunks_.resize( broker_->get_chunk_cnt(SparseChunkBroker::CHUNK_TYPE_SPARSE_NODE)); for (auto seq = 0UL; seq < sparse_node_chunks_.size(); ++seq) { sparse_node_chunks_[seq] = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_SPARSE_NODE, seq); if (!sparse_node_chunks_[seq]) { LOG_ERROR("Missing hnsw streamer sparse data chunk %zu th of %zu", seq, sparse_node_chunks_.size()); return IndexError_InvalidFormat; } } return 0; } int HnswSparseStreamerEntity::open(IndexStorage::Pointer stg, bool check_crc) { std::lock_guard lock(mutex_); int ret = broker_->open(std::move(stg), max_index_size_, chunk_size_, check_crc); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Open index failed for %s", IndexError::What(ret)); return ret; } ret = upper_neighbor_index_->init(broker_, upper_neighbor_chunk_size_, scaling_factor(), estimate_doc_capacity(), kUpperHashMemoryInflateRatio); if (ailego_unlikely(ret != 0)) { LOG_ERROR("Init neighbor hash map failed"); return ret; } //! init header auto header_chunk = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_HEADER, SparseChunkBroker::kDefaultChunkSeqId); if (!header_chunk) { // open empty index, create one auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_HEADER, SparseChunkBroker::kDefaultChunkSeqId, header_size()); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc header chunk failed"); return p.first; } size_t size = p.second->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } //! Open an exist hnsw index ret = init_chunks(header_chunk); if (ailego_unlikely(ret != 0)) { return ret; } //! total docs including features wrote in index but neighbors may not ready node_id_t total_vecs = 0; if (node_chunks_.size() > 0) { size_t last_idx = node_chunks_.size() - 1; auto last_chunk = node_chunks_[last_idx]; if (last_chunk->data_size() % node_size()) { LOG_WARN("The index may broken"); return IndexError_InvalidFormat; } total_vecs = last_idx * node_cnt_per_chunk_ + node_chunks_[last_idx]->data_size() / node_size(); } LOG_INFO( "Open index, l0NeighborCnt=%zu upperneighborCnt=%zu " "efConstruction=%zu curDocCnt=%u totalVecs=%u maxLevel=%u", l0_neighbor_cnt(), upper_neighbor_cnt(), ef_construction(), doc_cnt(), total_vecs, cur_max_level()); //! try to correct the docCnt if index not fully flushed if (doc_cnt() != total_vecs) { LOG_WARN("Index closed abnormally, using totalVecs as curDocCnt"); *mutable_doc_cnt() = total_vecs; } if (filter_same_key_ || get_vector_enabled_) { for (node_id_t id = 0U; id < doc_cnt(); ++id) { (*keys_map_)[get_key(id)] = id; } } stats_.set_loaded_count(doc_cnt()); return 0; } int HnswSparseStreamerEntity::close() { LOG_DEBUG("close index"); std::lock_guard lock(mutex_); flush_header(); mutable_header()->reset(); upper_neighbor_index_->cleanup(); keys_map_->clear(); header_.clear(); node_chunks_.clear(); upper_neighbor_chunks_.clear(); sparse_node_chunks_.clear(); return broker_->close(); } int HnswSparseStreamerEntity::flush(uint64_t checkpoint) { LOG_INFO("Flush index, curDocs=%u", doc_cnt()); std::lock_guard lock(mutex_); flush_header(); int ret = broker_->flush(checkpoint); if (ret != 0) { return ret; } return 0; } int HnswSparseStreamerEntity::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("Dump index, curDocs=%u", doc_cnt()); //! sort by keys, to support get_vector by key in searcher std::vector keys(doc_cnt()); for (node_id_t i = 0; i < doc_cnt(); ++i) { keys[i] = get_key(i); } //! dump neighbors auto get_level = [&](node_id_t id) { auto it = upper_neighbor_index_->find(id); if (it == upper_neighbor_index_->end()) { return 0U; }; auto meta = reinterpret_cast(&it->second); return meta->level; }; auto ret = dump_segments(dumper, keys.data(), get_level); if (ailego_unlikely(ret < 0)) { return ret; } *stats_.mutable_dumped_size() += ret; return 0; } int HnswSparseStreamerEntity::check_hnsw_index( const HNSWSparseHeader *hd) const { if (l0_neighbor_cnt() != hd->neighbor_cnt() || upper_neighbor_cnt() != hd->upper_neighbor_cnt()) { LOG_ERROR("Param neighbors:%zu:%zu mismatch index previous %zu:%zu", l0_neighbor_cnt(), upper_neighbor_cnt(), hd->neighbor_cnt(), hd->upper_neighbor_cnt()); return IndexError_Mismatch; } if (ef_construction() != hd->ef_construction()) { LOG_WARN("Param efConstruction %zu mismatch index previous %zu", ef_construction(), hd->ef_construction()); } if (scaling_factor() != hd->scaling_factor()) { LOG_WARN("Param scalingFactor %zu mismatch index previous %zu", scaling_factor(), hd->scaling_factor()); return IndexError_Mismatch; } if (prune_cnt() != hd->neighbor_prune_cnt()) { LOG_WARN("Param pruneCnt %zu mismatch index previous %zu", prune_cnt(), hd->neighbor_prune_cnt()); return IndexError_Mismatch; } if ((hd->entry_point() != kInvalidNodeId && hd->entry_point() >= hd->doc_cnt()) || (hd->entry_point() == kInvalidNodeId && hd->doc_cnt() > 0U)) { LOG_WARN("Invalid entryPoint %u, docCnt %u", hd->entry_point(), hd->doc_cnt()); return IndexError_InvalidFormat; } if (hd->entry_point() == kInvalidNodeId && broker_->get_chunk_cnt(SparseChunkBroker::CHUNK_TYPE_NODE) > 0) { LOG_WARN("The index is broken, maybe it haven't flush"); return IndexError_InvalidFormat; } return 0; } int HnswSparseStreamerEntity::add_vector(level_t level, key_t key, const std::string &sparse_vec, uint32_t sparse_count, node_id_t *id) { // allocat sparse chunk uint32_t sparse_vector_len = sparse_vec.size(); sparse_vector_len = AlignSize(sparse_vector_len); if (sparse_vector_len > sparse_chunk_size_) { LOG_ERROR( "Sparse Vector Length exceed the chunk size, sparse vec len: %u, chunk " "size: %u", sparse_vector_len, sparse_chunk_size_); return IndexError_InvalidArgument; } SparseChunk::Pointer node_chunk; SparseChunk::Pointer sparse_node_chunk; size_t chunk_offset = -1UL; size_t sparse_chunk_offset = -1UL; std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } node_id_t local_id = static_cast(doc_cnt()); uint32_t chunk_index = node_chunks_.size() - 1U; if (chunk_index == -1U || (node_chunks_[chunk_index]->data_size() >= node_cnt_per_chunk_ * node_size())) { // no space left and need to alloc if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } chunk_index++; auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; chunk_offset = 0UL; node_chunks_.emplace_back(node_chunk); } else { node_chunk = node_chunks_[chunk_index]; chunk_offset = node_chunk->data_size(); } uint32_t sparse_chunk_index = sparse_node_chunks_.size() - 1U; if (sparse_chunk_index == -1U || sparse_node_chunks_[sparse_chunk_index]->data_size() + sparse_vector_len > sparse_chunk_size_) { if (ailego_unlikely(sparse_node_chunks_.capacity() == sparse_node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } sparse_chunk_index++; auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_SPARSE_NODE, sparse_chunk_index, sparse_chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } sparse_node_chunk = p.second; sparse_node_chunks_.emplace_back(sparse_node_chunk); sparse_chunk_offset = 0UL; } else { sparse_node_chunk = sparse_node_chunks_[sparse_chunk_index]; sparse_chunk_offset = sparse_node_chunk->data_size(); } // write sparse vector if (sparse_vec.size() > 0) { size_t size = sparse_node_chunk->write( sparse_chunk_offset, sparse_vec.data(), sparse_vec.size()); if (ailego_unlikely(size != sparse_vec.size())) { LOG_ERROR("SparseChunk write sparse vec failed, ret=%zu", size); return IndexError_WriteData; } } uint64_t sparse_offset = sparse_chunk_index; sparse_offset = (sparse_offset << 32) + sparse_chunk_offset; size_t size = node_chunk->write(chunk_offset, &sparse_offset, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("SparseChunk write sparse vec index failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + sizeof(uint64_t), &sparse_vector_len, sizeof(uint32_t)); if (ailego_unlikely(size != sizeof(uint32_t))) { LOG_ERROR("SparseChunk write sparse vec len failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + sparse_meta_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("SparseChunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, local_id); if (ret != 0) { return ret; } if (sparse_vector_len > 0) { sparse_chunk_offset += sparse_vector_len; if (ailego_unlikely(sparse_node_chunk->resize(sparse_chunk_offset) != sparse_chunk_offset)) { LOG_ERROR("SparseChunk resize to %zu failed", sparse_chunk_offset); return IndexError_Runtime; } } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("SparseChunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } if (filter_same_key_ || get_vector_enabled_) { keys_map_lock_->lock(); (*keys_map_)[key] = local_id; keys_map_lock_->unlock(); } *mutable_doc_cnt() += 1; *mutable_total_sparse_count() += sparse_count; broker_->mark_dirty(); *id = local_id; return 0; } int HnswSparseStreamerEntity::add_vector_with_id(level_t level, node_id_t id, const std::string &sparse_vec, uint32_t sparse_count) { key_t key = id; SparseChunk::Pointer node_chunk; SparseChunk::Pointer sparse_node_chunk; size_t chunk_offset = -1UL; size_t sparse_chunk_offset = -1UL; // allocat sparse chunk uint32_t sparse_vector_len = sparse_vec.size(); sparse_vector_len = AlignSize(sparse_vector_len); if (sparse_vector_len > sparse_chunk_size_) { LOG_ERROR( "Sparse Vector Length exceed the chunk size, sparse vec len: %u, chunk " "size: %u", sparse_vector_len, sparse_chunk_size_); return IndexError_InvalidArgument; } std::lock_guard lock(mutex_); // duplicate check if (ailego_unlikely(filter_same_key_ && get_id(key) != kInvalidNodeId)) { LOG_WARN("Try to add duplicate key, ignore it"); return IndexError_Duplicate; } auto func_get_sparse_node_chunk_and_offset = [&](node_id_t node_id) -> int { uint32_t chunk_index = node_id >> node_index_mask_bits_; ailego_assert_with(chunk_index <= node_chunks_.size(), "invalid chunk idx"); // belongs to next chunk if (chunk_index == node_chunks_.size()) { if (ailego_unlikely(node_chunks_.capacity() == node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_index, chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } node_chunk = p.second; node_chunks_.emplace_back(node_chunk); } node_chunk = node_chunks_[chunk_index]; chunk_offset = (node_id & node_index_mask_) * node_size(); return 0; }; for (size_t start_id = doc_cnt(); start_id < id; ++start_id) { if (auto ret = func_get_sparse_node_chunk_and_offset(start_id); ret != 0) { LOG_ERROR("func_get_sparse_node_chunk_and_offset failed"); return ret; } size_t size = node_chunk->write(chunk_offset + sparse_meta_size(), &kInvalidKey, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("SparseChunk write key failed, ret=%zu", size); return IndexError_WriteData; } chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("SparseChunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } if (auto ret = func_get_sparse_node_chunk_and_offset(id); ret != 0) { LOG_ERROR("func_get_sparse_node_chunk_and_offset failed"); return ret; } uint32_t sparse_chunk_index = sparse_node_chunks_.size() - 1U; if (sparse_chunk_index == -1U || sparse_node_chunks_[sparse_chunk_index]->data_size() + sparse_vector_len > sparse_chunk_size_) { if (ailego_unlikely(sparse_node_chunks_.capacity() == sparse_node_chunks_.size())) { LOG_ERROR("add vector failed for no memory quota"); return IndexError_IndexFull; } sparse_chunk_index++; auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_SPARSE_NODE, sparse_chunk_index, sparse_chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } sparse_node_chunk = p.second; sparse_node_chunks_.emplace_back(sparse_node_chunk); sparse_chunk_offset = 0UL; } else { sparse_node_chunk = sparse_node_chunks_[sparse_chunk_index]; sparse_chunk_offset = sparse_node_chunk->data_size(); } // write sparse vector if (sparse_vec.size() > 0) { size_t size = sparse_node_chunk->write( sparse_chunk_offset, sparse_vec.data(), sparse_vec.size()); if (ailego_unlikely(size != sparse_vec.size())) { LOG_ERROR("SparseChunk write sparse vec failed, ret=%zu", size); return IndexError_WriteData; } } uint64_t sparse_offset = sparse_chunk_index; sparse_offset = (sparse_offset << 32) + sparse_chunk_offset; size_t size = node_chunk->write(chunk_offset, &sparse_offset, sizeof(uint64_t)); if (ailego_unlikely(size != sizeof(uint64_t))) { LOG_ERROR("SparseChunk write sparse vec index failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + sizeof(uint64_t), &sparse_vector_len, sizeof(uint32_t)); if (ailego_unlikely(size != sizeof(uint32_t))) { LOG_ERROR("SparseChunk write sparse vec len failed, ret=%zu", size); return IndexError_WriteData; } size = node_chunk->write(chunk_offset + sparse_meta_size(), &key, sizeof(key_t)); if (ailego_unlikely(size != sizeof(key_t))) { LOG_ERROR("SparseChunk write vec failed, ret=%zu", size); return IndexError_WriteData; } //! level 0 neighbors is inited to zero by default int ret = add_upper_neighbor(level, id); if (ret != 0) { return ret; } if (sparse_vector_len > 0) { sparse_chunk_offset += sparse_vector_len; if (ailego_unlikely(sparse_node_chunk->resize(sparse_chunk_offset) != sparse_chunk_offset)) { LOG_ERROR("SparseChunk resize to %zu failed", sparse_chunk_offset); return IndexError_Runtime; } } if (*mutable_doc_cnt() <= id) { *mutable_doc_cnt() = id + 1; chunk_offset += node_size(); if (ailego_unlikely(node_chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("Chunk resize to %zu failed", chunk_offset); return IndexError_Runtime; } } *mutable_total_sparse_count() += sparse_count; if (filter_same_key_ || get_vector_enabled_) { keys_map_lock_->lock(); (*keys_map_)[key] = id; keys_map_lock_->unlock(); } broker_->mark_dirty(); return 0; } void HnswSparseStreamerEntity::update_ep_and_level(node_id_t ep, level_t level) { HnswSparseEntity::update_ep_and_level(ep, level); flush_header(); return; } const HnswSparseEntity::Pointer HnswSparseStreamerEntity::clone() const { std::vector node_chunks; node_chunks.reserve(node_chunks_.size()); for (size_t i = 0UL; i < node_chunks_.size(); ++i) { node_chunks.emplace_back(node_chunks_[i]->clone()); if (ailego_unlikely(!node_chunks[i])) { LOG_ERROR("HnswSparseStreamerEntity get chunk failed in clone"); return HnswSparseEntity::Pointer(); } } std::vector sparse_node_chunks; sparse_node_chunks.reserve(sparse_node_chunks_.size()); for (size_t i = 0UL; i < sparse_node_chunks_.size(); ++i) { sparse_node_chunks.emplace_back(sparse_node_chunks_[i]->clone()); if (ailego_unlikely(!sparse_node_chunks[i])) { LOG_ERROR("HnswSparseStreamerEntity get sparse chunk failed in clone"); return HnswSparseEntity::Pointer(); } } std::vector upper_neighbor_chunks; upper_neighbor_chunks.reserve(upper_neighbor_chunks_.size()); for (size_t i = 0UL; i < upper_neighbor_chunks_.size(); ++i) { upper_neighbor_chunks.emplace_back(upper_neighbor_chunks_[i]->clone()); if (ailego_unlikely(!upper_neighbor_chunks[i])) { LOG_ERROR("HnswSparseStreamerEntity get chunk failed in clone"); return HnswSparseEntity::Pointer(); } } HnswSparseStreamerEntity *entity = new (std::nothrow) HnswSparseStreamerEntity( stats_, header(), chunk_size_, node_index_mask_bits_, upper_neighbor_mask_bits_, filter_same_key_, get_vector_enabled_, sparse_chunk_size_, upper_neighbor_index_, keys_map_lock_, keys_map_, std::move(node_chunks), std::move(upper_neighbor_chunks), std::move(sparse_node_chunks), broker_); if (ailego_unlikely(!entity)) { LOG_ERROR("HnswSparseStreamerEntity new failed"); } return HnswSparseEntity::Pointer(entity); } //! Get sparse vector feature data by key int HnswSparseStreamerEntity::get_sparse_vector_by_key( key_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { *sparse_count = 0; auto id = get_id(key); if (id == kInvalidNodeId) { return IndexError_NoExist; } return get_sparse_vector_by_id(id, sparse_count, sparse_indices_buffer, sparse_values_buffer); } int HnswSparseStreamerEntity::get_sparse_vector_by_id( node_id_t id, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const { IndexStorage::MemoryBlock block; get_sparse_data(id, block); const void *sparse_data = block.data(); if (sparse_data == nullptr) { return IndexError_InvalidValue; } SparseUtility::ReverseSparseFormat(sparse_data, sparse_count, sparse_indices_buffer, sparse_values_buffer, sparse_unit_size()); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/hnsw_sparse/hnsw_sparse_streamer_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "hnsw_sparse_chunk.h" #include "hnsw_sparse_entity.h" #include "hnsw_sparse_index_hash.h" #include "hnsw_sparse_params.h" namespace zvec { namespace core { //! HnswSparseStreamerEntity manage vector data, pkey, and node's neighbors class HnswSparseStreamerEntity : public HnswSparseEntity { public: //! Cleanup //! return 0 on success, or errCode in failure virtual int cleanup() override; //! Make a copy of streamer entity, to support thread-safe operation. //! The segment in container cannot be read concurrenly virtual const HnswSparseEntity::Pointer clone() const override; //! Get primary key of the node id virtual key_t get_key(node_id_t id) const override; //! Get vector feature data by key virtual const void *get_vector_meta(node_id_t id) const override; virtual int get_vector_meta(const node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Get vectors feature data by local ids virtual int get_vector_metas(const node_id_t *ids, uint32_t count, const void **vecs) const override; virtual int get_vector_metas( const node_id_t *ids, uint32_t count, std::vector &block_vecs) const override; //! Get vector sparse feature data by chunk index and offset virtual const void *get_sparse_data(uint64_t offset, uint32_t len) const override; //! Get sparse data from id virtual const void *get_sparse_data(node_id_t id) const override; virtual int get_sparse_data(uint64_t offset, uint32_t len, IndexStorage::MemoryBlock &block) const override; virtual int get_sparse_data(node_id_t id, IndexStorage::MemoryBlock &block) const override; //! Get sparse data from vector virtual std::pair get_sparse_data_from_vector( const void *vec) const override; virtual int get_sparse_data_from_vector(const void *vec, IndexStorage::MemoryBlock &block, int &sparse_length) const override; //! Get sparse vector feature data by key virtual int get_sparse_vector_by_key( key_t key, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; //! Get sparse vector feature data by id virtual int get_sparse_vector_by_id( node_id_t id, uint32_t *sparse_count, std::string *sparse_indices_buffer, std::string *sparse_values_buffer) const override; //! Get the node id's neighbors on graph level //! Note: the neighbors cannot be modified, using the following //! method to get WritableNeighbors if want to virtual const Neighbors get_neighbors(level_t level, node_id_t id) const override; //! Add vector and key to hnsw entity, and local id will be saved in id virtual int add_vector(level_t level, key_t key, const std::string &sparse_vec_buffer, uint32_t sparse_count, node_id_t *id) override; //! Add vector and id to hnsw entity virtual int add_vector_with_id(level_t level, node_id_t id, const std::string &sparse_vec, uint32_t sparse_count) override; virtual int update_neighbors( level_t level, node_id_t id, const std::vector> &neighbors) override; //! Replace node id in level's neighbors int update_neighbors_dense( level_t level, node_id_t id, const std::vector> &neighbors); //! Replace node id in level's neighbors int update_neighbors_sparse( level_t level, node_id_t id, const std::vector> &neighbors); //! Append neighbor_id to node id neighbors on level //! Notice: the caller must be ensure the neighbors not full virtual void add_neighbor(level_t level, node_id_t id, uint32_t size, node_id_t neighbor_id) override; //! Dump index by dumper virtual int dump(const IndexDumper::Pointer &dumper) override; virtual void update_ep_and_level(node_id_t ep, level_t level) override; public: //! Constructor HnswSparseStreamerEntity(IndexStreamer::Stats &stats); //! Destructor ~HnswSparseStreamerEntity(); //! Init entity int init(uint64_t max_index_size, size_t max_doc_cnt); //! Flush graph entity to disk //! return 0 on success, or errCode in failure int flush(uint64_t checkpoint); //! Open entity from storage //! return 0 on success, or errCode in failure int open(IndexStorage::Pointer stg, bool check_crc); //! Close entity //! return 0 on success, or errCode in failure int close(); //! Set meta information from entity int set_index_meta(const IndexMeta &meta) const { return IndexHelper::SerializeToStorage(meta, broker_->storage().get()); } //! Get meta information from entity int get_index_meta(IndexMeta *meta) const { return IndexHelper::DeserializeFromStorage(broker_->storage().get(), meta); } //! Set params: chunk size inline void set_chunk_size(size_t val) { chunk_size_ = val; } //! Set params inline void set_filter_same_key(bool val) { filter_same_key_ = val; } //! Set params inline void set_get_vector(bool val) { get_vector_enabled_ = val; } //! Get vector local id by key inline node_id_t get_id(key_t key) const { keys_map_lock_->lock_shared(); auto it = keys_map_->find(key); keys_map_lock_->unlock_shared(); return it == keys_map_->end() ? kInvalidNodeId : it->second; } void print_key_map() { std::cout << "key map begins" << std::endl; auto iter = keys_map_->begin(); while (iter != keys_map_->end()) { std::cout << "key: " << iter->first << ", id: " << iter->second << std::endl; ; iter++; } std::cout << "key map ends" << std::endl; } //! Get neighbors size inline size_t neighbors_size() const { return sizeof(NeighborsHeader) + l0_neighbor_cnt() * sizeof(node_id_t); } //! Get upper neighbors size inline size_t upper_neighbors_size() const { return sizeof(NeighborsHeader) + upper_neighbor_cnt() * sizeof(node_id_t); } private: union UpperNeighborIndexMeta { struct { uint32_t level : 4; uint32_t index : 28; // index is composite type: chunk idx, and the // N th neighbors in chunk, they two composite // the 28 bits location }; uint32_t data; }; template using HashMap = google::dense_hash_map>; template using HashMapPointer = std::shared_ptr>; template using HashSet = google::dense_hash_set>; template using HashSetPointer = std::shared_ptr>; //! upper neighbor index hashmap using NIHashMap = HnswSparseIndexHashMap; using NIHashMapPointer = std::shared_ptr; //! Private construct, only be called by clone method HnswSparseStreamerEntity( IndexStreamer::Stats &stats, const HNSWSparseHeader &hd, size_t chunk_size, uint32_t node_index_mask_bits, uint32_t upper_neighbor_mask_bits, bool filter_same_key, bool get_vector_enabled, uint32_t sparse_chunk_size, const NIHashMapPointer &upper_neighbor_index, std::shared_ptr &keys_map_lock, const HashMapPointer &keys_map, std::vector &&node_chunks, std::vector &&upper_neighbor_chunks, std::vector &&sparse_node_chunks, const SparseChunkBroker::Pointer &broker) : stats_(stats), chunk_size_(chunk_size), node_index_mask_bits_(node_index_mask_bits), node_cnt_per_chunk_(1UL << node_index_mask_bits_), node_index_mask_(node_cnt_per_chunk_ - 1), upper_neighbor_mask_bits_(upper_neighbor_mask_bits), upper_neighbor_mask_((1U << upper_neighbor_mask_bits_) - 1), filter_same_key_(filter_same_key), get_vector_enabled_(get_vector_enabled), sparse_chunk_size_(sparse_chunk_size), upper_neighbor_index_(upper_neighbor_index), keys_map_lock_(keys_map_lock), keys_map_(keys_map), node_chunks_(std::move(node_chunks)), upper_neighbor_chunks_(std::move(upper_neighbor_chunks)), sparse_node_chunks_(std::move(sparse_node_chunks)), broker_(broker) { *mutable_header() = hd; neighbor_size_ = neighbors_size(); upper_neighbor_size_ = upper_neighbors_size(); } //! Called only in searching procedure per context, so no need to lock void sync_chunks(SparseChunkBroker::CHUNK_TYPE type, size_t idx, std::vector *chunks) const { if (ailego_likely(idx < chunks->size())) { return; } for (size_t i = chunks->size(); i <= idx; ++i) { auto chunk = broker_->get_chunk(type, i); // the storage can ensure get chunk will success after the first get ailego_assert_with(!!chunk, "get chunk failed"); chunks->emplace_back(std::move(chunk)); } } //! return pair: chunk index + chunk offset inline std::pair get_vector_chunk_loc( node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size(); sync_chunks(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } //! return pair: chunk index + chunk offset inline std::pair get_key_chunk_loc(node_id_t id) const { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size(); offset += sparse_meta_size(); sync_chunks(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); return std::make_pair(chunk_idx, offset); } //! return pair: chunk index + chunk offset inline std::pair get_sparse_chunk_loc( uint32_t chunk_index, uint32_t chunk_offset) const { sync_chunks(SparseChunkBroker::CHUNK_TYPE_SPARSE_NODE, chunk_index, &sparse_node_chunks_); return std::make_pair(chunk_index, chunk_offset); } inline std::pair get_upper_neighbor_chunk_loc( level_t level, node_id_t id) const { auto it = upper_neighbor_index_->find(id); ailego_assert_abort(it != upper_neighbor_index_->end(), "Get upper neighbor header failed"); auto meta = reinterpret_cast(&it->second); uint32_t chunk_idx = (meta->index) >> upper_neighbor_mask_bits_; uint32_t offset = (((meta->index) & upper_neighbor_mask_) + level - 1) * upper_neighbor_size_; sync_chunks(SparseChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_idx, &upper_neighbor_chunks_); ailego_assert_abort(chunk_idx < upper_neighbor_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < upper_neighbor_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(chunk_idx, offset); } //! return pair: chunk + chunk offset inline std::pair get_neighbor_chunk_loc( level_t level, node_id_t id) const { if (level == 0UL) { uint32_t chunk_idx = id >> node_index_mask_bits_; uint32_t offset = (id & node_index_mask_) * node_size() + vector_size() + sizeof(key_t); offset += sparse_meta_size(); sync_chunks(SparseChunkBroker::CHUNK_TYPE_NODE, chunk_idx, &node_chunks_); ailego_assert_abort(chunk_idx < node_chunks_.size(), "invalid chunk idx"); ailego_assert_abort(offset < node_chunks_[chunk_idx]->data_size(), "invalid chunk offset"); return std::make_pair(node_chunks_[chunk_idx].get(), offset); } else { auto p = get_upper_neighbor_chunk_loc(level, id); return std::make_pair(upper_neighbor_chunks_[p.first].get(), p.second); } } //! Chunk hnsw index valid int check_hnsw_index(const HNSWSparseHeader *hd) const; size_t get_total_upper_neighbors_size(level_t level) const { return level * upper_neighbor_size_; } //! Add upper neighbor header and reserve space for upper neighbor int add_upper_neighbor(level_t level, node_id_t id) { if (level == 0) { return 0; } SparseChunk::Pointer chunk; uint64_t chunk_offset = -1UL; size_t neighbors_size = get_total_upper_neighbors_size(level); uint64_t chunk_index = upper_neighbor_chunks_.size() - 1UL; if (chunk_index == -1UL || (upper_neighbor_chunks_[chunk_index]->padding_size() < neighbors_size)) { // no space left and need to alloc chunk_index++; if (ailego_unlikely(upper_neighbor_chunks_.capacity() == upper_neighbor_chunks_.size())) { LOG_ERROR("add upper neighbor failed for no memory quota"); return IndexError_IndexFull; } auto p = broker_->alloc_chunk(SparseChunkBroker::CHUNK_TYPE_UPPER_NEIGHBOR, chunk_index, upper_neighbor_chunk_size_); if (ailego_unlikely(p.first != 0)) { LOG_ERROR("Alloc data chunk failed"); return p.first; } chunk = p.second; chunk_offset = 0UL; upper_neighbor_chunks_.emplace_back(chunk); } else { chunk = upper_neighbor_chunks_[chunk_index]; chunk_offset = chunk->data_size(); } ailego_assert_with((size_t)level < kMaxGraphLayers, "invalid level"); ailego_assert_with(chunk_offset % upper_neighbor_size_ == 0, "invalid offset"); ailego_assert_with((chunk_offset / upper_neighbor_size_) < (1U << upper_neighbor_mask_bits_), "invalid offset"); ailego_assert_with(chunk_index < (1U << (28 - upper_neighbor_mask_bits_)), "invalid chunk index"); UpperNeighborIndexMeta meta; meta.level = level; meta.index = (chunk_index << upper_neighbor_mask_bits_) | (chunk_offset / upper_neighbor_size_); chunk_offset += upper_neighbor_size_ * level; if (ailego_unlikely(!upper_neighbor_index_->insert(id, meta.data))) { LOG_ERROR("HashMap insert value failed"); return IndexError_Runtime; } if (ailego_unlikely(chunk->resize(chunk_offset) != chunk_offset)) { LOG_ERROR("SparseChunk resize to %zu failed", (size_t)chunk_offset); return IndexError_Runtime; } return 0; } size_t estimate_doc_capacity() const { return node_chunks_.capacity() * node_cnt_per_chunk_; } int init_chunk_params(size_t max_index_size) { sparse_chunk_size_ = AlignPageSize(chunk_size_); node_cnt_per_chunk_ = std::max(1, chunk_size_ / node_size()); //! align node cnt per chunk to pow of 2 node_index_mask_bits_ = std::ceil(std::log2(node_cnt_per_chunk_)); node_cnt_per_chunk_ = 1UL << node_index_mask_bits_; chunk_size_ = AlignPageSize(node_cnt_per_chunk_ * node_size()); node_index_mask_ = node_cnt_per_chunk_ - 1; if (max_index_size == 0UL) { max_index_size_ = chunk_size_ * kDefaultMaxChunkCnt; } else { max_index_size_ = max_index_size; } //! To get a balanced upper neighbor chunk size. //! If the upper chunk size is equal to node chunk size, it may waste //! upper neighbor chunk space; if the upper neighbor chunk size is too //! small, the will need large upper neighbor chunks index space. So to //! get a balanced ratio be sqrt of the node/neighbor size ratio float ratio = std::sqrt(node_size() * scaling_factor() * 1.0f / upper_neighbor_size_); upper_neighbor_chunk_size_ = AlignPageSize(std::max(get_total_upper_neighbors_size(kMaxGraphLayers), static_cast(chunk_size_ / ratio))); upper_neighbor_mask_bits_ = std::ceil(std::log2(upper_neighbor_chunk_size_ / upper_neighbor_size_)); upper_neighbor_mask_ = (1 << upper_neighbor_mask_bits_) - 1; size_t max_node_chunk_cnt = std::ceil(max_index_size_ / chunk_size_); size_t max_upper_chunk_cnt = std::ceil( (max_node_chunk_cnt * node_cnt_per_chunk_ * 1.0f / scaling_factor()) / (upper_neighbor_chunk_size_ / upper_neighbor_size_)); max_upper_chunk_cnt = max_upper_chunk_cnt + std::ceil(max_upper_chunk_cnt / scaling_factor()); //! reserve space to avoid memmove in chunks vector emplace chunk, so //! as to lock-free in reading chunk node_chunks_.reserve(max_node_chunk_cnt); sparse_node_chunks_.reserve(max_node_chunk_cnt); upper_neighbor_chunks_.reserve(max_upper_chunk_cnt); LOG_DEBUG( "Settings: nodeSize=%zu chunkSize=%u upperNeighborSize=%u " "upperNeighborChunkSize=%u " "nodeCntPerChunk=%u maxChunkCnt=%zu maxNeighborChunkCnt=%zu " "maxIndexSize=%zu ratio=%.3f", node_size(), chunk_size_, upper_neighbor_size_, upper_neighbor_chunk_size_, node_cnt_per_chunk_, max_node_chunk_cnt, max_upper_chunk_cnt, max_index_size_, ratio); return 0; } //! Init node chunk and neighbor chunks int init_chunks(const SparseChunk::Pointer &header_chunk); int flush_header(void) { if (!broker_->dirty()) { // do not need to flush return 0; } auto header_chunk = broker_->get_chunk(SparseChunkBroker::CHUNK_TYPE_HEADER, SparseChunkBroker::kDefaultChunkSeqId); if (ailego_unlikely(!header_chunk)) { LOG_ERROR("get header chunk failed"); return IndexError_Runtime; } size_t size = header_chunk->write(0UL, &header(), header_size()); if (ailego_unlikely(size != header_size())) { LOG_ERROR("Write header chunk failed"); return IndexError_WriteData; } return 0; } private: HnswSparseStreamerEntity(const HnswSparseStreamerEntity &) = delete; HnswSparseStreamerEntity &operator=(const HnswSparseStreamerEntity &) = delete; static constexpr uint64_t kUpperHashMemoryInflateRatio = 2.0f; private: IndexStreamer::Stats &stats_; HNSWSparseHeader header_{}; std::mutex mutex_{}; size_t max_index_size_{0UL}; uint32_t chunk_size_{kDefaultChunkSize}; uint32_t upper_neighbor_chunk_size_{kDefaultChunkSize}; uint32_t node_index_mask_bits_{0U}; uint32_t node_cnt_per_chunk_{0U}; uint32_t node_index_mask_{0U}; uint32_t neighbor_size_{0U}; uint32_t upper_neighbor_size_{0U}; //! UpperNeighborIndex.index composite chunkIdx and offset in chunk by the //! following mask uint32_t upper_neighbor_mask_bits_{0U}; uint32_t upper_neighbor_mask_{0U}; bool filter_same_key_{false}; bool get_vector_enabled_{false}; uint32_t sparse_chunk_size_{kDefaultChunkSize}; NIHashMapPointer upper_neighbor_index_{}; mutable std::shared_ptr keys_map_lock_{}; HashMapPointer keys_map_{}; //! the chunks will be changed in searcher, so need mutable //! data chunk include: vector, key, level 0 neighbors mutable std::vector node_chunks_{}; //! upper neighbor chunk inlude: UpperNeighborHeader + (1~level) neighbors mutable std::vector upper_neighbor_chunks_{}; //! chunk that holds up sparse part mutable std::vector sparse_node_chunks_{}; SparseChunkBroker::Pointer broker_{}; // chunk broker }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_knn_ivf STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS zvec_ailego core_framework core_knn_cluster INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/algorithm/ivf/ivf_builder.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_builder.h" #include #include #include "algorithm/cluster/cluster_params.h" #include "ivf_dumper.h" namespace zvec { namespace core { /*! IndexHolder support filtered by vector labels */ class LabelFilteredIndexHolder : public IndexHolder { public: /*! Index Holder Iterator */ class Iterator : public IndexHolder::Iterator { public: //! Index Holder Iterator Pointer typedef std::unique_ptr Pointer; //! Constructor Iterator(const IVFBuilder::RandomAccessIndexHolder::Pointer &holder, const std::vector *elems) : holder_(holder), elems_(elems) {} //! Destructor virtual ~Iterator(void) {} //! Retrieve pointer of data virtual const void *data(void) const override { return holder_->element((*elems_)[index_]); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return index_ < elems_->size(); } //! Retrieve primary key virtual uint64_t key(void) const override { return (*elems_)[index_]; } //! Next iterator virtual void next(void) override { ++index_; } private: //! Members const IVFBuilder::RandomAccessIndexHolder::Pointer holder_{nullptr}; const std::vector *elems_{nullptr}; size_t index_{0}; }; //! Constructor LabelFilteredIndexHolder( const IVFBuilder::RandomAccessIndexHolder::Pointer &holder, const std::vector &items) : holder_(holder), elems_(&items) {} //! Retrieve count of elements in holder (-1 indicates unknown) virtual size_t count(void) const override { return elems_->size(); } //! Retrieve dimension virtual size_t dimension(void) const override { return holder_->dimension(); } //! Retrieve type information virtual IndexMeta::DataType data_type(void) const override { return holder_->data_type(); } //! Retrieve element size in bytes virtual size_t element_size(void) const override { return holder_->element_size(); } //! Retrieve if it can multi-pass virtual bool multipass(void) const override { return true; } //! Create a new iterator virtual IndexHolder::Iterator::Pointer create_iterator(void) override { return IndexHolder::Iterator::Pointer( new LabelFilteredIndexHolder::Iterator(holder_, elems_)); } private: //! Members const IVFBuilder::RandomAccessIndexHolder::Pointer holder_{}; const std::vector *elems_{}; }; IVFBuilder::IVFBuilder() {} IVFBuilder::~IVFBuilder() { this->cleanup(); } int IVFBuilder::init(const IndexMeta &meta, const ailego::Params ¶ms) { LOG_INFO("Begin IVFBuilder::init!"); if (state_ != INIT) { LOG_ERROR("IVFBuilder state wrong. state=%d", state_); return IndexError_Logic; } meta_ = meta; converted_meta_ = meta; quantized_meta_ = meta; // Clear the converter/reformer params for external transforms converted_meta_.set_reformer(std::string(), 0, ailego::Params()); converted_meta_.set_converter(std::string(), 0, ailego::Params()); quantized_meta_.set_reformer(std::string(), 0, ailego::Params()); quantized_meta_.set_converter(std::string(), 0, ailego::Params()); params_ = params; if (!IndexFactory::HasMetric(meta_.metric_name())) { LOG_ERROR("Metric %s not exist", meta_.metric_name().c_str()); return IndexError_NoExist; } int ret = parse_centroids_num(params); ivf_check_with_msg(ret, "Failed to parse centroids, ret=%d", ret); ret = parse_clustering_params(params); ivf_check_with_msg(ret, "Failed to parse clustering params, ret=%d", ret); ret = parse_general_params(params); ivf_check_with_msg(ret, "Failed to parse general params, ret=%d", ret); LOG_INFO("End IVFBuilder::init!"); LOG_DEBUG( "Converter=%s Quantizer=%s Optimizer=%s " "OptimizerQuantizer=%s QuantizeByCentroid=%u StoreFeatures=%u " "ClusterClass=%s TrainSamplesCount=%u TrainSampleRatio=%f " "BlockVectorCount=%u", params.get_as_string(PARAM_IVF_BUILDER_CONVERTER_CLASS).c_str(), params.get_as_string(PARAM_IVF_BUILDER_QUANTIZER_CLASS).c_str(), params.get_as_string(PARAM_IVF_BUILDER_OPTIMIZER_CLASS).c_str(), params.get_as_string(PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_CLASS).c_str(), params.get_as_bool(PARAM_IVF_BUILDER_QUANTIZE_BY_CENTROID), params.get_as_bool(PARAM_IVF_BUILDER_STORE_ORIGINAL_FEATURES), params.get_as_string(PARAM_IVF_BUILDER_CLUSTER_CLASS).c_str(), params.get_as_uint32(PARAM_IVF_BUILDER_TRAIN_SAMPLE_COUNT), params.get_as_float(PARAM_IVF_BUILDER_TRAIN_SAMPLE_RATIO), block_vector_count_); state_ = INITED; return 0; } int IVFBuilder::cleanup(void) { LOG_INFO("Begin IVFBuilder::cleanup"); state_ = INIT; stats_.clear_attributes(); stats_.set_built_costtime(0u); stats_.set_built_count(0u); stats_.set_discarded_count(0u); stats_.set_dumped_costtime(0u); stats_.set_dumped_count(0u); stats_.set_trained_costtime(0u); stats_.set_trained_count(0u); centroid_num_vec_.clear(); cluster_class_.clear(); converter_class_.clear(); cluster_params_.clear(); labels_.clear(); centroid_index_.reset(); holder_.reset(); converted_meta_ = meta_; converter_.reset(); quantized_meta_ = meta_; quantizers_.clear(); error_ = false; err_code_ = 0; thread_count_ = 0; sample_count_ = 0; cluster_auto_tuning_ = false; store_original_features_ = false; quantize_by_centroid_ = false; LOG_INFO("End IVFBuilder::cleanup"); return 0; } int IVFBuilder::train(IndexThreads::Pointer threads, IndexHolder::Pointer holder) { LOG_INFO("Begin IVFBuilder::train with holder"); if (state_ != INITED) { LOG_ERROR("IVFBuilder train failed, wrong state=%d", state_); return IndexError_Runtime; } if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } ailego::ElapsedTime timer; if (!holder || holder->count() == 0) { LOG_ERROR("Input holder is nullptr or empty while train index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while train index"); return IndexError_Mismatch; } if (converter_) { int ret = IndexConverter::TrainAndTransform(converter_, std::move(holder)); ivf_check_with_msg(ret, "Failed to train or transform by converter %s", converter_->name().c_str()); converted_meta_ = converter_->meta(); holder = converter_->result(); } ailego::Params train_params; int ret = prepare_trainer_params(train_params); ivf_check_with_msg(ret, "Failed to prepare trainer params, ret=%d", ret); IndexTrainer::Pointer trainer = IndexFactory::CreateTrainer("StratifiedClusterTrainer"); ivf_assert_with_msg(trainer, IndexError_NoExist, "Failed to create trainer"); ret = trainer->init(converted_meta_, train_params); ivf_check_with_msg(ret, "Trainer init failed with ret %d", ret); ret = trainer->train(std::move(threads), std::move(holder)); ivf_check_with_msg(ret, "Trainer train failed with ret %d", ret); ret = this->train(trainer); ivf_check_error_code(ret); stats_.set_trained_costtime(timer.milli_seconds()); LOG_INFO("End IVFBuilder::train with holder"); state_ = TRAINED; return 0; } int IVFBuilder::train(const IndexTrainer::Pointer &trainer) { LOG_DEBUG("Begin IVFBuilder::train by trainer"); ailego::ElapsedTime timer; if (state_ != INITED) { LOG_ERROR("IVFBuilder train failed, wrong state=%d", state_); return IndexError_Runtime; } if (!trainer) { LOG_ERROR("Input trainer is nullptr while train index"); return IndexError_InvalidArgument; } IndexCluster::CentroidList centroid_list; IndexBundle::Pointer boundle = trainer->indexes(); int ret = IndexCluster::Deserialize(trainer->meta(), boundle, ¢roid_list); ivf_check_with_msg(ret, "Failed to deserialize index"); const IndexMeta &meta = trainer->meta(); if (meta.data_type() != converted_meta_.data_type() || meta.metric_name().compare(converted_meta_.metric_name()) != 0 || meta.element_size() != converted_meta_.element_size()) { if (meta.converter_name() != converter_class_) { LOG_ERROR("Input trainer doesn't match index meta while train index"); return IndexError_Mismatch; } //! Create converter from trainer params LOG_INFO("Train IVFBuilder by trainer with converter"); converter_ = CreateAndInitConverter(meta_, meta.converter_name(), meta.converter_params()); ivf_assert(converter_, IndexError_Runtime); converted_meta_ = meta; } centroid_index_ = std::make_shared(); if (!centroid_index_) { return IndexError_NoMemory; } ret = centroid_index_->init(converted_meta_, params_); ivf_check_error_code(ret); ret = centroid_index_->build(centroid_list); ivf_check_with_msg(ret, "Failed to build centroid index"); if (params_.has(PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_CLASS)) { //! Quantize the centroids for searcher searcher_centroid_index_ = std::make_shared(); if (!searcher_centroid_index_) { return IndexError_NoMemory; } ailego::Params params; params_.get(PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_PARAMS, ¶ms); searcher_centroid_index_->set_quantizer( params_.get_as_string(PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_CLASS), params); ret = searcher_centroid_index_->init(converted_meta_, params_); ivf_check_error_code(ret); ret = searcher_centroid_index_->build(centroid_list); ivf_check_with_msg(ret, "Failed to build centroid index"); } stats_.set_trained_costtime(timer.milli_seconds()); LOG_DEBUG("End IVFBuilder::train by trainer"); state_ = TRAINED; return 0; } int IVFBuilder::build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) { LOG_INFO("Begin IVFBuilder::build!"); if (state_ != TRAINED) { LOG_ERROR("Train the index first before build"); return IndexError_Runtime; } ailego::ElapsedTime timer; if (!holder || holder->count() == 0) { LOG_ERROR("Input holder is nullptr or empty while building index"); return IndexError_InvalidArgument; } if (!holder->is_matched(meta_)) { LOG_ERROR("Input holder doesn't match index meta while building index"); return IndexError_Mismatch; } if (!threads) { threads = std::make_shared(thread_count_, false); if (!threads) { return IndexError_NoMemory; } } holder_ = std::make_shared(meta_); if (!holder_) { return IndexError_NoMemory; } if (holder->count() > 0) { holder_->reserve(holder->count()); } for (auto iter = holder->create_iterator(); iter && iter->is_valid(); iter->next()) { holder_->emplace(iter->key(), iter->data()); } // Holder is not needed, cleanup it. holder.reset(); IndexHolder::Pointer converted_holder = holder_; if (converter_) { int ret = converter_->transform(holder_); ivf_check_with_msg(ret, "Failed to transform by converter %s", converter_->name().c_str()); converted_holder = converter_->result(); } labels_.resize(centroid_index_->centroids_count()); int ret = this->build_label_index(threads.get(), converted_holder); ivf_check_with_msg(ret, "Failed to build index for %s", IndexError::What(ret)); ret = this->prepare_quantizer(threads.get()); ivf_check_error_code(ret); stats_.set_built_costtime(timer.milli_seconds()); LOG_INFO("End IVFBuilder::build"); state_ = BUILT; return 0; } int IVFBuilder::dump(const IndexDumper::Pointer &dumper) { LOG_INFO("Begin IVFBuilder::dump"); if (state_ != BUILT) { LOG_ERROR("Build the index before dump QC Index"); return IndexError_Runtime; } ailego::ElapsedTime timer; int ret = this->dump_index(dumper); ivf_check_with_msg(ret, "Failed to dump index with ret=%d", ret); // the fitting function for the follow points: 1000000(0.02) 10000000(0.01) // 50000000(0.005) 100000000(0.001) float scan_ratio = -0.004 * std::log(holder_->count()) + 0.0751; scan_ratio = std::max(scan_ratio, 0.0001f); // Set Searcher Params ailego::Params params; params.set(PARAM_IVF_SEARCHER_SCAN_RATIO, scan_ratio); meta_.set_searcher("IVFSearcher", 0, std::move(params)); meta_.set_builder("IVFBuilder", 0, std::move(params_)); ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); if (ret != 0) { LOG_ERROR("Failed to serialize meta into dumper."); return ret; } stats_.set_discarded_count(stats_.built_count() - stats_.dumped_count()); stats_.set_dumped_costtime(timer.milli_seconds()); LOG_INFO("End IVFBuilder::dump"); return 0; } int IVFBuilder::CheckAndUpdateMajorOrder(IndexMeta &meta) { const std::string &metric_name = meta.metric_name(); auto metric = IndexFactory::CreateMetric(metric_name); if (!metric) { LOG_ERROR("CreateMetric %s failed", metric_name.c_str()); return IndexError_InvalidArgument; } int ret = metric->init(meta, meta.metric_params()); ivf_check_with_msg(ret, "IndexMetric %s init failed", metric_name.c_str()); bool support_column_major = true; for (size_t m = 32; m != 0; m /= 2) { for (size_t n = m; n != 0; n /= 2) { if (metric->distance_matrix(m, n) == nullptr) { support_column_major = false; break; } } if (!support_column_major) { break; } } support_column_major &= meta.element_size() % IndexMeta::AlignSizeof(meta.data_type()) == 0; if (meta.major_order() == IndexMeta::MO_UNDEFINED) { if (support_column_major && meta.dimension() <= 512) { meta.set_major_order(IndexMeta::MO_COLUMN); } else { meta.set_major_order(IndexMeta::MO_ROW); } } else { if (!support_column_major && meta.major_order() == IndexMeta::MO_COLUMN) { LOG_WARN( "Index Metric %s Unsupported " "Column Major Order", metric_name.c_str()); return IndexError_Unsupported; } } if (block_vector_count_ * quantized_meta_.element_size() % 32 != 0) { LOG_ERROR( "block_vector_count * quantized_element_size not align with 32 bytes."); return IndexError_InvalidArgument; } return 0; } int IVFBuilder::parse_centroids_num(const ailego::Params ¶ms) { std::string centroids_num = params.get_as_string(PARAM_IVF_BUILDER_CENTROID_COUNT); if (centroids_num.empty()) { LOG_ERROR("Param %s is required", PARAM_IVF_BUILDER_CENTROID_COUNT.c_str()); return IndexError_InvalidArgument; } std::vector centroid_str_vec; ailego::StringHelper::Split(centroids_num, CENTROID_SEPERATOR, ¢roid_str_vec); size_t level_cnt = centroid_str_vec.size(); if ((level_cnt <= 0) || (level_cnt > 2)) { LOG_ERROR("Centroids level count must be [1,2]"); return IndexError_InvalidArgument; } for (size_t idx = 0; idx < level_cnt; ++idx) { uint32_t centroid_cnt = 0; if (!ailego::StringHelper::ToUint32(centroid_str_vec[idx], ¢roid_cnt)) { LOG_ERROR("Invalid centroids count %s", centroid_str_vec[idx].c_str()); return IndexError_InvalidArgument; } centroid_num_vec_.push_back(centroid_cnt); } return 0; } int IVFBuilder::parse_clustering_params(const ailego::Params ¶ms) { params.get(PARAM_IVF_BUILDER_CLUSTER_AUTO_TUNING, &cluster_auto_tuning_); cluster_class_ = params.get_as_string(PARAM_IVF_BUILDER_CLUSTER_CLASS); if (cluster_class_.empty()) { // OptKmeansCluster does not support custom metric cluster_class_ = meta_.metric_name() == kMipsMetricName ? "KmeansCluster" : "OptKmeansCluster"; LOG_INFO("Using [%s] as default cluster class", cluster_class_.c_str()); } for (size_t i = 1; i <= centroid_num_vec_.size(); ++i) { std::string level_params_key = PARAM_IVF_BUILDER_CLUSTER_PARAMS_IN_LEVEL_PREFIX + std::to_string(i); ailego::Params level_params; params.get(level_params_key, &level_params); cluster_params_.push_back(level_params); } return 0; } int IVFBuilder::parse_general_params(const ailego::Params ¶ms) { thread_count_ = params.get_as_uint32(PARAM_IVF_BUILDER_THREAD_COUNT); sample_count_ = params.get_as_uint32(PARAM_IVF_BUILDER_TRAIN_SAMPLE_COUNT); sample_ratio_ = params.get_as_float(PARAM_IVF_BUILDER_TRAIN_SAMPLE_RATIO); params.get(PARAM_IVF_BUILDER_QUANTIZE_BY_CENTROID, &quantize_by_centroid_); params.get(PARAM_IVF_BUILDER_STORE_ORIGINAL_FEATURES, &store_original_features_); //! Prepare Converter for training if (meta_.metric_name() == kIPMetricName) { converter_class_ = kMipsConverterName; } params.get(PARAM_IVF_BUILDER_CONVERTER_CLASS, &converter_class_); if (!converter_class_.empty()) { ailego::Params converter_params; params_.get(PARAM_IVF_BUILDER_CONVERTER_PARAMS, &converter_params); converter_ = CreateAndInitConverter(meta_, converter_class_, converter_params); ivf_assert(converter_, IndexError_NoExist); } params_.get(PARAM_IVF_BUILDER_BLOCK_VECTOR_COUNT, &block_vector_count_); if (block_vector_count_ == 0) { block_vector_count_ = kDefaultBlockCount; } if (block_vector_count_ > kDefaultBlockCount || block_vector_count_ & (block_vector_count_ - 1)) { LOG_ERROR("block_vector_count only can be [1|2|4|8|16|32]."); return IndexError_InvalidArgument; } if (block_vector_count_ * meta_.element_size() % 32 != 0) { LOG_ERROR("block_vector_count * element_size not align with 32 bytes."); return IndexError_InvalidArgument; } return 0; } int IVFBuilder::prepare_trainer_params(ailego::Params ¶ms) { params.set(STRATIFIED_TRAINER_SAMPLE_COUNT, sample_count_); params.set(STRATIFIED_TRAINER_SAMPLE_RATIO, sample_ratio_); params.set(STRATIFIED_TRAINER_THREAD_COUNT, thread_count_); params.set(STRATIFIED_TRAINER_AUTOAUNE, cluster_auto_tuning_); if (centroid_num_vec_.empty()) { LOG_ERROR("Centroids no specified."); return IndexError_InvalidArgument; } std::string cluster_count = std::to_string(centroid_num_vec_[0]); if (centroid_num_vec_.size() > 1) { cluster_count += (CENTROID_SEPERATOR + std::to_string(centroid_num_vec_[1])); } params.set(STRATIFIED_TRAINER_CLUSTER_COUNT, cluster_count); for (size_t i = 1; i <= cluster_params_.size(); ++i) { std::string level_params_key = STRATIFIED_TRAINER_PARAMS_IN_LEVEL_PREFIX + std::to_string(i); params.set(level_params_key, cluster_params_[i - 1]); } params.set(STRATIFIED_TRAINER_CLASS_NAME, cluster_class_); return 0; } int IVFBuilder::build_label_index(IndexThreads *threads, const IndexHolder::Pointer &holder) { auto iter = holder->create_iterator(); if (!iter) { LOG_ERROR("Create iterator for holder failed"); return IndexError_Runtime; } auto task_group = threads->make_group(); if (!task_group) { LOG_ERROR("Failed to create task group"); return IndexError_Runtime; } size_t id = 0UL; AILEGO_DEFER([&]() { task_group->wait_finish(); stats_.set_built_count(id); LOG_INFO("Finished building, total=%zu", id); }); size_t elem_size = holder->element_size(); std::shared_ptr vectors = std::make_shared(); ivf_assert(vectors, IndexError_NoMemory); for (; iter && iter->is_valid(); iter->next()) { ivf_assert(!error_, err_code_); vectors->emplace_back(iter->data(), elem_size, id); id++; if (vectors->size() == kBatchSize || id == holder_->count()) { auto task = ailego::Closure ::New(const_cast(this), &IVFBuilder::label, vectors); task_group->submit(std::move(task)); vectors = std::make_shared(); ivf_assert(vectors, IndexError_NoMemory); vectors->reserve(kBatchSize); } if (!(id & 0xFFFFF)) { LOG_INFO("Current built count:%zu", id); } } ailego_assert_with(vectors->size() == 0, "invalid size"); return err_code_; } int IVFBuilder::dump_index(const IndexDumper::Pointer &dumper) { int ret = CheckAndUpdateMajorOrder(quantized_meta_); ivf_check_error_code(ret); IVFDumper::Pointer ivf_dumper = std::make_shared( quantized_meta_, dumper, centroid_index_->centroids_count(), block_vector_count_); if (!ivf_dumper) { LOG_ERROR("Alloc IVFDumper failed"); return IndexError_NoMemory; } //! Dump inverted vectors std::vector dumped_ids; std::function record_dumped_id = [&](uint32_t) {}; if (store_original_features_) { dumped_ids.reserve(holder_->count()); record_dumped_id = [&](uint32_t id) { dumped_ids.emplace_back(id); }; } if (quantizers_.size() == 0) { //! No quantizer for inverted vectors for (size_t i = 0; i < centroid_index_->centroids_count(); ++i) { ailego_assert_with(i < labels_.size(), "Index Overflow"); for (size_t j = 0; j < labels_[i].size(); ++j) { auto id = labels_[i][j]; record_dumped_id(id); ret = ivf_dumper->dump_inverted_vector(i, holder_->key(id), holder_->element(id)); ivf_check_error_code(ret); } } } else { for (size_t i = 0; i < centroid_index_->centroids_count(); ++i) { ailego_assert_with(i < labels_.size(), "Index Overflow"); auto holder = std::make_shared(holder_, labels_[i]); if (!holder) { return IndexError_NoMemory; } auto quantizer = quantize_by_centroid_ ? quantizers_[i] : quantizers_[0]; ret = quantizer->transform(holder); ivf_check_error_code(ret); auto iter = quantizer->result()->create_iterator(); for (; iter->is_valid(); iter->next()) { uint32_t id = iter->key(); record_dumped_id(id); ret = ivf_dumper->dump_inverted_vector(i, holder_->key(id), iter->data()); ivf_check_error_code(ret); } } } ret = ivf_dumper->dump_inverted_vector_finished(); ivf_check_error_code(ret); ret = ivf_dumper->dump_quantizer_params(quantizers_); ivf_check_error_code(ret); auto centroid_index = searcher_centroid_index_ ? searcher_centroid_index_ : centroid_index_; ret = ivf_dumper->dump_centroid_index(centroid_index->data(), centroid_index->size()); ivf_check_with_msg(ret, "Failed to dump CentroidIndex"); if (store_original_features_) { for (size_t i = 0; i < dumped_ids.size(); ++i) { ret = ivf_dumper->dump_original_vector(holder_->element(dumped_ids[i]), holder_->element_size()); ivf_check_error_code(ret); } } stats_.set_dumped_count(stats_.dumped_count() + ivf_dumper->dumped_count()); return 0; } int IVFBuilder::prepare_quantizer(IndexThreads *threads) { std::string quantizer_name; params_.get(PARAM_IVF_BUILDER_QUANTIZER_CLASS, &quantizer_name); if (quantizer_name.empty()) { return 0; } //! Prepare Quantizers for inverted index ailego::Params quantizer_params; params_.get(PARAM_IVF_BUILDER_QUANTIZER_PARAMS, &quantizer_params); if (((quantizer_name != kInt8QuantizerName && quantizer_name != kInt4QuantizerName) || meta_.metric_name() != kIPMetricName) && quantize_by_centroid_) { LOG_WARN("%s is supported in InnerProduct only", PARAM_IVF_BUILDER_QUANTIZE_BY_CENTROID.c_str()); quantize_by_centroid_ = false; } if (quantizer_name == kInt4QuantizerName && meta_.dimension() & 0x1) { LOG_ERROR("Unsupport quantizer=%s for dim=%u", kInt4QuantizerName, meta_.dimension()); return IndexError_Unsupported; } int ret = 0; auto create_and_init_quantizer = [&]() { auto quantizer = IndexFactory::CreateConverter(quantizer_name); if (!quantizer) { LOG_ERROR("Failed to create converter %s", quantizer_name.c_str()); ret = IndexError_NoExist; return IndexConverter::Pointer(); } ret = quantizer->init(meta_, quantizer_params); if (ret != 0) { LOG_ERROR("Failed to initialize converter %s for %s", quantizer_name.c_str(), IndexError::What(ret)); return IndexConverter::Pointer(); } return quantizer; }; for (size_t i = 0; i < centroid_index_->centroids_count(); ++i) { quantizers_.emplace_back(create_and_init_quantizer()); ivf_check_error_code(ret); if (!quantize_by_centroid_) { break; } } //! Train the quantizers auto train_data = [&](size_t i) { IndexHolder::Pointer holder = holder_; size_t idx = 0; if (quantize_by_centroid_) { holder = std::make_shared(holder_, labels_[i]); if (!holder && !error_.exchange(true)) { err_code_ = IndexError_NoMemory; return; } idx = i; } if (holder->count() == 0) { return; } ret = quantizers_[idx]->train(holder); if (ret != 0) { LOG_ERROR("Failed to train converter %s for %s", quantizer_name.c_str(), IndexError::What(ret)); if (!error_.exchange(true)) { err_code_ = IndexError_Runtime; } } }; auto task_group = threads->make_group(); if (!task_group) { LOG_ERROR("Failed to create task group"); return IndexError_Runtime; } for (size_t i = 0; i < quantizers_.size(); ++i) { if (error_) { task_group->wait_finish(); return err_code_; } task_group->submit(ailego::Closure ::New(train_data, i)); } task_group->wait_finish(); if (quantizers_.size() > 0) { quantized_meta_ = quantizers_[0]->meta(); } return 0; } INDEX_FACTORY_REGISTER_BUILDER(IVFBuilder); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_builder.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "ivf_centroid_index.h" namespace zvec { namespace core { /*! IVF Builder */ class IVFBuilder : public IndexBuilder { public: //! Constructor IVFBuilder(); //! Destructor ~IVFBuilder(); //! Disable them IVFBuilder(const IVFBuilder &) = delete; IVFBuilder &operator=(const IVFBuilder &) = delete; public: //! Initialize the builder virtual int init(const IndexMeta &meta, const ailego::Params ¶ms) override; //! Cleanup the builder virtual int cleanup(void) override; //! Train the data virtual int train(IndexThreads::Pointer threads, IndexHolder::Pointer holder) override; //! Train the data virtual int train(const IndexTrainer::Pointer &trainer) override; //! Build the index virtual int build(IndexThreads::Pointer threads, IndexHolder::Pointer holder) override; //! Dump index into file system virtual int dump(const IndexDumper::Pointer &dumper) override; //! Retrieve statistics virtual const Stats &stats(void) const override { return stats_; } IVFCentroidIndex::Pointer centroid_index() const { return centroid_index_; } public: /*! Random Access Index Holder */ class RandomAccessIndexHolder : public IndexHolder { public: //! Index Holder Iterator Pointer typedef std::shared_ptr Pointer; /*! Random Access Index Holder Iterator */ class Iterator : public IndexHolder::Iterator { public: //! Index Holder Iterator Pointer typedef std::unique_ptr Pointer; //! Constructor Iterator(RandomAccessIndexHolder *owner) : holder_(owner) {} //! Destructor virtual ~Iterator(void) {} //! Retrieve pointer of data virtual const void *data(void) const override { return holder_->element(id_); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return id_ < holder_->count(); } //! Retrieve primary key virtual uint64_t key(void) const override { return holder_->key(id_); } //! Next iterator virtual void next(void) override { ++id_; } private: //! Members RandomAccessIndexHolder *holder_{nullptr}; uint32_t id_{0}; }; //! Constructor RandomAccessIndexHolder(const IndexMeta &meta) : features_(std::make_shared(meta)) {} //! Retrieve count of elements in holder (-1 indicates unknown) virtual size_t count(void) const override { return features_->count(); } //! Retrieve dimension virtual size_t dimension(void) const override { return features_->dimension(); } //! Retrieve type information virtual IndexMeta::DataType data_type(void) const override { return features_->data_type(); } //! Retrieve element size in bytes virtual size_t element_size(void) const override { return features_->element_size(); } //! Retrieve if it can multi-pass virtual bool multipass(void) const override { return true; } //! Create a new iterator virtual IndexHolder::Iterator::Pointer create_iterator(void) override { return IndexHolder::Iterator::Pointer( new RandomAccessIndexHolder::Iterator(this)); } void reserve(size_t elems) { features_->reserve(elems); keys_.reserve(elems); } //! Append an element into holder void emplace(uint64_t pkey, const void *vec) { features_->emplace(vec); keys_.emplace_back(pkey); } //! Retrieve feature via local id const void *element(size_t id) const { return features_->element(id); } //! Retrieve key via local id uint64_t key(size_t id) const { ailego_assert_with(id < keys_.size(), "Index Overflow"); return keys_[id]; } private: //! Disable them RandomAccessIndexHolder(void) = delete; //! Members CompactIndexFeatures::Pointer features_{}; std::vector keys_{}; }; private: /*! Wrapper of feature */ class Vector { public: typedef std::shared_ptr Pointer; Vector(const void *vec, size_t len, uint32_t idx) : vec_(reinterpret_cast(vec), len), id_{idx} {} const void *data() const { return vec_.data(); } size_t size() const { return vec_.size(); } uint32_t id(void) const { return id_; } private: std::string vec_{}; uint32_t id_{0u}; }; using VectorList = std::vector; //! Check MajorOrder in meta, and update the major order if needed int CheckAndUpdateMajorOrder(IndexMeta &meta); //! Parse params int parse_centroids_num(const ailego::Params ¶ms); int parse_clustering_params(const ailego::Params ¶ms); int parse_general_params(const ailego::Params ¶ms); //! Prepare params for trainer int prepare_trainer_params(ailego::Params ¶ms); //! Build the index int build_label_index(IndexThreads *threads, const IndexHolder::Pointer &holder); //! Dump the index to dumper int dump_index(const IndexDumper::Pointer &dumper); //! Prepare the quantizer for inverted index int prepare_quantizer(IndexThreads *threads); //! Quantize the centrods list int quantize_centroids(); //! Create converter and init with params static IndexConverter::Pointer CreateAndInitConverter( const IndexMeta &meta, const std::string &name, const ailego::Params ¶ms) { auto converter = IndexFactory::CreateConverter(name); if (!converter) { LOG_ERROR("Failed to create converter %s", name.c_str()); return IndexConverter::Pointer(); } int ret = converter->init(meta, params); if (ret != 0) { LOG_ERROR("Failed to initialize converter %s for %s", name.c_str(), IndexError::What(ret)); return IndexConverter::Pointer(); } return converter; } //! Select the nearest centroid id for the vector void label(const std::shared_ptr &vecs) { for (size_t i = 0; i < vecs->size(); ++i) { auto &vec = (*vecs)[i]; uint32_t centroid_idx = centroid_index_->search_nearest_centroid(vec.data(), vec.size()); if (centroid_idx == IVFCentroidIndex::kInvalidID) { LOG_ERROR("Failed to search nearest centroid in CentroidIndex"); if (!error_.exchange(true)) { err_code_ = IndexError_Runtime; } return; } ailego_assert_with(centroid_idx < labels_.size(), "Index Overflow"); mutex_.lock(); labels_[centroid_idx].emplace_back(vec.id()); mutex_.unlock(); } } private: //! Constants static constexpr size_t kThreadPoolQueueSize = 300u; static constexpr size_t kBatchSize = 10u; static constexpr size_t kDefaultBlockCount = 32u; enum BuilderState { INIT = 0, INITED = 1, TRAINED = 2, BUILT = 3 }; //! Members BuilderState state_{INIT}; Stats stats_{}; ailego::Params params_{}; IndexMeta meta_{}; std::vector centroid_num_vec_{}; std::string cluster_class_{}; std::string converter_class_{}; std::vector cluster_params_{}; std::vector> labels_{}; std::mutex mutex_{}; IVFCentroidIndex::Pointer centroid_index_{}; IVFCentroidIndex::Pointer searcher_centroid_index_{}; RandomAccessIndexHolder::Pointer holder_{}; IndexMeta converted_meta_{}; IndexConverter::Pointer converter_{}; IndexMeta quantized_meta_{}; std::vector quantizers_{}; std::atomic_bool error_{false}; int err_code_{0}; uint32_t thread_count_{0}; uint32_t sample_count_{0}; float sample_ratio_{0.0}; uint32_t block_vector_count_{kDefaultBlockCount}; bool cluster_auto_tuning_{false}; bool store_original_features_{false}; bool quantize_by_centroid_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_centroid_index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_centroid_index.h" #include #include #include "metric/metric_params.h" namespace zvec { namespace core { /*! Fake Trainer to supply centroids in bundle */ class FakeClusterTrainer : public IndexTrainer { public: //! Constructor FakeClusterTrainer(const IndexMeta &imeta, const IndexBundle::Pointer &bundle) : meta_(imeta), bundle_(bundle) {} //! Destructor ~FakeClusterTrainer(void) {} protected: //! Initialize Trainer virtual int init(const IndexMeta &, const ailego::Params &) override { return 0; } //! Cleanup Trainer virtual int cleanup(void) override { return 0; } //! Train the data virtual int train(IndexHolder::Pointer) override { return 0; } //! Train the data virtual int train(IndexThreads::Pointer, IndexHolder::Pointer) override { return 0; } //! Load index from file path or dir virtual int load(IndexStorage::Pointer) override { return 0; } //! Dump index into file path or dir virtual int dump(const IndexDumper::Pointer &) override { return 0; } //! Retrieve Index Meta virtual const IndexMeta &meta(void) const override { return meta_; } //! Retrieve statistics virtual const IndexTrainer::Stats &stats(void) const override { return stats_; } //! Retrieve the output indexes virtual IndexBundle::Pointer indexes(void) const override { return bundle_; } private: //! Members IndexMeta meta_{}; Stats stats_{}; IndexBundle::Pointer bundle_{}; }; /*! Int8QuantizerReformer for InnerProduct Measure */ class Int8QuantizerReformer4IP : public IndexReformer { public: //! Initialize Reformer virtual int init(const ailego::Params &) override { return 0; } //! Cleanup Reformer virtual int cleanup(void) override { return 0; } //! Load index from container virtual int load(IndexStorage::Pointer) override { return 0; } //! Unload index virtual int unload(void) override { return 0; } //! Transform query virtual int transform(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, std::string * /*out*/, IndexQueryMeta * /*ometa*/) const override { #if 0 size_t dim = qmeta.dimension(); out->resize(IndexMeta::ElementSizeof( IndexMeta::DataType::DT_INT8, dim)); ometa->set_meta(IndexMeta::DataType::DT_INT8, dim); const float *ivec = reinterpret_cast(query); int8_t *ovec = reinterpret_cast(&(*out)[0]); float abs_max = 0.0f; for (size_t i = 0; i < dim; ++i) { auto abs = std::abs(ivec[i]); if (abs > abs_max) { abs_max = abs; } } if (abs_max > 0.0f) { float scale = 127 / abs_max; for (size_t i = 0; i < dim; ++i) { ovec[i] = static_cast(std::round(ivec[i] * scale)); } } else { std::fill(ovec, ovec + dim, static_cast(1)); } return 0; #else return IndexError_NotImplemented; #endif } //! Transform queries virtual int transform(const void *query, const IndexQueryMeta &qmeta, uint32_t count, std::string *oquery, IndexQueryMeta *ometa) const override { size_t dim = qmeta.dimension(); oquery->resize(count * IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT8, dim)); ometa->set_meta(IndexMeta::DataType::DT_INT8, dim); const float *ivec = reinterpret_cast(query); int8_t *ovec = reinterpret_cast(&(*oquery)[0]); for (size_t q = 0; q < count; ++q) { float abs_max = 0.0f; const float *in = &ivec[q * dim]; int8_t *out = &ovec[q * dim]; for (size_t i = 0; i < dim; ++i) { auto abs = std::abs(in[i]); if (abs > abs_max) { abs_max = abs; } } if (abs_max > 0.0f) { float scale = 127 / abs_max; for (size_t i = 0; i < dim; ++i) { out[i] = static_cast(std::round(in[i] * scale)); } } else { std::fill(out, out + dim, static_cast(1)); } } return 0; } //! Normalize results virtual int normalize(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, IndexDocumentList & /*result*/) const override { return 0; } }; /*! Int4QuantizerReformer for InnerProduct Metric */ class Int4QuantizerReformer4IP : public IndexReformer { public: //! Initialize Reformer virtual int init(const ailego::Params &) override { return 0; } //! Cleanup Reformer virtual int cleanup(void) override { return 0; } //! Load index from container virtual int load(IndexStorage::Pointer) override { return 0; } //! Unload index virtual int unload(void) override { return 0; } //! Transform query virtual int transform(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, std::string * /*out*/, IndexQueryMeta * /*ometa*/) const override { return IndexError_NotImplemented; } //! Transform queries virtual int transform(const void *query, const IndexQueryMeta &qmeta, uint32_t count, std::string *oquery, IndexQueryMeta *ometa) const override { if (qmeta.dimension() & 0x1) { LOG_ERROR("Unsuuport dim=%u for transform", qmeta.dimension()); return IndexError_Unsupported; } size_t dim = qmeta.dimension(); oquery->resize(count * IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT4, dim)); ometa->set_meta(IndexMeta::DataType::DT_INT4, dim); const float *ivec = reinterpret_cast(query); uint8_t *ovec = reinterpret_cast(&(*oquery)[0]); for (size_t q = 0; q < count; ++q) { float abs_max = 0.0f; float max = -std::numeric_limits::max(); const float *in = &ivec[q * dim]; uint8_t *out = &ovec[q * dim / 2]; for (size_t i = 0; i < dim; ++i) { float abs = std::abs(in[i]); abs_max = std::max(abs_max, abs); max = std::max(max, in[i]); } if (abs_max > 0.0f) { float scale = ((7 * abs_max > 8 * max) ? 8 : 7) / abs_max; for (size_t i = 0; i < dim; i += 2) { auto v1 = static_cast(std::round(in[i] * scale)); auto v2 = static_cast(std::round(in[i + 1] * scale)); out[i / 2] = (static_cast(v1) << 4) | (static_cast(v2) & 0xF); } } else { std::fill(out, out + dim / 2, static_cast(9)); } } return 0; } //! Normalize results virtual int normalize(const void * /*query*/, const IndexQueryMeta & /*qmeta*/, IndexDocumentList & /*result*/) const override { return 0; } }; int IVFCentroidIndex::init(const IndexMeta &meta, const ailego::Params ¶ms) { meta_ = meta; params.get(PARAM_IVF_BUILDER_OPTIMIZER_CLASS, &builder_class_); params.get(PARAM_IVF_BUILDER_OPTIMIZER_PARAMS, &builder_params_); params.get(PARAM_IVF_SEARCHER_OPTIMIZER, &searcher_class_); params.get(PARAM_IVF_SEARCHER_OPTIMIZER_PARAMS, &searcher_params_); return 0; } int IVFCentroidIndex::search(const void *query, const IndexQueryMeta &qmeta, size_t count, IndexSearcher::Context::Pointer &ctx) { int ret = 0; if (reformer_) { std::string buffer; IndexQueryMeta ometa; ret = reformer_->transform(query, qmeta, count, &buffer, &ometa); if (ret != 0) { LOG_ERROR("Failed to transform querys by reformer"); return ret; } ret = searcher_->search_impl(buffer.data(), ometa, count, ctx); } else { ret = searcher_->search_impl(query, qmeta, count, ctx); } ivf_check_with_msg(ret, "Failed to search in centroid index for %s", IndexError::What(ret)); return 0; } uint32_t IVFCentroidIndex::search_nearest_centroid(const void *query, size_t len) { //! Called in building index precedure, so transform the query is needless if (len != meta_.element_size()) { LOG_ERROR("Invalid query size actual: %zu, expected: %u", len, meta_.element_size()); return kInvalidID; } thread_local IndexSearcher::Context::Pointer context( searcher_->create_context()); context->set_topk(1); IndexQueryMeta qmeta(meta_.data_type(), meta_.dimension()); int ret = searcher_->search_impl(query, qmeta, context); if (ret != 0 || context->result().empty()) { LOG_ERROR("Failed to search nearest centroid, with ret %d", ret); return kInvalidID; } return static_cast(context->result()[0].key()); } uint32_t IVFCentroidIndex::transform_and_search_nearest_centroid( const void *record, const IndexQueryMeta &rmeta, IndexSearcher::Context::Pointer &ctx) const { int ret = 0; if (reformer_) { std::string buffer; IndexQueryMeta ometa; ret = reformer_->convert(record, rmeta, &buffer, &ometa); if (ret != 0) { LOG_ERROR("Failed to transform querys by reformer"); return kInvalidID; } ret = searcher_->search_impl(buffer.data(), ometa, ctx); } else { ret = searcher_->search_impl(record, rmeta, ctx); } if (ret != 0 || ctx->result().empty()) { LOG_ERROR("Failed to search in centroid index for %s", IndexError::What(ret)); return kInvalidID; } return static_cast(ctx->result()[0].key()); } IndexHolder::Pointer IVFCentroidIndex::quantize_holder( const IndexHolder::Pointer &holder) { auto input = holder; if (meta_.reformer_name() == kMipsReformerName && meta_.metric_name() == kL2MetricName && (quantizer_class_ == kInt8QuantizerName || quantizer_class_ == kInt4QuantizerName)) { //! Reverse for Mips if do convert by integer quantizer auto reverse = IndexFactory::CreateConverter(kMipsRevConverterName); if (!reverse) { LOG_ERROR("Failed to create converter %s", kMipsRevConverterName); return nullptr; } ailego::Params params; auto p = meta_.reformer_params(); params.set(MIPS_REVERSE_CONVERTER_M_VALUE, p.get_as_uint32(MIPS_REFORMER_M_VALUE)); params.set(MIPS_REVERSE_CONVERTER_U_VALUE, p.get_as_float(MIPS_REFORMER_U_VALUE)); params.set(MIPS_REVERSE_CONVERTER_L2_NORM, p.get_as_uint32(MIPS_REFORMER_L2_NORM)); params.set(MIPS_REVERSE_CONVERTER_FORCED_SINGLE_FLOAT, p.get_as_float(MIPS_REFORMER_FORCED_HALF_FLOAT)); int ret = reverse->init(meta_, params); if (ret != 0) { LOG_ERROR("Fail to init converter %s", kMipsRevConverterName); return nullptr; } ret = IndexConverter::TrainAndTransform(reverse, holder); if (ret != 0) { LOG_ERROR("Fail to transform converter %s", kMipsRevConverterName); return nullptr; } input = reverse->result(); meta_ = reverse->meta(); meta_.set_metric(kIPMetricName, 0, ailego::Params()); meta_.set_reformer("", 0, ailego::Params()); } auto converter = IndexFactory::CreateConverter(quantizer_class_); if (!converter) { LOG_ERROR("Failed to create converter %s", quantizer_class_.c_str()); return nullptr; } int ret = converter->init(meta_, quantizer_params_); if (ret != 0) { LOG_ERROR("Fail to init converter %s", quantizer_class_.c_str()); return nullptr; } ret = IndexConverter::TrainAndTransform(converter, input); if (ret != 0) { LOG_ERROR("Fail to tranform converter %s", quantizer_class_.c_str()); return nullptr; } meta_ = converter->meta(); return converter->result(); } int IVFCentroidIndex::build_index( const IndexCluster::CentroidList ¢roid_list, const IndexDumper::Pointer &dumper) { IndexBuilder::Pointer builder = IndexFactory::CreateBuilder(builder_class_); if (!builder) { LOG_ERROR("Failed to create builder %s", builder_class_.c_str()); return IndexError_NoExist; } IndexHolder::Pointer holder = std::make_shared(meta_, centroid_list); if (!holder) { return IndexError_NoMemory; } if (holder->count() == 0) { LOG_ERROR("No centroids to build"); return IndexError_InvalidArgument; } centroids_count_ = holder->count(); //! Set default params if not given auto count = std::to_string( static_cast(std::ceil(std::sqrt(centroids_count_ / 10.0)))); // if (IsHcBuilder(builder_class_) && // !builder_params_.has(hc::PARAM_HC_BUILDER_CENTROID_COUNT)) { // builder_params_.set(hc::PARAM_HC_BUILDER_CENTROID_COUNT, count); // } else if (builder_class_ == "GcBuilder" && // !builder_params_.has(hc::PARAM_GC_BUILDER_CENTROID_COUNT)) { // builder_params_.set(hc::PARAM_GC_BUILDER_CENTROID_COUNT, count); // } if (!quantizer_class_.empty()) { holder = this->quantize_holder(holder); if (!holder) { return IndexError_Runtime; } } const auto name = builder_class_.c_str(); int ret = builder->init(meta_, builder_params_); ivf_check_with_msg(ret, "%s init failed, ret=%d", name, ret); // if (IsHcBuilder(builder_class_) && quantizer_class_.empty()) { // auto trainer = this->prepare_trainer(centroid_list); // ret = trainer ? builder->train(trainer) : builder->train(holder); // } else { // ret = builder->train(holder); // } ret = builder->train(holder); ivf_check_with_msg(ret, "%s train failed, ret=%d", name, ret); ret = builder->build(holder); ivf_check_with_msg(ret, "%s build failed, ret=%d", name, ret); ret = builder->dump(dumper); ivf_check_with_msg(ret, "%s dump failed, ret=%d", name, ret); ret = dumper->close(); ivf_check_error_code(ret); return 0; } int IVFCentroidIndex::build(const IndexCluster::CentroidList ¢roid_list) { index_building_ = true; //! Build and dump the index IndexDumper::Pointer dumper = IndexFactory::CreateDumper("MemoryDumper"); if (!dumper) { LOG_ERROR("Failed to create MemoryDumper"); return IndexError_NoExist; } path_ = IVFUtility::GenerateRandomPath(kTempralPathPrefix); int ret = dumper->create(path_); if (ret != 0) { LOG_ERROR("IndexDumper create path %s failed", path_.c_str()); return ret; } ret = this->build_index(centroid_list, dumper); ivf_check_error_code(ret); auto rope = IndexMemory::Instance()->open(path_); if (!rope) { LOG_ERROR("Open memory path %s failed.", path_.c_str()); return ret; } if (rope->count() != 1) { LOG_ERROR("Graph Rope block count not equal with 1."); return ret; } (*rope)[0].read(0, &data_, 0); size_ = (*rope)[0].size(); //! Load the index IndexStorage::Pointer container = IndexFactory::CreateStorage("MemoryReadStorage"); if (!container) { LOG_ERROR("Failed to create MemoryReadStorage"); return IndexError_NoExist; } ret = container->init(ailego::Params()); ivf_check_with_msg(ret, "Failed to initialize MemoryReadStorage for %s", IndexError::What(ret)); ret = container->open(path_, false); ivf_check_with_msg(ret, "Failed to load path in MemoryReadStorage for %s", IndexError::What(ret)); ailego::Params searcher_params; if (!searcher_class_.empty()) { searcher_params.set(PARAM_IVF_SEARCHER_OPTIMIZER, searcher_class_); } if (!searcher_params_.empty()) { searcher_params.set(PARAM_IVF_SEARCHER_OPTIMIZER_PARAMS, searcher_params_); } ret = this->load(container, searcher_params); ivf_check_with_msg(ret, "IVFCentroidIndex load failed with %s", IndexError::What(ret)); return 0; } int IVFCentroidIndex::load(const IndexStorage::Pointer &container, const ailego::Params params) { if (!container) { LOG_ERROR("Invalid container"); return IndexError_InvalidArgument; } int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } auto reformer_name = meta_.reformer_name(); if (!reformer_name.empty()) { LOG_DEBUG("Load CentroidIndex with reformer %s, metric %s", reformer_name.c_str(), meta_.metric_name().c_str()); if ((reformer_name == kInt8ReformerName || reformer_name == kInt4ReformerName) && meta_.metric_name() == kIPMetricName) { if (reformer_name == kInt8ReformerName) { reformer_ = std::make_shared(); } else { reformer_ = std::make_shared(); } if (!reformer_) { return IndexError_NoMemory; } } else { reformer_ = IndexFactory::CreateReformer(reformer_name); if (!reformer_) { LOG_ERROR("Failed to create reformer %s", reformer_name.c_str()); return IndexError_NoExist; } } ret = reformer_->init(meta_.reformer_params()); ivf_check_with_msg(ret, "Failed to initialize reformer %s", reformer_name.c_str()); } searcher_class_ = meta_.searcher_name(); params.get(PARAM_IVF_SEARCHER_OPTIMIZER, &searcher_class_); params.get(PARAM_IVF_SEARCHER_OPTIMIZER_PARAMS, &searcher_params_); searcher_ = IndexFactory::CreateSearcher(searcher_class_); if (!searcher_) { LOG_ERROR("Failed to create searcher %s", searcher_class_.c_str()); return IndexError_Runtime; } auto searcher_params = meta_.searcher_params(); searcher_params.merge(searcher_params_); ret = searcher_->init(searcher_params); ivf_check_with_msg(ret, "Failed to initialize searcher %s", searcher_class_.c_str()); IndexMetric::Pointer metric; if (index_building_) { // The searcher index metric should specified in building process, // otherwise the query_metric will be used in searching metric = IndexFactory::CreateMetric(meta_.metric_name()); ivf_assert_with_msg(metric, IndexError_NoExist, "Failed to create metric %s", meta_.metric_name().c_str()); ret = metric->init(meta_, meta_.metric_params()); ivf_check_with_msg(ret, "Failed to initialize metric"); } ret = searcher_->load(container, metric); ivf_check_with_msg(ret, "Failed to load searcher %s", searcher_class_.c_str()); return 0; } IndexTrainer::Pointer IVFCentroidIndex::prepare_trainer( const IndexCluster::CentroidList ¢roid_list) { IndexCluster::CentroidList level1_centroids; bool two_level = false; for (auto &it : centroid_list) { auto centroid = it; if (!centroid.subitems().empty()) { two_level = true; } centroid.mutable_subitems()->clear(); centroid.mutable_similars()->clear(); level1_centroids.emplace_back(centroid); } if (!two_level) { return IndexTrainer::Pointer(); } IndexBundle::Pointer bundle; IndexCluster::Serialize(meta_, level1_centroids, &bundle); return std::make_shared(meta_, bundle); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_centroid_index.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "ivf_params.h" #include "ivf_utility.h" namespace zvec { namespace core { /*! IVF Centroid Index */ class IVFCentroidIndex { public: typedef std::shared_ptr Pointer; //! Constructor IVFCentroidIndex(void) {} //! Destructor ~IVFCentroidIndex(void) { IndexMemory *instance = IndexMemory::Instance(); if (instance) { if (instance->has(path_)) { instance->remove(path_); } } } //! Initialize int init(const IndexMeta &meta, const ailego::Params ¶ms); //! Set Quantizer for the index void set_quantizer(const std::string &quantizer_name, ailego::Params &quantizer_params) { quantizer_class_ = quantizer_name; quantizer_params_ = quantizer_params; } //! Retrieve data address of the index const void *data(void) const { return data_; } //! Retrieve size of the index size_t size(void) const { return size_; } //! Create searcher context for centroid index IndexSearcher::Context::Pointer create_context() const { return searcher_ ? searcher_->create_context() : nullptr; } //! Similarity search int search(const void *query, const IndexQueryMeta &qmeta, size_t count, IndexSearcher::Context::Pointer &ctx); //! Search the nearest point, must be called in local thread pool uint32_t search_nearest_centroid(const void *query, size_t len); //! Transform Data and Search the nearest point, called while adding record uint32_t transform_and_search_nearest_centroid( const void *record, const IndexQueryMeta &rmeta, IndexSearcher::Context::Pointer &ctx) const; //! Build Centroid Index From Centroid List int build(const IndexCluster::CentroidList ¢roid_list); //! Load Centroid Index From container int load(const IndexStorage::Pointer &container, const ailego::Params params); //! Retrieve centroid count of the index size_t centroids_count(void) const { return centroids_count_; } //! Retrieve meta const IndexMeta &meta() const { return meta_; } //! Retrieve reformer of the index const IndexReformer::Pointer reformer(void) const { return reformer_; } static constexpr uint32_t kInvalidID = std::numeric_limits::max(); private: /*! Centroids IndexHolder */ class CentroidsIndexHolder : public IndexHolder { public: class Iterator : public IndexHolder::Iterator { public: //! Index Holder Iterator Pointer typedef std::unique_ptr Pointer; //! Constructor Iterator(std::vector *features) : features_(features) {} //! Destructor virtual ~Iterator(void) {} //! Retrieve pointer of data virtual const void *data(void) const override { return (*features_)[id_]; } //! Test if the iterator is valid virtual bool is_valid(void) const override { return id_ < features_->size(); } //! Retrieve primary key virtual uint64_t key(void) const override { return id_; } //! Next iterator virtual void next(void) override { ++id_; } private: //! Members std::vector *features_{nullptr}; uint32_t id_{0}; }; //! Constructor CentroidsIndexHolder(const IndexMeta &meta, const IndexCluster::CentroidList ¢roid_list) : dimension_(meta.dimension()), element_size_(meta.element_size()), data_type_(meta.data_type()) { using CentroidList = IndexCluster::CentroidList; std::function get_leaf_features = [&](const CentroidList ¢s) { if (cents.empty()) { return; } for (const auto &it : cents) { if (it.subitems().empty()) { features_.emplace_back(it.feature()); } else { get_leaf_features(it.subitems()); } } }; get_leaf_features(centroid_list); } //! Retrieve count of elements in holder (-1 indicates unknown) virtual size_t count(void) const override { return features_.size(); } //! Retrieve dimension virtual size_t dimension(void) const override { return dimension_; } //! Retrieve type information virtual IndexMeta::DataType data_type(void) const override { return data_type_; } //! Retrieve element size in bytes virtual size_t element_size(void) const override { return element_size_; } //! Retrieve if it can multi-pass virtual bool multipass(void) const override { return true; } //! Create a new iterator virtual IndexHolder::Iterator::Pointer create_iterator(void) override { return IndexHolder::Iterator::Pointer( new CentroidsIndexHolder::Iterator(&features_)); } private: //! Members std::vector features_{}; size_t dimension_{0}; size_t element_size_{0}; IndexMeta::DataType data_type_{IndexMeta::DataType::DT_UNDEFINED}; }; int build_index(const IndexCluster::CentroidList ¢roid_list, const IndexDumper::Pointer &dumper); //! Prepare trainer for clustering index IndexTrainer::Pointer prepare_trainer( const IndexCluster::CentroidList ¢roid_list); //! Quantize the centroid vectors in holder IndexHolder::Pointer quantize_holder(const IndexHolder::Pointer &holder); private: //! Constants constexpr static const char *kDefaultBuilder = "FlatBuilder"; constexpr static const char *kTempralPathPrefix = "IVF"; //! Members IndexMeta meta_{}; IndexSearcher::Pointer searcher_{}; IndexReformer::Pointer reformer_{}; std::string builder_class_{kDefaultBuilder}; std::string searcher_class_{}; std::string quantizer_class_{}; std::string path_{}; ailego::Params builder_params_{}; ailego::Params searcher_params_{}; ailego::Params quantizer_params_{}; const void *data_{}; size_t size_{}; size_t centroids_count_{0}; bool index_building_{false}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_distance_calculator.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_distance_calculator.h" #include namespace zvec { namespace core { IVFDistanceCalculator::IVFDistanceCalculator(const IndexMeta &meta, const IndexMetric::Pointer &metric, uint32_t block_vec_cnt) : metric_ptr_(metric), block_vec_cnt_(block_vec_cnt) { row_distance_ = metric->distance(); distanceXx1_ = metric->distance_matrix(block_vec_cnt, 1); distances_.resize(33); for (size_t b = 32; b != 0; b /= 2) { distances_[b] = metric->distance_matrix(block_vec_cnt, b); } element_size_ = meta.element_size(); dimension_ = meta.dimension(); if (meta.major_order() == IndexMeta::MajorOrder::MO_COLUMN) { column_major_order_ = true; } else { column_major_order_ = false; } } IVFDistanceCalculator::~IVFDistanceCalculator() { row_distance_ = nullptr; distanceXx1_ = nullptr; distances_.clear(); } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_distance_calculator.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace core { class IVFDistanceCalculator { public: typedef std::shared_ptr Pointer; //! Constructor IVFDistanceCalculator(const IndexMeta &meta, const IndexMetric::Pointer &metric, uint32_t block_vec_cnt); virtual ~IVFDistanceCalculator(); public: inline void query_centroids_distance(const void *query, size_t qnum, const void *feature, size_t fnum, float *distances); inline void query_centroids_distance(const void *query, const void *feature, size_t fnum, float *distances); inline void query_features_distance(const void *query, const void *feature, size_t fnum, float *distances); inline void query_features_distance(const void *query, const void *feature, bool column_major, size_t fnum, float *distances); protected: //! Row Major Distances -> Online inline void row_major_distance(const void *query, size_t qnum, const void *feature, size_t fnum, float *out); inline void row_major_distance(const void *query, const void *feature, size_t fnum, float *out); template inline void batch_query_centroids_distance(const void *query, const void *feature, size_t fnum, float *distances); protected: IndexMetric::Pointer metric_ptr_{}; IndexMetric::MatrixDistance row_distance_{nullptr}; IndexMetric::MatrixDistance distanceXx1_{nullptr}; std::vector distances_{}; size_t element_size_{0}; size_t dimension_{0}; uint32_t block_vec_cnt_{0}; bool column_major_order_{false}; }; void IVFDistanceCalculator::query_centroids_distance(const void *query, size_t qnum, const void *feature, size_t fnum, float *distances) { if (column_major_order_) { switch (qnum) { case 1: batch_query_centroids_distance<1>(query, feature, fnum, distances); break; case 16: batch_query_centroids_distance<16>(query, feature, fnum, distances); break; case 8: batch_query_centroids_distance<8>(query, feature, fnum, distances); break; case 4: batch_query_centroids_distance<4>(query, feature, fnum, distances); break; case 2: batch_query_centroids_distance<2>(query, feature, fnum, distances); break; case 32: batch_query_centroids_distance<32>(query, feature, fnum, distances); break; default: LOG_ERROR("Unsupported query num %zu.", qnum); break; } } else { const uint8_t *cur_query = reinterpret_cast(query); for (size_t q = 0; q < qnum; ++q) { this->row_major_distance(cur_query, feature, fnum, distances); cur_query += element_size_; distances += block_vec_cnt_; } } } void IVFDistanceCalculator::query_centroids_distance(const void *query, const void *feature, size_t fnum, float *distances) { this->query_features_distance(query, feature, fnum, distances); } void IVFDistanceCalculator::query_features_distance(const void *query, const void *feature, size_t fnum, float *distances) { if (column_major_order_) { if (fnum == block_vec_cnt_) { distanceXx1_(feature, query, dimension_, distances); } else { this->row_major_distance(query, feature, fnum, distances); } } else { this->row_major_distance(query, feature, fnum, distances); } } void IVFDistanceCalculator::query_features_distance(const void *query, const void *feature, bool column_major, size_t fnum, float *distances) { if (column_major) { ailego_assert_with(fnum == block_vec_cnt_, "Invalid Block"); distanceXx1_(feature, query, dimension_, distances); } else { this->row_major_distance(query, feature, fnum, distances); } } template void IVFDistanceCalculator::batch_query_centroids_distance(const void *query, const void *feature, size_t fnum, float *distances) { if (fnum == block_vec_cnt_) { distances_[Q](feature, query, dimension_, distances); } else { row_major_distance(query, Q, feature, fnum, distances); } } void IVFDistanceCalculator::row_major_distance(const void *query, size_t qnum, const void *feature, size_t fnum, float *out) { const uint8_t *cur_query = reinterpret_cast(query); for (size_t q = 0; q < qnum; ++q) { const uint8_t *tmp_feature = reinterpret_cast(feature); float *cur_out = out + q * fnum; for (size_t f = 0; f < fnum; ++f) { row_distance_(cur_query, tmp_feature, dimension_, cur_out + f); tmp_feature += element_size_; } cur_query += element_size_; } } void IVFDistanceCalculator::row_major_distance(const void *query, const void *feature, size_t fnum, float *out) { const uint8_t *cur_feature = reinterpret_cast(feature); for (size_t f = 0; f < fnum; ++f) { row_distance_(query, cur_feature, dimension_, out + f); cur_feature += element_size_; } } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_dumper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_dumper.h" namespace zvec { namespace core { int IVFDumper::dump_inverted_vector(uint32_t inverted_list_id, uint64_t key, const void *vec) { int ret = this->check_dump_inverted_list(inverted_list_id); ivf_check_error_code(ret); ++inverted_lists_meta_[cur_list_id_].vector_count; ++header_.total_vector_count; block_.emplace(key, vec, IndexMeta::MajorOrder::MO_ROW); if (block_.full()) { ret = this->dump_block(); ivf_check_error_code(ret); } return 0; } int IVFDumper::dump_inverted_block(uint32_t inverted_list_id, const uint64_t *keys, const void *vecs, uint32_t vector_count, bool column_major) { int ret = this->check_dump_inverted_list(inverted_list_id); ivf_check_error_code(ret); if (block_.match_order(column_major ? IndexMeta::MajorOrder::MO_COLUMN : IndexMeta::MajorOrder::MO_ROW) && vector_count == block_.capacity()) { // Dump the block directly size_t size = vector_count * meta_.element_size(); size_t pd_size = ailego_align(size, 32) - size; if (dumper_->write(vecs, size) != size) { LOG_ERROR("Failed to write data into dumper %s", dumper_->name().c_str()); return IndexError_WriteData; } if (pd_size > 0) { std::string padding(pd_size, '\0'); if (dumper_->write(padding.data(), pd_size) != pd_size) { return IndexError_WriteData; } } std::copy(keys, keys + vector_count, std::back_inserter(keys_)); ++inverted_lists_meta_[cur_list_id_].block_count; ++header_.block_count; header_.inverted_body_size += size; } else { size_t step_size = meta_.element_size(); if (column_major) { step_size = IndexMeta::AlignSizeof(meta_.data_type()); } for (size_t i = 0; i < vector_count; ++i) { auto v = reinterpret_cast(vecs) + i * step_size; block_.emplace(keys[i], v, column_major ? IndexMeta::MajorOrder::MO_COLUMN : IndexMeta::MajorOrder::MO_ROW); if (block_.full()) { ret = this->dump_block(); ivf_check_error_code(ret); } } } inverted_lists_meta_[cur_list_id_].vector_count += vector_count; header_.total_vector_count += vector_count; return 0; } int IVFDumper::dump_container_segment(const IndexStorage::Pointer &container, const std::string &segmemt_id) { auto seg = container->get(segmemt_id, 2); if (!seg) { LOG_ERROR("Failed to fetch segment %s from %s", segmemt_id.c_str(), container->name().c_str()); return IndexError_InvalidFormat; } const size_t batch_size = 32 * 1024; const size_t total_size = seg->data_size() + seg->padding_size(); size_t off = 0; while (off < total_size) { const void *data = nullptr; size_t rd_size = std::min(batch_size, total_size - off); if (seg->read(off, &data, rd_size) != rd_size) { LOG_ERROR("Failed to read data, off=%zu size=%zu", off, rd_size); return IndexError_ReadData; } if (dumper_->write(data, rd_size) != rd_size) { LOG_ERROR("Failed to write data, size=%zu", rd_size); return IndexError_WriteData; } off += rd_size; } int ret = dumper_->append(segmemt_id, seg->data_size(), seg->padding_size(), seg->data_crc()); ivf_check_with_msg(ret, "Failed to append %s", segmemt_id.c_str()); dumped_size_ += total_size; return 0; } int IVFDumper::dump_inverted_vector_finished(void) { //! Dump Inverted Index Segment if (!block_.empty()) { int ret = this->dump_block(); ivf_check_error_code(ret); } header_.block_size = block_.block_size(); size_t segment_size = header_.inverted_body_size; int ret = dumper_->append(IVF_INVERTED_BODY_SEG_ID, segment_size, 0, 0); if (ret != 0) { LOG_ERROR("Failed to append to segment %s, ret=%d", IVF_INVERTED_BODY_SEG_ID.c_str(), ret); return ret; } dumped_size_ += segment_size; //! Dump Inverted Index Header Segment std::string str; meta_.serialize(&str); header_.header_size = sizeof(header_) + str.size(); header_.index_meta_size = str.size(); header_.inverted_list_count = inverted_lists_meta_.size(); if (dumper_->write(&header_, sizeof(header_)) != sizeof(header_)) { LOG_ERROR("Failed to write data, size %zu", sizeof(header_)); return IndexError_WriteData; } if (dumper_->write(str.data(), str.size()) != str.size()) { LOG_ERROR("Failed to write data, size %zu", str.size()); return IndexError_WriteData; } size_t padding_size = 0; ret = this->dump_padding(header_.header_size, &padding_size); ivf_check_error_code(ret); ret = dumper_->append(IVF_INVERTED_HEADER_SEG_ID, header_.header_size, padding_size, 0); if (ret != 0) { LOG_ERROR("Failed to append to segment %s, ret:%d", IVF_INVERTED_HEADER_SEG_ID.c_str(), ret); return ret; } dumped_size_ += header_.header_size + padding_size; LOG_DEBUG( "Dump header info: blocks=%u block_size=%u block_vec_count=%u " "inverted_list_count=%u total_vecs=%u inverted_size=%zu", header_.block_count, header_.block_size, header_.block_vector_count, header_.inverted_list_count, header_.total_vector_count, static_cast(header_.inverted_body_size)); //! Dump Inverted Lists Meta Segment segment_size = inverted_lists_meta_.size() * sizeof(InvertedListMeta); ret = this->dump_segment(IVF_INVERTED_META_SEG_ID, inverted_lists_meta_.data(), segment_size); ivf_check_error_code(ret); //! Dump Keys Segment ret = this->dump_segment(IVF_KEYS_SEG_ID, keys_.data(), keys_.size() * sizeof(keys_[0])); ivf_check_error_code(ret); //! Dump Mapping Segment auto mapping = std::make_shared>(); IVFUtility::Sort(keys_.data(), mapping.get(), keys_.size()); ret = this->dump_segment(IVF_MAPPING_SEG_ID, mapping->data(), mapping->size() * sizeof(uint32_t)); ivf_check_error_code(ret); mapping.reset(); //! Dump the Offsets Segment return this->dump_offsets_segment(); } int IVFDumper::dump_centroid_index(const void *data, size_t size) { int ret = this->dump_segment(IVF_CENTROID_SEG_ID, data, size); ivf_check_error_code(ret); return 0; } int IVFDumper::dump_quantizer_params( const std::vector &quantizers) { if (meta_.reformer_name() != kInt8ReformerName && meta_.reformer_name() != kInt4ReformerName) { // IntegerQuantizer params is support only return 0; } if (quantizers.size() == 1) { //! Donot dump, using reformer params in IndexMeta return 0; } if (quantizers.size() != header_.inverted_list_count) { LOG_ERROR("Mismatch size, quantizers=%zu, inverted_list_count=%u", quantizers.size(), header_.inverted_list_count); return IndexError_Logic; } bool int8_quantizer = meta_.reformer_name() == kInt8ReformerName; std::vector params; params.resize(header_.inverted_list_count); for (size_t i = 0; i < quantizers.size(); ++i) { auto &p = quantizers[i]->meta().reformer_params(); auto &scale_key = int8_quantizer ? INT8_QUANTIZER_REFORMER_SCALE : INT4_QUANTIZER_REFORMER_SCALE; auto &bias_key = int8_quantizer ? INT8_QUANTIZER_REFORMER_BIAS : INT4_QUANTIZER_REFORMER_BIAS; if (inverted_lists_meta_[i].vector_count > 0 && (!p.has(scale_key) || !p.has(bias_key))) { LOG_ERROR("Miss reformer params %s or %s", bias_key.c_str(), scale_key.c_str()); return IndexError_Logic; } params[i].bias = p.get_as_float(bias_key); params[i].scale = p.get_as_float(scale_key); } return this->dump_segment( int8_quantizer ? IVF_INT8_QUANTIZED_PARAMS_SEG_ID : IVF_INT4_QUANTIZED_PARAMS_SEG_ID, params.data(), params.size() * sizeof(InvertedIntegerQuantizerParams)); } int IVFDumper::dump_original_vector(const void *data, size_t size) { if (dumped_feature_count_ >= header_.total_vector_count) { LOG_ERROR("Dump too much orignal features, expect=%u", header_.total_vector_count); return IndexError_Logic; } if (dumper_->write(data, size) != size) { LOG_ERROR("Dumper write features failed"); return IndexError_WriteData; } dumped_features_size_ += size; ++dumped_feature_count_; if (dumped_feature_count_ == header_.total_vector_count) { //! Dump features finished, dump the meta size_t padding_size = 0; int ret = this->dump_padding(size, &padding_size); ivf_check_error_code(ret); ret = dumper_->append(IVF_FEATURES_SEG_ID, dumped_features_size_, padding_size, 0); if (ret != 0) { LOG_ERROR("Dumper append segment %s failed, ret:%d", IVF_FEATURES_SEG_ID.c_str(), ret); return ret; } dumped_size_ += dumped_features_size_; } return 0; } int IVFDumper::check_dump_inverted_list(uint32_t inverted_list_id) { if (inverted_list_id < cur_list_id_) { LOG_ERROR("Invalid backward vector dumping, want=%u cur=%u", inverted_list_id, cur_list_id_); return IndexError_Logic; } if (inverted_list_id >= inverted_lists_meta_.size()) { LOG_ERROR("Invalid inverted_list_id=%u, lists_size=%zu", inverted_list_id, inverted_lists_meta_.size()); return IndexError_Logic; } if (inverted_list_id != cur_list_id_) { //! flush previous inverted_list block int ret = this->dump_block(); ivf_check_error_code(ret); for (auto idx = cur_list_id_ + 1; idx <= inverted_list_id; ++idx) { inverted_lists_meta_[idx].offset = header_.inverted_body_size; inverted_lists_meta_[idx].id_offset = header_.total_vector_count; } cur_list_id_ = inverted_list_id; } return 0; } int IVFDumper::dump_offsets_segment(void) const { bool col_pri = meta_.major_order() == IndexMeta::MajorOrder::MO_COLUMN; size_t total_size = 0; for (size_t i = 0; i < inverted_lists_meta_.size(); ++i) { std::vector offsets; const auto &m = inverted_lists_meta_[i]; size_t vec_cnt = m.vector_count; size_t idx = 0; uint64_t off = m.offset; size_t align_idx = vec_cnt - vec_cnt % block_vector_count_; for (size_t j = 0; j < vec_cnt; ++j) { if (col_pri && j < align_idx) { offsets.emplace_back(off + idx * block_.align_size(), true); } else { offsets.emplace_back(off + idx * block_.element_size(), false); } ++idx; if (idx == block_vector_count_) { off += header_.block_size; idx = 0; } } if (idx != 0) { off += (vec_cnt - align_idx) * meta_.element_size(); } size_t len = offsets.size() * sizeof(offsets[0]); size_t actual_len = dumper_->write(offsets.data(), len); if (actual_len != len) { LOG_ERROR("Write offsets failed expect %zu, actual: %zu.", len, actual_len); return IndexError_WriteData; } total_size += len; } size_t padding_size = 0; int ret = this->dump_padding(total_size, &padding_size); ivf_check_error_code(ret); ret = dumper_->append(IVF_OFFSETS_SEG_ID, total_size, padding_size, 0); if (ret != 0) { LOG_ERROR("Dumper append segment %s failed, ret:%d", IVF_OFFSETS_SEG_ID.c_str(), ret); return ret; } dumped_size_ += total_size + padding_size; return 0; } int IVFDumper::dump_segment(const std::string &segment_id, const void *data, size_t size) const { size_t len = dumper_->write(data, size); if (len != size) { LOG_ERROR("Dump segment %s data failed, expect=%zu, actual=%zu", segment_id.c_str(), size, len); return IndexError_WriteData; } size_t padding_size = 0; int ret = this->dump_padding(size, &padding_size); ivf_check_error_code(ret); uint32_t crc = ailego::Crc32c::Hash(data, size); ret = dumper_->append(segment_id, size, padding_size, crc); if (ret != 0) { LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); return ret; } dumped_size_ += size + padding_size; return 0; } int IVFDumper::dump_padding(size_t data_size, size_t *padding_size) const { *padding_size = IVFUtility::AlignedSize(data_size) - data_size; if (*padding_size == 0) { return 0; } std::string padding(*padding_size, '\0'); if (dumper_->write(padding.data(), *padding_size) != *padding_size) { LOG_ERROR("Append padding failed, size %lu", *padding_size); return IndexError_WriteData; } return 0; } int IVFDumper::dump_block(void) { if (block_.empty()) { return 0; } size_t size = ailego_align(block_.bytes(), 32); if (dumper_->write(block_.data(), size) != size) { LOG_ERROR("Failed to write data into dumper %s", dumper_->name().c_str()); return IndexError_WriteData; } auto &keys = block_.keys(); std::copy(keys.begin(), keys.end(), std::back_inserter(keys_)); ++inverted_lists_meta_[cur_list_id_].block_count; ++header_.block_count; header_.inverted_body_size += size; block_.clear(); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_dumper.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "metric/metric_params.h" #include "ivf_index_format.h" #include "ivf_params.h" #include "ivf_utility.h" namespace zvec { namespace core { /*! Quantized Clustering Dumper */ class IVFDumper { public: typedef std::shared_ptr Pointer; //! Vectors block class Block { public: //! Initialize block void init(const IndexMeta &meta, uint32_t max_vec_count) { element_size_ = meta.element_size(); auto bsize = IVFUtility::AlignedSize(max_vec_count, element_size_); data_.resize(bsize); count_ = 0u; major_order_ = meta.major_order(); align_size_ = IndexMeta::AlignSizeof(meta.data_type()); units_ = element_size_ / align_size_; max_vec_count_ = max_vec_count; keys_.reserve(max_vec_count_); } //! Add a vector to the block in row major order //! If the block is full and the block order is column, make a //! transpose void emplace(uint64_t key, const void *vec, IndexMeta::MajorOrder order) { switch (align_size_) { case 2: do_emplace(vec, order); break; case 4: do_emplace(vec, order); break; case 8: do_emplace(vec, order); break; default: ailego_check_with(false, "Unsupport Aligned Size"); } keys_.emplace_back(key); } bool full(void) const { return count_ == max_vec_count_; } const void *data(void) const { return data_.data(); } void clear(void) { count_ = 0u; keys_.clear(); } bool empty(void) const { return count_ == 0u; } size_t size(void) const { return count_; } size_t capacity(void) const { return max_vec_count_; } size_t align_size(void) const { return align_size_; } size_t element_size(void) const { return element_size_; } //! Retrieve block data size size_t bytes(void) const { return element_size_ * count_; } //! Retrieve max block size in bytes size_t block_size(void) const { return data_.size(); } IndexMeta::MajorOrder major_order(void) const { return major_order_; } const std::vector &keys(void) const { return keys_; } bool match_order(IndexMeta::MajorOrder column_major) const { return major_order_ == column_major; } private: //! Transpose the block vectors void transpose() { std::vector buf(data_.size()); IVFUtility::Transpose(align_size_, data_.data(), count_, units_, buf.data()); data_.swap(buf); } template void do_emplace(const void *vec, IndexMeta::MajorOrder order) { ailego_assert_with(count_ < max_vec_count_, "emplace a full block"); T *dst = reinterpret_cast(data_.data() + element_size_ * count_); const T *src = reinterpret_cast(vec); size_t step = order == IndexMeta::MO_ROW ? 1 : max_vec_count_; for (auto i = 0u; i < units_; ++i) { *dst = *src; dst++; src += step; } count_++; if (full() && major_order_ == IndexMeta::MO_COLUMN) { transpose(); } } private: //! Members std::vector data_{}; std::vector keys_{}; uint32_t count_{0u}; uint32_t units_{0u}; uint32_t align_size_{0u}; uint32_t element_size_{0u}; uint32_t max_vec_count_{0u}; IndexMeta::MajorOrder major_order_{}; }; //! Constructor IVFDumper(const IndexMeta &meta, const IndexDumper::Pointer &dumper, size_t inverted_list_count, size_t block_vector_count) : meta_(meta), dumper_(dumper), block_vector_count_(block_vector_count), inverted_lists_meta_(inverted_list_count) { block_.init(meta, block_vector_count_); header_.block_vector_count = block_vector_count_; } //! Constructor IVFDumper(const IndexMeta &meta, const IndexDumper::Pointer &dumper, size_t inverted_list_count) : IVFDumper(meta, dumper, inverted_list_count, kDefaultBlockCount) {} //! Destructor ~IVFDumper() { // Check the dumper status if (dumped_feature_count_ > 0 && dumped_feature_count_ != header_.total_vector_count) { LOG_ERROR("Dumped features=%u mismatch from invertedVecs=%u", dumped_feature_count_, header_.total_vector_count); ailego_assert_with(false, "invalid status"); } } //! Dump a vector in row major order int dump_inverted_vector(uint32_t inverted_list_id, uint64_t key, const void *vec); int dump_inverted_block(uint32_t inverted_list_id, const uint64_t *keys, const void *vecs, uint32_t vector_count, bool column_major); //! Finish dump the inverted vectors int dump_inverted_vector_finished(void); //! Dump the centroids index int dump_centroid_index(const void *data, size_t size); //! Dump params for each inverted list quantizer int dump_quantizer_params( const std::vector &quantizers); //! Dump the original vector, which doesnot been quantized int dump_original_vector(const void *data, size_t size); //! Retrieve total dumped size size_t dumped_size(void) const { return dumped_size_; } //! Retrieve total dumped vector count size_t dumped_count(void) const { return header_.total_vector_count; } //! Dump the segment from container int dump_container_segment(const IndexStorage::Pointer &container, const std::string &segmemt_id); private: int check_dump_inverted_list(uint32_t inverted_list_id); //! Dump offsets segment int dump_offsets_segment(void) const; //! Dump a segment int dump_segment(const std::string &segment_id, const void *data, size_t size) const; //! Dump segment padding int dump_padding(size_t data_size, size_t *padding_size) const; //! Dump a vector block int dump_block(void); private: //! Constants static constexpr size_t kDefaultBlockCount = 32u; //! Members Block block_{}; // vectors grouped in block const IndexMeta meta_{}; // IndexMeta of the inverted index const IndexDumper::Pointer dumper_{}; size_t block_vector_count_{kDefaultBlockCount}; std::vector inverted_lists_meta_{}; std::vector keys_{}; InvertedIndexHeader header_{}; uint32_t cur_list_id_{0}; uint32_t dumped_feature_count_{0}; size_t dumped_features_size_{0}; mutable size_t dumped_size_{0}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_entity.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_entity.h" #include #include "ivf_utility.h" namespace zvec { namespace core { //! Initialize int IVFEntity::IVFReformerWrapper::init(const IndexMeta &imeta) { auto &name = imeta.reformer_name(); if (name.empty()) { type_ = kReformerTpNone; return 0; } auto reformer = IndexFactory::CreateReformer(name); if (!reformer) { LOG_ERROR("Failed to create reformer %s", name.c_str()); return IndexError_NoExist; } int ret = reformer->init(imeta.reformer_params()); ivf_check_with_msg(ret, "Failed to init reformer %s", name.c_str()); reformer_ = std::move(reformer); if (name == kInt8ReformerName) { if (imeta.metric_name() == kIPMetricName) { type_ = kReformerTpInnerProductInt8; return 0; } auto &key = INT8_QUANTIZER_REFORMER_SCALE; if (!imeta.reformer_params().has(key)) { LOG_ERROR("Missing param %s in reformer %s", key.c_str(), name.c_str()); return IndexError_InvalidArgument; }; float scale = imeta.reformer_params().get_as_float(key); reciprocal_ = scale == 0.0 ? 1.0 : (1.0 / scale); type_ = kReformerTpInt8; } else if (name == kInt4ReformerName) { if (imeta.metric_name() == kIPMetricName) { type_ = kReformerTpInnerProductInt4; return 0; } auto &key = INT4_QUANTIZER_REFORMER_SCALE; if (!imeta.reformer_params().has(key)) { LOG_ERROR("Missing param %s in reformer %s", key.c_str(), name.c_str()); return IndexError_InvalidArgument; }; float scale = imeta.reformer_params().get_as_float(key); reciprocal_ = scale == 0.0 ? 1.0 : (1.0 / scale); type_ = kReformerTpInt4; } else { type_ = kReformerTpDefault; } LOG_DEBUG("Init QcReformer with %s, type=%u", name.c_str(), type_); return 0; } //! Update the params, Called by gpu searcher only int IVFEntity::IVFReformerWrapper::update(const IndexMeta &meta) { auto &name = meta.reformer_name(); if (name == kInt4ReformerName && meta.metric_name() == kL2MetricName) { auto &key = INT4_QUANTIZER_REFORMER_SCALE; if (!meta.reformer_params().has(key)) { LOG_ERROR("Missing param %s in reformer %s", key.c_str(), name.c_str()); return IndexError_InvalidArgument; }; float scale = meta.reformer_params().get_as_float(key); reciprocal_ = scale == 0.0 ? 1.0 : (1.0 / scale / kNormalizeScaleFactor); type_ = kReformerTpInt8; ailego::Params params; float int8_scale = scale * kNormalizeScaleFactor; params.set(INT8_QUANTIZER_REFORMER_SCALE, int8_scale); float bias = meta.reformer_params().get_as_float(INT4_QUANTIZER_REFORMER_BIAS); params.set(INT8_QUANTIZER_REFORMER_BIAS, bias); params.set( INT4_QUANTIZER_REFORMER_METRIC, meta.reformer_params().get_as_string(INT4_QUANTIZER_REFORMER_METRIC)); auto reformer = IndexFactory::CreateReformer(kInt8ReformerName); if (!reformer) { LOG_ERROR("Failed to create reformer %s", name.c_str()); return IndexError_NoExist; } int ret = reformer->init(params); ivf_check_with_msg(ret, "Failed to init reformer %s", name.c_str()); reformer_ = reformer; LOG_DEBUG("Init QcReformer with %s, type=%u", name.c_str(), type_); } return 0; } //! Transform a query int IVFEntity::IVFReformerWrapper::transform(const void *query, const IndexQueryMeta &qmeta, const void **out, IndexQueryMeta *ometa) { int ret = 0; switch (type_) { case kReformerTpNone: *out = query; *ometa = qmeta; break; case kReformerTpInnerProductInt8: if (qmeta.data_type() != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } scales_.resize(1); buffer_.resize(IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT8, qmeta.dimension())); this->transform(0, static_cast(query), qmeta.dimension(), reinterpret_cast(&buffer_[0])); *ometa = qmeta; ometa->set_meta(IndexMeta::DataType::DT_INT8, qmeta.dimension()); *out = buffer_.data(); break; case kReformerTpInnerProductInt4: if (qmeta.data_type() != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } scales_.resize(1); buffer_.resize(IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT4, qmeta.dimension())); this->transform(0, static_cast(query), qmeta.dimension(), reinterpret_cast(&buffer_[0])); *ometa = qmeta; ometa->set_meta(IndexMeta::DataType::DT_INT4, qmeta.dimension()); *out = buffer_.data(); break; case kReformerTpInt8: case kReformerTpInt4: /* FALLTHRU */ case kReformerTpDefault: ret = reformer_->transform(query, qmeta, &buffer_, ometa); *out = buffer_.data(); break; default: ret = IndexError_Unsupported; break; } return ret; } //! Transform querys int IVFEntity::IVFReformerWrapper::transform(const void *query, const IndexQueryMeta &qmeta, uint32_t count, const void **out, IndexQueryMeta *ometa) { int ret = 0; switch (type_) { case kReformerTpNone: *out = query; *ometa = qmeta; break; case kReformerTpInnerProductInt8: if (qmeta.data_type() != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } scales_.resize(count); buffer_.resize(count * IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT8, qmeta.dimension())); { const float *ivec = reinterpret_cast(query); int8_t *ovec = reinterpret_cast(&buffer_[0]); for (size_t i = 0; i < count; ++i) { this->transform(i, &ivec[i * qmeta.dimension()], qmeta.dimension(), &ovec[i * qmeta.dimension()]); } } *ometa = qmeta; ometa->set_meta(IndexMeta::DataType::DT_INT8, qmeta.dimension()); *out = buffer_.data(); break; case kReformerTpInnerProductInt4: if (qmeta.data_type() != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } scales_.resize(count); buffer_.resize(count * IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT4, qmeta.dimension())); { const float *ivec = reinterpret_cast(query); uint8_t *ovec = reinterpret_cast(&buffer_[0]); for (size_t i = 0; i < count; ++i) { this->transform(i, &ivec[i * qmeta.dimension()], qmeta.dimension(), &ovec[i * qmeta.dimension() / 2]); } } *ometa = qmeta; ometa->set_meta(IndexMeta::DataType::DT_INT4, qmeta.dimension()); *out = buffer_.data(); break; case kReformerTpInt8: case kReformerTpInt4: /* FALLTHRU */ case kReformerTpDefault: ret = reformer_->transform(query, qmeta, count, &buffer_, ometa); *out = buffer_.data(); break; default: ret = IndexError_Unsupported; break; } return ret; } //! Transform querys int IVFEntity::IVFReformerWrapper::transform_gpu(const void *query, const IndexQueryMeta &qmeta, uint32_t count, const void **out, IndexQueryMeta *ometa) { int ret = 0; switch (type_) { case kReformerTpNone: case kReformerTpDefault: *out = query; *ometa = qmeta; break; case kReformerTpInnerProductInt4: case kReformerTpInnerProductInt8: if (qmeta.data_type() != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } scales_.resize(count); buffer_.resize(count * IndexMeta::ElementSizeof(IndexMeta::DataType::DT_INT8, qmeta.dimension())); { const float *ivec = reinterpret_cast(query); int8_t *ovec = reinterpret_cast(&buffer_[0]); for (size_t i = 0; i < count; ++i) { this->transform(i, &ivec[i * qmeta.dimension()], qmeta.dimension(), &ovec[i * qmeta.dimension()]); } } *ometa = qmeta; ometa->set_meta(IndexMeta::DataType::DT_INT8, qmeta.dimension()); *out = buffer_.data(); break; case kReformerTpInt8: case kReformerTpInt4: ret = reformer_->transform(query, qmeta, count, &buffer_, ometa); *out = buffer_.data(); break; default: ret = IndexError_Unsupported; break; } return ret; } //! Convert a record int IVFEntity::IVFReformerWrapper::convert(const void *record, const IndexQueryMeta &rmeta, const void **out, IndexQueryMeta *ometa) { if (type_ == kReformerTpNone) { *out = record; *ometa = rmeta; return 0; } int ret = reformer_->convert(record, rmeta, &buffer_, ometa); *out = buffer_.data(); return ret; } //! Convert records int IVFEntity::IVFReformerWrapper::convert(const void *records, const IndexQueryMeta &rmeta, uint32_t count, const void **out, IndexQueryMeta *ometa) { if (type_ == kReformerTpNone) { *out = records; *ometa = rmeta; return 0; } int ret = reformer_->convert(records, rmeta, count, &buffer_, ometa); *out = buffer_.data(); return ret; } //! Normalize score void IVFEntity::IVFReformerWrapper::normalize(size_t qidx, IndexDocumentHeap *heap) const { switch (type_) { case kReformerTpNone: return; case kReformerTpInnerProductInt8: case kReformerTpInnerProductInt4: ailego_assert_with(qidx < scales_.size(), "invalid index"); { auto reciprocal = 1.0f / scales_[qidx]; for (auto &it : *heap) { *it.mutable_score() *= reciprocal; } } break; case kReformerTpInt8: case kReformerTpInt4: for (auto &it : *heap) { *it.mutable_score() *= reciprocal_; } break; default: // Not support break; } } //! Normalize score void IVFEntity::IVFReformerWrapper::normalize(size_t qidx, const void *query, const IndexQueryMeta &qmeta, IndexDocumentHeap *heap) const { switch (type_) { case kReformerTpNone: return; case kReformerTpInnerProductInt8: case kReformerTpInnerProductInt4: ailego_assert_with(qidx < scales_.size(), "invalid index"); { auto reciprocal = 1.0f / scales_[qidx]; for (auto &it : *heap) { *it.mutable_score() *= reciprocal; } } break; case kReformerTpInt8: case kReformerTpInt4: for (auto &it : *heap) { *it.mutable_score() *= reciprocal_; } break; case kReformerTpDefault: reformer_->normalize(query, qmeta, *heap); break; default: // Not support LOG_ERROR("Not a supported type in QC reformer, type: %u", type_); break; } } void IVFEntity::IVFReformerWrapper::transform(size_t qidx, const float *in, size_t dim, int8_t *out) { ailego_assert_with(qidx < scales_.size(), "invalid index"); float abs_max = 0.0f; for (size_t i = 0; i < dim; ++i) { auto abs = std::abs(in[i]); if (abs > abs_max) { abs_max = abs; } } if (abs_max > 0.0f) { float scale = 127 / abs_max; for (size_t i = 0; i < dim; ++i) { out[i] = static_cast(std::round(in[i] * scale)); } scales_[qidx] = scale; } else { std::fill(out, out + dim, static_cast(1)); scales_[qidx] = std::numeric_limits::max(); } } void IVFEntity::IVFReformerWrapper::transform(size_t qidx, const float *in, size_t dim, uint8_t *out) { ailego_assert_with(qidx < scales_.size(), "invalid index"); ailego_assert_with(dim % 2 == 0, "invalid dim"); float abs_max = 0.0f; float max = -std::numeric_limits::max(); for (size_t i = 0; i < dim; ++i) { float abs = std::abs(in[i]); abs_max = std::max(abs_max, abs); max = std::max(max, in[i]); } if (abs_max > 0.0f) { float scale = ((7 * abs_max > 8 * max) ? 8 : 7) / abs_max; for (size_t i = 0; i < dim; i += 2) { auto v1 = static_cast(std::round(in[i] * scale)); auto v2 = static_cast(std::round(in[i + 1] * scale)); out[i / 2] = (static_cast(v1) & 0xF) | (static_cast(v2) << 4); } scales_[qidx] = scale; } else { std::fill(out, out + dim / 2, static_cast(9)); scales_[qidx] = std::numeric_limits::max(); } } int IVFEntity::load_header(const IndexStorage::Pointer &container) { //! Load the Header Segment auto header = container->get(IVF_INVERTED_HEADER_SEG_ID); if (!header) { LOG_ERROR("Failed to get segment %s", IVF_INVERTED_HEADER_SEG_ID.c_str()); return IndexError_InvalidFormat; } if (header->data_size() < sizeof(header_)) { LOG_ERROR("Invalid format for segment %s", IVF_INVERTED_HEADER_SEG_ID.c_str()); return IndexError_InvalidFormat; } const void *data = nullptr; if (header->read(0, &data, header->data_size()) != header->data_size()) { LOG_ERROR("Failed to read data, segment %s", IVF_INVERTED_HEADER_SEG_ID.c_str()); return IndexError_ReadData; } std::memcpy(&header_, data, sizeof(header_)); if (header_.header_size < sizeof(header_) + header_.index_meta_size || header_.header_size > header->data_size()) { LOG_ERROR("Invalid header size %u", header_.header_size); return IndexError_InvalidFormat; } //! Load the index meta if (!meta_.deserialize( reinterpret_cast(data) + sizeof(header_), header_.index_meta_size)) { LOG_ERROR("Failed to deserialize index meta"); return IndexError_InvalidFormat; } int ret = reformer_.init(meta_); ivf_check_error_code(ret); //! Create the distance calculator auto metric = IndexFactory::CreateMetric(meta_.metric_name()); if (!metric) { LOG_ERROR("Failed to create metric %s", meta_.metric_name().c_str()); return IndexError_NoExist; } ret = metric->init(meta_, meta_.metric_params()); if (ret != 0) { LOG_ERROR("Failed to initialize metric %s", meta_.metric_name().c_str()); return ret; } calculator_ = std::make_shared( meta_, metric->query_metric() ? metric->query_metric() : metric, header_.block_vector_count); if (!calculator_) { return IndexError_NoMemory; } return 0; } int IVFEntity::load(const IndexStorage::Pointer &container) { int ret = this->load_header(container); ivf_check_error_code(ret); //! Load the remaining segments container_ = container; size_t expect_size = header_.inverted_body_size; inverted_ = load_segment(IVF_INVERTED_BODY_SEG_ID, expect_size); if (!inverted_) { LOG_ERROR("Failed to load segment, inverted_size=%zu block_count=%u", static_cast(header_.inverted_body_size), header_.block_count); return IndexError_InvalidFormat; } expect_size = header_.inverted_list_count * sizeof(InvertedListMeta); inverted_meta_ = load_segment(IVF_INVERTED_META_SEG_ID, expect_size); if (!inverted_meta_) { LOG_ERROR("Failed to load segment, inverted_lists=%u", header_.inverted_list_count); return IndexError_InvalidFormat; } expect_size = header_.total_vector_count * sizeof(uint64_t); keys_ = load_segment(IVF_KEYS_SEG_ID, expect_size); if (!keys_) { return IndexError_InvalidFormat; } expect_size = header_.total_vector_count * sizeof(InvertedVecLocation); offsets_ = load_segment(IVF_OFFSETS_SEG_ID, expect_size); if (!offsets_) { return IndexError_InvalidFormat; } expect_size = header_.total_vector_count * sizeof(uint32_t); mapping_ = load_segment(IVF_MAPPING_SEG_ID, expect_size); if (!mapping_) { return IndexError_InvalidFormat; } norm_value_sqrt_ = meta_.metric_name() == "Euclidean" || meta_.metric_name() == "Manhattan"; if (container_->get(IVF_INT8_QUANTIZED_PARAMS_SEG_ID) || container->get(IVF_INT4_QUANTIZED_PARAMS_SEG_ID)) { expect_size = header_.inverted_list_count * sizeof(InvertedIntegerQuantizerParams); auto &seg_id = meta_.reformer_name() == kInt8ReformerName ? IVF_INT8_QUANTIZED_PARAMS_SEG_ID : IVF_INT4_QUANTIZED_PARAMS_SEG_ID; integer_quantizer_params_ = load_segment(seg_id, expect_size); if (!integer_quantizer_params_) { return IndexError_InvalidFormat; } norm_value_ = 0.0f; } else if (meta_.reformer_name() == kInt8ReformerName || meta_.reformer_name() == kInt4ReformerName) { auto &scale_key = meta_.reformer_name() == kInt8ReformerName ? INT8_QUANTIZER_REFORMER_SCALE : INT4_QUANTIZER_REFORMER_SCALE; auto scale = meta_.reformer_params().get_as_float(scale_key); norm_value_ = this->convert_to_normalize_value(scale); } else { norm_value_ = 1.0f; } if (container_->get(IVF_FEATURES_SEG_ID)) { features_ = load_segment(IVF_FEATURES_SEG_ID, 0); if (!features_) { return IndexError_InvalidFormat; } if (features_->data_size() % vector_count() != 0) { LOG_ERROR("Invalid featureSegment size=%zu, totalVecs=%zu", features_->data_size(), vector_count()); return IndexError_InvalidFormat; } } LOG_DEBUG( "Load inverted index done, docs=%u invertedListCnt=%u " "elementSize=%u metric=%s reformer=%s", header_.total_vector_count, header_.inverted_list_count, meta_.element_size(), meta_.metric_name().c_str(), meta_.reformer_name().c_str()); return 0; } int IVFEntity::search(size_t inverted_list_id, const void *query, const IndexFilter &filter, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { ailego_assert_with(inverted_list_id < header_.inverted_list_count, "invalid id"); auto list_meta = this->inverted_list_meta(inverted_list_id); ivf_assert(list_meta, IndexError_ReadData); const void *data = nullptr; const size_t block_vecs = header_.block_vector_count; std::vector distances(block_vecs); const size_t batch_size = kBatchBlocks; const size_t block_size = header_.block_size; const auto norm_val = this->inverted_list_normalize_value(inverted_list_id); for (size_t i = 0; i < list_meta->block_count; i += batch_size) { //! Read vecs const size_t off = list_meta->offset + i * block_size; const size_t blocks = std::min(batch_size, list_meta->block_count - i); const size_t size = std::min(blocks * block_size, static_cast(header_.inverted_body_size - off)); if (inverted_->read(off, &data, size) != size) { LOG_ERROR("Failed to read block, off=%zu, size=%zu", off, size); return IndexError_ReadData; } //! Read keys size_t items = std::min(blocks * block_vecs, list_meta->vector_count - (i * block_vecs)); auto keys = get_keys(list_meta->id_offset + i * block_vecs, items); if (!keys) { return IndexError_ReadData; } //! Compute distances for each block for (size_t b = 0; b < blocks; ++b) { const size_t vecs_count = std::min(block_vecs, list_meta->vector_count - (i + b) * block_vecs); auto block_keys = keys + b * block_vecs; size_t keeps = 0; ailego_assert_with(block_vecs < sizeof(keeps) * 8, "bits overflow"); for (size_t k = 0; k < vecs_count; ++k) { if (!filter(block_keys[k])) { keeps |= (1 << k); } else { ++(*context_stats->mutable_filtered_count()); } } if (keeps == 0) { continue; } const void *block_data = static_cast(data) + b * block_size; calculator_->query_features_distance(query, block_data, vecs_count, distances.data()); *(context_stats->mutable_dist_calced_count()) += vecs_count; uint32_t id_off = list_meta->id_offset + (i + b) * block_vecs; for (size_t k = 0; k < vecs_count; ++k) { if (keeps & (1 << k)) { if (block_keys[k] != kInvalidKey) { heap->emplace(block_keys[k], distances[k] * norm_val, id_off + k); } } } } } *scan_count = list_meta->vector_count; return 0; } //! search in inverted list without filter int IVFEntity::search(size_t inverted_list_id, const void *query, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { ailego_assert_with(inverted_list_id < header_.inverted_list_count, "invalid id"); auto list_meta = inverted_list_meta(inverted_list_id); ivf_assert(list_meta, IndexError_ReadData); const void *data = nullptr; const size_t block_vecs = header_.block_vector_count; std::vector distances(block_vecs); const size_t batch_size = kBatchBlocks; const size_t block_size = header_.block_size; const auto norm_val = this->inverted_list_normalize_value(inverted_list_id); for (size_t i = 0; i < list_meta->block_count; i += batch_size) { //! Read vecs const size_t off = list_meta->offset + i * block_size; const size_t blocks = std::min(batch_size, list_meta->block_count - i); const size_t size = std::min(blocks * block_size, static_cast(header_.inverted_body_size - off)); if (inverted_->read(off, &data, size) != size) { LOG_ERROR("Failed to read block, off=%zu, size=%zu", off, size); return IndexError_ReadData; } //! Read keys size_t items = std::min(blocks * block_vecs, list_meta->vector_count - (i * block_vecs)); auto keys = get_keys(list_meta->id_offset + i * block_vecs, items); if (!keys) { return IndexError_ReadData; } //! Compute distances for each block for (size_t b = 0; b < blocks; ++b) { const size_t vecs_count = std::min(block_vecs, list_meta->vector_count - (i + b) * block_vecs); auto block_keys = keys + b * block_vecs; const void *block_data = static_cast(data) + b * block_size; calculator_->query_features_distance(query, block_data, vecs_count, distances.data()); for (size_t k = 0; k < vecs_count; ++k) { if (block_keys[k] != kInvalidKey) { uint32_t id = list_meta->id_offset + (i + b) * block_vecs + k; heap->emplace(block_keys[k], distances[k] * norm_val, id); } } *(context_stats->mutable_dist_calced_count()) += vecs_count; } } *scan_count = list_meta->vector_count; return 0; } //! search all inverted list with filter int IVFEntity::search(const void *query, const IndexFilter &filter, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { for (size_t i = 0; i < header_.inverted_list_count; ++i) { uint32_t scan_count; int ret = this->search(i, query, filter, &scan_count, heap, context_stats); if (ret != 0) { return ret; } } return 0; } //! search all inverted list without filter int IVFEntity::search(const void *query, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const { for (size_t i = 0; i < header_.inverted_list_count; ++i) { uint32_t scan_count; int ret = this->search(i, query, &scan_count, heap, context_stats); if (ret != 0) { return ret; } } return 0; } const void *IVFEntity::get_vector(size_t id) const { if (features_) { const void *data = nullptr; size_t element_size = features_->data_size() / vector_count(); size_t off = id * element_size; if (features_->read(off, &data, element_size) != element_size) { LOG_ERROR("Failed to read segment, off=%zu size=%zu", off, element_size); return nullptr; } return data; } const void *data = nullptr; size_t size = sizeof(InvertedVecLocation); if (offsets_->read(id * size, &data, size) != size) { LOG_ERROR("Failed to read offsets segment, id=%zu", id); return nullptr; } auto &loc = *reinterpret_cast(data); if (loc.column_major) { vector_.resize(meta_.element_size()); auto unit_size = IndexMeta::AlignSizeof(meta_.data_type()); size_t cols = meta_.element_size() / unit_size; size_t step = block_vector_count() * unit_size; size_t rd_size = step * (cols - 1) + unit_size; if (inverted_->read(loc.offset, &data, rd_size) != rd_size) { LOG_ERROR("Failed to read data, off=%zu size=%zu", static_cast(loc.offset), rd_size); return nullptr; } for (size_t c = 0; c < cols; ++c) { vector_.replace(c * unit_size, unit_size, reinterpret_cast(data) + c * step, unit_size); } return vector_.data(); } else { if (inverted_->read(loc.offset, &data, meta_.element_size()) != meta_.element_size()) { LOG_ERROR("Failed to read data, off=%zu size=%u", static_cast(loc.offset), meta_.element_size()); return nullptr; } return data; } } int IVFEntity::get_vector(size_t id, IndexStorage::MemoryBlock &block) const { if (features_) { size_t element_size = features_->data_size() / vector_count(); size_t off = id * element_size; if (features_->read(off, block, element_size) != element_size) { LOG_ERROR("Failed to read segment, off=%zu size=%zu", off, element_size); return IndexError_Runtime; } return 0; } IndexStorage::MemoryBlock data_block; size_t size = sizeof(InvertedVecLocation); if (offsets_->read(id * size, data_block, size) != size) { LOG_ERROR("Failed to read offsets segment, id=%zu", id); return IndexError_Runtime; } const void *data = data_block.data(); auto &loc = *reinterpret_cast(data); if (loc.column_major) { vector_.resize(meta_.element_size()); auto unit_size = IndexMeta::AlignSizeof(meta_.data_type()); size_t cols = meta_.element_size() / unit_size; size_t step = block_vector_count() * unit_size; size_t rd_size = step * (cols - 1) + unit_size; if (inverted_->read(loc.offset, &data, rd_size) != rd_size) { LOG_ERROR("Failed to read data, off=%zu size=%zu", static_cast(loc.offset), rd_size); return IndexError_Runtime; } for (size_t c = 0; c < cols; ++c) { vector_.replace(c * unit_size, unit_size, reinterpret_cast(data) + c * step, unit_size); } block.reset(vector_.data()); return 0; } else { if (inverted_->read(loc.offset, block, meta_.element_size()) != meta_.element_size()) { LOG_ERROR("Failed to read data, off=%zu size=%u", static_cast(loc.offset), meta_.element_size()); return IndexError_Runtime; } return 0; } } uint32_t IVFEntity::key_to_id(uint64_t key) const { //! Do binary search uint32_t start = 0UL; uint32_t end = vector_count(); const void *data = nullptr; uint32_t idx = 0u; while (start < end) { idx = start + (end - start) / 2; if (ailego_unlikely(mapping_->read(idx * sizeof(uint32_t), &data, sizeof(uint32_t)) != sizeof(uint32_t))) { LOG_ERROR("Failed to read mapping segment, idx=%u", idx); return std::numeric_limits::max(); } const uint64_t *mkey; uint32_t local_id = *reinterpret_cast(data); if (ailego_unlikely(keys_->read(local_id * sizeof(uint64_t), (const void **)(&mkey), sizeof(uint64_t)) != sizeof(uint64_t))) { LOG_ERROR("Read key from segment failed"); return std::numeric_limits::max(); } if (*mkey < key) { start = idx + 1; } else if (*mkey > key) { end = idx; } else { return local_id; } } return std::numeric_limits::max(); } const void *IVFEntity::get_vector_by_key(uint64_t key) const { uint32_t id = this->key_to_id(key); if (id != std::numeric_limits::max()) { return get_vector(id); } else { return nullptr; } } int IVFEntity::get_vector_by_key(uint64_t key, IndexStorage::MemoryBlock &block) const { uint32_t id = this->key_to_id(key); if (id != std::numeric_limits::max()) { return get_vector(id, block); } else { return IndexError_Runtime; } } IVFEntity::Pointer IVFEntity::clone(void) const { auto entity = std::make_shared(); return clone(entity); } IVFEntity::Pointer IVFEntity::clone(const IVFEntity::Pointer &entity) const { if (!entity) { LOG_ERROR("Failed to alloc IVFEntity"); return nullptr; } auto inverted = inverted_->clone(); ivf_assert_with_msg(inverted, nullptr, "Failed to clone inverted segment"); auto inverted_meta = inverted_meta_->clone(); ivf_assert_with_msg(inverted_meta, nullptr, "Failed to clone inverted meta segment"); auto keys = keys_->clone(); ivf_assert_with_msg(keys, nullptr, "Failed to clone keys segment"); auto offsets = offsets_->clone(); ivf_assert_with_msg(offsets, nullptr, "Failed to clone offsets segment"); auto mapping = mapping_->clone(); ivf_assert_with_msg(mapping, nullptr, "Failed to clone mapping segment"); IndexStorage::Segment::Pointer integer_quantizer_params; if (integer_quantizer_params_) { integer_quantizer_params = integer_quantizer_params_->clone(); if (!integer_quantizer_params) { LOG_ERROR("Failed to clone integer quantizer params segment"); return nullptr; } } IndexStorage::Segment::Pointer features; if (features_) { features = features_->clone(); if (!features) { LOG_ERROR("Failed to clone features segment"); return nullptr; } } entity->meta_ = this->meta_; entity->reformer_ = this->reformer_; entity->calculator_ = this->calculator_; entity->header_ = this->header_; entity->container_ = this->container_; entity->inverted_ = inverted; entity->inverted_meta_ = inverted_meta; entity->keys_ = keys; entity->offsets_ = offsets; entity->mapping_ = mapping; entity->integer_quantizer_params_ = integer_quantizer_params; entity->features_ = features; entity->norm_value_ = this->norm_value_; entity->norm_value_sqrt_ = this->norm_value_sqrt_; return entity; } IndexStorage::Segment::Pointer IVFEntity::load_segment( const std::string &seg_id, size_t expect_size) const { auto segment = container_->get(seg_id); if (!segment) { LOG_ERROR("Failed to get segment %s", seg_id.c_str()); return nullptr; } if (expect_size && segment->data_size() != expect_size) { LOG_ERROR("Invalid segment %s size=%zu, total_vecs=%u", seg_id.c_str(), segment->data_size(), header_.total_vector_count); return nullptr; } return segment; } } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_entity.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "metric/metric_params.h" #include "ivf_distance_calculator.h" #include "ivf_index_format.h" #include "ivf_params.h" namespace zvec { namespace core { /*! IVF Entity */ class IVFEntity { public: typedef std::shared_ptr Pointer; class IVFReformerWrapper; //! Constructor IVFEntity() {} //! Destructor virtual ~IVFEntity() {} //! Disable them IVFEntity(const IVFEntity &) = delete; IVFEntity &operator=(const IVFEntity &) = delete; //! load the index from container virtual int load(const IndexStorage::Pointer &container); //! search in inverted list with filter int search(size_t inverted_list_id, const void *query, const IndexFilter &filter, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! search in inverted list without filter int search(size_t inverted_list_id, const void *query, uint32_t *scan_count, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! search all inverted list with filter int search(const void *query, const IndexFilter &filter, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! search all inverted list without filter int search(const void *query, IndexDocumentHeap *heap, IndexContext::Stats *context_stats) const; //! Clone the entity virtual IVFEntity::Pointer clone(void) const; //! Clone the entity IVFEntity::Pointer clone(const IVFEntity::Pointer &entity) const; //! Retrieve the primary keys by local id in heap int retrieve_keys(IndexDocumentHeap *heap) const { for (auto &it : (*heap)) { uint64_t key = this->get_key(it.index()); if (key == kInvalidKey) { return IndexError_ReadData; } it.set_key(key); } return 0; } //! Retrieve the total vectors in the index size_t vector_count(void) const { return header_.total_vector_count; } //! Retrieve the inverted list count size_t inverted_list_count(void) const { return header_.inverted_list_count; } //! Retrieve block size of the inverted vector size_t inverted_block_size(void) const { return header_.block_size; } //! Retrieve the vectors count in one block size_t block_vector_count(void) const { return header_.block_vector_count; } //! Retrieve IndexMeta of the inverted index const IndexMeta &meta(void) const { return meta_; } //! Retrieve a block of vectors const void *read_block(size_t inverted_list_id, size_t local_block_id, size_t *vecs_count) const { auto iv_meta = this->inverted_list_meta(inverted_list_id); if (!iv_meta || local_block_id >= iv_meta->block_count) { LOG_ERROR("Failed to read inverted list, listId=%zu blockIdx=%zu", inverted_list_id, local_block_id); return nullptr; } size_t block_vecs = header_.block_vector_count; *vecs_count = std::min(block_vecs, iv_meta->vector_count - local_block_id * block_vecs); ailego_assert_with(*vecs_count <= header_.block_vector_count, "invalid vecs"); const size_t off = iv_meta->offset + local_block_id * header_.block_size; const size_t size = *vecs_count * meta_.element_size(); const void *data = nullptr; if (inverted_->read(off, &data, size) != size) { LOG_ERROR("Failed to read block off=%zu size=%zu", off, size); return nullptr; } return data; } //! Retrieve the inverted list meta const InvertedListMeta *inverted_list_meta(size_t inverted_list_id) const { const void *data = nullptr; const size_t size = sizeof(InvertedListMeta); const size_t offset = inverted_list_id * size; if (inverted_meta_->read(offset, &data, size) != size) { LOG_ERROR("Failed to read inverted meta, id=%zu, size=%zu", inverted_list_id, size); return nullptr; } return static_cast(data); } //! Retrieve the keys by consecutive local ids const uint64_t *get_keys(size_t id, size_t count) const { const void *data = nullptr; const size_t offset = id * sizeof(uint64_t); const size_t size = count * sizeof(uint64_t); if (keys_->read(offset, &data, size) != size) { LOG_ERROR("Failed to read keys, id=%zu, size=%zu", id, size); return nullptr; } return static_cast(data); } //! Retrieve the key by local id uint64_t get_key(size_t id) const { const void *data = nullptr; const size_t offset = id * sizeof(uint64_t); const size_t size = sizeof(uint64_t); if (keys_->read(offset, &data, size) != size) { LOG_ERROR("Failed to read key, id=%zu", id); return kInvalidKey; } return *static_cast(data); } //! Retrieve vector by local id const void *get_vector(size_t id) const; //! Retrieve vector by local id const void *get_vector_by_key(uint64_t key) const; int get_vector(size_t id, IndexStorage::MemoryBlock &block) const; int get_vector_by_key(uint64_t key, IndexStorage::MemoryBlock &block) const; uint32_t key_to_id(uint64_t key) const; //! Transform a query int transform(const void *query, const IndexQueryMeta &qmeta, const void **out, IndexQueryMeta *ometa) const { return reformer_.transform(query, qmeta, out, ometa); } //! Transform queries int transform(const void *query, const IndexQueryMeta &qmeta, uint32_t count, const void **out, IndexQueryMeta *ometa) const { return reformer_.transform(query, qmeta, count, out, ometa); } //! Normalize the score in query part void normalize(size_t qidx, IndexDocumentHeap *heap) const { return reformer_.normalize(qidx, heap); } //! Retrieve the value for each inverted list to multiply for normalizing float inverted_list_normalize_value(size_t inverted_list_id) const { if (norm_value_ != 0.0f) { return norm_value_; } // ailego_assert_with(integer_quantizer_params_, "nullptr"); if (integer_quantizer_params_ != nullptr) { const void *data = nullptr; size_t size = sizeof(InvertedIntegerQuantizerParams); size_t off = inverted_list_id * size; if (integer_quantizer_params_->read(off, &data, size) != size) { LOG_ERROR("Failed to read data from segment, off=%zu", off); return 1.0f; } auto scale = static_cast(data)->scale; return this->convert_to_normalize_value(scale); } return norm_value_; } //! Check whether the feature segment exist bool has_orignal_feature() const { return !!features_; } //! Retrieve reformer const IVFReformerWrapper &reformer(void) const { return reformer_; } /*! Index Reformer Wrapper * To transform query in inverted index searching, and normalize the score */ class IVFReformerWrapper { public: //! Constructor IVFReformerWrapper() {} //! Assignment IVFReformerWrapper &operator=(const IVFReformerWrapper &wrapper) { reformer_ = wrapper.reformer_; type_ = wrapper.type_; buffer_.clear(); buffer_.shrink_to_fit(); reciprocal_ = wrapper.reciprocal_; return *this; } //! Initialize int init(const IndexMeta &imeta); //! Update int update(const IndexMeta &meta); //! Transform a query int transform(const void *query, const IndexQueryMeta &qmeta, const void **out, IndexQueryMeta *ometa); //! Transform queries int transform(const void *query, const IndexQueryMeta &qmeta, uint32_t count, const void **out, IndexQueryMeta *ometa); //! Convert a record virtual int convert(const void *record, const IndexQueryMeta &rmeta, const void **out, IndexQueryMeta *ometa); //! Convert records virtual int convert(const void *records, const IndexQueryMeta &rmeta, uint32_t count, const void **out, IndexQueryMeta *ometa); //! Transform queries int transform_gpu(const void *query, const IndexQueryMeta &qmeta, uint32_t count, const void **out, IndexQueryMeta *ometa); //! Normalize the score in query part void normalize(size_t qidx, IndexDocumentHeap *heap) const; //! Normalize the score in query part void normalize(size_t qidx, const void *query, const IndexQueryMeta &qmeta, IndexDocumentHeap *heap) const; private: //! Transform query from fp32 to int8 void transform(size_t qidx, const float *in, size_t dim, int8_t *out); //! Transform query from fp32 to int4 void transform(size_t qidx, const float *in, size_t dim, uint8_t *out); private: //! Constants enum Type { kReformerTpNone = 0, kReformerTpInnerProductInt8 = 1, kReformerTpInnerProductInt4 = 2, kReformerTpInt8 = 3, kReformerTpInt4 = 4, kReformerTpDefault = 7, }; //! Members Type type_{kReformerTpNone}; IndexReformer::Pointer reformer_{}; std::string buffer_{}; float reciprocal_{0.0}; // for int8 std::vector scales_{}; // for int8 IP }; private: //! Load the segment by seg_id in expect_size segment size IndexStorage::Segment::Pointer load_segment(const std::string &seg_id, size_t expect_size) const; //! Load the header segment int load_header(const IndexStorage::Pointer &container); //! Convert the int8 quantizer scale to normalize value float convert_to_normalize_value(float scale) const { auto v = scale == 0.0 ? 1.0 : (1.0 / scale); return !norm_value_sqrt_ ? v : std::sqrt(v); } protected: //! Constants static constexpr size_t kBatchBlocks = 10u; //! Members IndexMeta meta_{}; mutable IVFReformerWrapper reformer_{}; IVFDistanceCalculator::Pointer calculator_{}; InvertedIndexHeader header_{}; IndexStorage::Pointer container_{}; IndexStorage::Segment::Pointer inverted_{}; IndexStorage::Segment::Pointer inverted_meta_{}; IndexStorage::Segment::Pointer keys_{}; IndexStorage::Segment::Pointer offsets_{}; IndexStorage::Segment::Pointer mapping_{}; IndexStorage::Segment::Pointer features_{}; IndexStorage::Segment::Pointer integer_quantizer_params_{}; mutable std::string vector_{}; // temporary buffer for colomn major order float norm_value_{0.0f}; // normalize the inverted vector to orignal score bool norm_value_sqrt_{false}; // does the norm value need to sqrt }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_index_format.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include namespace zvec { namespace core { static constexpr uint64_t kInvalidKey = std::numeric_limits::max(); /*! Index Format of Inverted Index Header */ struct InvertedIndexHeader { uint32_t header_size{0}; uint32_t total_vector_count{0}; uint64_t inverted_body_size{0}; uint32_t inverted_list_count{0}; uint32_t block_vector_count{0}; uint32_t block_size{0}; uint32_t block_count{0}; uint32_t index_meta_size{0}; char reserved_[28]; char index_meta[0]; }; /*! Index Format of Inverted Index Meta for each Inverted list */ struct InvertedListMeta { uint64_t offset{0}; uint32_t block_count{0}; uint32_t vector_count{0}; uint32_t id_offset{0}; char reserved_[16]; }; /*! Index Format of Location in Inverted Index for each vector */ struct InvertedVecLocation { InvertedVecLocation(size_t off, bool col) : offset(off), column_major(col), reserved(0u) {} uint64_t offset : 48; // feature offset in posting block segment uint64_t column_major : 1; // coloum major if true uint64_t reserved : 15; }; /*! Index Format of Integer Quantizer params for each inverted list */ struct InvertedIntegerQuantizerParams { float scale{1.0}; float bias{0.0}; }; /*! Location of Vectors Block in Storage Segment */ struct BlockLocation { uint16_t segment_id; uint16_t block_index; }; /*! The Header of a Block in Storage Segment */ struct BlockHeader { BlockLocation next; uint16_t vector_count; uint16_t column_major : 1; uint16_t reserved_ : 15; }; struct DeletionMap { void set(uint32_t index) { bitset.set(index); } void reset(uint32_t index) { bitset.reset(index); } bool test(uint32_t index) const { return bitset.test(index); } bool is_dirty() const { return bitset.test_any(); } ailego::FixedBitset<32> bitset{}; }; static_assert(sizeof(DeletionMap) == 4, "DeletionMap must be 4 bytes"); /*! Meta Information of Streamer Entity */ struct StreamerInvertedMeta { uint64_t create_time{0}; uint64_t update_time{0}; uint64_t revision_id{0}; uint32_t segment_count{0}; uint32_t segment_size{0}; uint8_t reserved_[32]; InvertedIndexHeader header; }; /*! Location of Vector in Storage Segment */ struct VectorLocation { //! Constructor VectorLocation(void) {} //! Constructor VectorLocation(uint16_t id, bool col, uint32_t off) : segment_id(id), column_major(col), offset(off) {} uint16_t segment_id; uint16_t column_major : 1; uint16_t reserved_ : 15; uint32_t offset; public: bool operator==(const VectorLocation &other) const { return segment_id == other.segment_id && column_major == other.column_major && offset == other.offset; } }; static_assert(sizeof(VectorLocation) == sizeof(uint64_t), "VectorLocation must be size of 8 bytes"); struct KeyInfo { KeyInfo(void) {} KeyInfo(uint32_t idx, const VectorLocation &loc) : centroid_idx(idx), location(loc) {} KeyInfo(VectorLocation loc) : location(loc) {} uint32_t centroid_idx; VectorLocation location; }; // Segments ID const std::string IVF_CENTROID_SEG_ID("ivf.centroid"); const std::string IVF_INVERTED_BODY_SEG_ID("ivf.inverted_body"); const std::string IVF_INVERTED_HEADER_SEG_ID("ivf.inverted_header"); const std::string IVF_INVERTED_META_SEG_ID("ivf.inverted_meta"); const std::string IVF_KEYS_SEG_ID("hc.keys"); const std::string IVF_OFFSETS_SEG_ID("ivf.offsets"); const std::string IVF_MAPPING_SEG_ID("ivf.mapping"); const std::string IVF_FEATURES_SEG_ID("ivf.features"); const std::string IVF_INT8_QUANTIZED_PARAMS_SEG_ID("ivf.int8_quantized_params"); const std::string IVF_INT4_QUANTIZED_PARAMS_SEG_ID("ivf.int4_quantized_params"); const std::string IVF_INVERTED_LIST_HEAD_SEG_ID("ivf.inverted_list_head"); const std::string IVF_STORAGE_SEGMENT_ID("ivf.S"); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_index_provider.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "ivf_entity.h" namespace zvec { namespace core { /*! IVF IndexProvider */ class IVFIndexProvider : public IndexProvider { public: IVFIndexProvider(const IndexMeta &meta, const IVFEntity::Pointer &entity, const std::string &owner) : meta_(meta), entity_(entity), owner_class_(owner) {} IVFIndexProvider(const IVFIndexProvider &) = delete; IVFIndexProvider &operator=(const IVFIndexProvider &) = delete; public: //! Create a new iterator virtual Iterator::Pointer create_iterator(void) override { return Iterator::Pointer(new (std::nothrow) Iterator(entity_)); } //! Retrieve count of vectors virtual size_t count(void) const override { return entity_->vector_count(); } //! Retrieve dimension of vector virtual size_t dimension(void) const override { return meta_.dimension(); } //! Retrieve type of vector virtual IndexMeta::DataType data_type(void) const override { return meta_.data_type(); } //! Retrieve vector size in bytes virtual size_t element_size(void) const override { return meta_.element_size(); } //! Retrieve a vector using a primary key virtual const void *get_vector(uint64_t key) const override { return entity_->get_vector_by_key(key); } //! Retrieve the owner class virtual const std::string &owner_class(void) const override { return owner_class_; } private: class Iterator : public IndexProvider::Iterator { public: Iterator(const IVFEntity::Pointer &entity) : entity_(entity) {} //! Retrieve pointer of data //! NOTICE: the vec feature will be changed after iterating to next, so //! the caller need to keep a copy of it before iterator to next vector virtual const void *data(void) const override { return entity_->get_vector(index_); } //! Test if the iterator is valid virtual bool is_valid(void) const override { return index_ < entity_->vector_count(); } //! Retrieve primary key virtual uint64_t key(void) const override { return entity_->get_key(index_); } //! Next iterator virtual void next(void) override { ++index_; } private: //! Members IVFEntity::Pointer entity_; size_t index_{0}; }; private: //! Members const IndexMeta &meta_; IVFEntity::Pointer entity_; std::string owner_class_; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { static const std::string SEPARATOR("/"); static const std::string CENTROID_SEPERATOR = "*"; // builder params static const std::string PARAM_IVF_BUILDER_CENTROID_COUNT( "proxima.ivf.builder.centroid_count"); static const std::string PARAM_IVF_BUILDER_CLUSTER_CLASS( "proxima.ivf.builder.cluster_class"); static const std::string PARAM_IVF_BUILDER_THREAD_COUNT( "proxima.ivf.builder.thread_count"); static const std::string PARAM_IVF_BUILDER_CLUSTER_AUTO_TUNING( "proxima.ivf.builder.cluster_auto_tuning"); static const std::string PARAM_IVF_BUILDER_TRAIN_SAMPLE_COUNT( "proxima.ivf.builder.train_sample_count"); static const std::string PARAM_IVF_BUILDER_TRAIN_SAMPLE_RATIO( "proxima.ivf.builder.train_sample_ratio"); static const std::string PARAM_IVF_BUILDER_CONVERTER_PARAMS( "proxima.ivf.builder.converter_params"); static const std::string PARAM_IVF_BUILDER_CONVERTER_CLASS( "proxima.ivf.builder.converter_class"); static const std::string PARAM_IVF_BUILDER_STORE_ORIGINAL_FEATURES( "proxima.ivf.builder.store_original_features"); static const std::string PARAM_IVF_BUILDER_QUANTIZER_CLASS( "proxima.ivf.builder.quantizer_class"); static const std::string PARAM_IVF_BUILDER_QUANTIZE_BY_CENTROID( "proxima.ivf.builder.quantize_by_centroid"); static const std::string PARAM_IVF_BUILDER_QUANTIZER_PARAMS( "proxima.ivf.builder.quantizer_params"); static const std::string PARAM_IVF_BUILDER_CLUSTER_PARAMS_IN_LEVEL_PREFIX( "proxima.ivf.builder.cluster_params_in_level_"); static const std::string PARAM_IVF_BUILDER_OPTIMIZER_CLASS( "proxima.ivf.builder.optimizer_class"); static const std::string PARAM_IVF_BUILDER_OPTIMIZER_PARAMS( "proxima.ivf.builder.optimizer_params"); static const std::string PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_CLASS( "proxima.ivf.builder.optimizer_quantizer_class"); static const std::string PARAM_IVF_BUILDER_OPTIMIZER_QUANTIZER_PARAMS( "proxima.ivf.builder.optimizer_quantizer_params"); static const std::string PARAM_IVF_BUILDER_BLOCK_VECTOR_COUNT( "proxima.ivf.builder.block_vector_count"); // searcher params static const std::string PARAM_IVF_SEARCHER_SCAN_RATIO( "proxima.ivf.searcher.scan_ratio"); static const std::string PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.ivf.searcher.brute_force_threshold"); static const std::string PARAM_IVF_SEARCHER_OPTIMIZER( "proxima.ivf.searcher.optimizer"); static const std::string PARAM_IVF_SEARCHER_OPTIMIZER_PARAMS( "proxima.ivf.searcher.optimizer_params"); static const std::string PARAM_IVF_SEARCHER_CONVERTER_REFORMER( "proxima.ivf.searcher.converter_reformer"); // Constants static constexpr char const *kIPMetricName = "InnerProduct"; static constexpr char const *kMipsMetricName = "MipsSquaredEuclidean"; static constexpr char const *kL2MetricName = "SquaredEuclidean"; static constexpr char const *kMipsConverterName = "MipsConverter"; static constexpr char const *kMipsRevConverterName = "MipsReverseConverter"; static constexpr char const *kMipsReformerName = "MipsReformer"; static constexpr char const *kInt8QuantizerName = "Int8QuantizerConverter"; static constexpr char const *kInt4QuantizerName = "Int4QuantizerConverter"; static constexpr char const *kInt8ReformerName = "Int8QuantizerReformer"; static constexpr char const *kInt4ReformerName = "Int4QuantizerReformer"; static constexpr float kNormalizeScaleFactor = 16.0f; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_searcher.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_searcher.h" #include #include #include "ivf_centroid_index.h" #include "ivf_index_provider.h" #include "ivf_params.h" namespace zvec { namespace core { int IVFSearcher::init(const ailego::Params ¶meters) { params_ = parameters; params_.get(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); searcher_state_ = STATE_INITED; return 0; } int IVFSearcher::cleanup(void) { this->unload(); params_.clear(); bruteforce_threshold_ = kDefaultBfThreshold; searcher_state_ = STATE_INIT; return 0; } int IVFSearcher::load(IndexStorage::Pointer container, IndexMetric::Pointer /*metric*/) { if (!container) { LOG_ERROR("Invalid container"); return IndexError_InvalidArgument; } if (searcher_state_ != STATE_INITED) { LOG_ERROR("Initalize the searcher first before load index"); return IndexError_Runtime; } ailego::ElapsedTime timer; int ret = IndexHelper::DeserializeFromStorage(container.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from container"); return ret; } //! Load centroid index centroid_index_ = std::make_shared(); if (!centroid_index_) { return IndexError_NoMemory; } auto seg = container->get(IVF_CENTROID_SEG_ID, 0); if (!seg) { LOG_ERROR("Failed to get segment %s", IVF_CENTROID_SEG_ID.c_str()); return IndexError_InvalidFormat; } IndexStorage::Pointer seg_container = std::make_shared(seg); if (!seg_container) { return IndexError_NoMemory; } ret = seg_container->open(std::string(), false); if (ret != 0) { LOG_ERROR("IndexSegmentStorage load failed for %s", IndexError::What(ret)); return ret; } ret = centroid_index_->load(seg_container, params_); if (ret != 0) { LOG_ERROR("Failed to load index for %s", IndexError::What(ret)); return ret; } auto reformer = centroid_index_->reformer(); params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index entity_ = std::make_shared(); if (!entity_) { return IndexError_NoMemory; } ret = entity_->load(container); ivf_check_error_code(ret); magic_ = IndexContext::GenerateMagic(); stats_.set_loaded_count(entity_->vector_count()); stats_.set_loaded_costtime(timer.milli_seconds()); searcher_state_ = STATE_LOADED; return 0; } int IVFSearcher::unload(void) { magic_ = 0; centroid_index_.reset(); entity_.reset(); stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); stats_.clear_attributes(); searcher_state_ = STATE_INITED; return 0; } int IVFSearcher::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return search_bf_impl(query, qmeta, 1, context); } int IVFSearcher::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (!query || qmeta.element_size() != meta_.element_size()) { LOG_ERROR("Null query or invalid qmeta"); return IndexError_InvalidArgument; } IVFSearcherContext *ctx = dynamic_cast(context.get()); if (!ctx || ctx->topk() == 0) { LOG_ERROR("Invalid context or topk not set yet"); return IndexError_InvalidArgument; } if (ctx->magic() != magic_) { //! context is created by another searcher int ret = this->update_context(ctx); ivf_check_error_code(ret); } ctx->reset_results(count); auto &entity = ctx->entity(); auto &filter = ctx->filter(); //! Transform the querys for querying in inverted vector index later IndexQueryMeta iv_qmeta; int ret = entity->transform(query, qmeta, count, &query, &iv_qmeta); ivf_check_with_msg(ret, "Failed to transform querys"); // TODO: do batch search in matrix for (size_t q = 0; q < count; ++q) { auto &context_stats = ctx->mutable_stats(q); auto &heap = ctx->mutable_result_heap(); heap.clear(); if (!filter.is_valid()) { ret = entity->search(query, &heap, &context_stats); } else { ret = entity->search(query, filter, &heap, &context_stats); } ivf_check_with_msg(ret, "Failed to search in entity for %s", IndexError::What(ret)); heap.sort(); // sort the results if (!filter.is_valid()) { // mapping the local id to key if query without filter ret = entity->retrieve_keys(&heap); ivf_check_error_code(ret); } entity->normalize(q, &heap); ctx->topk_to_result(q); query = static_cast(query) + iv_qmeta.element_size(); } return 0; } int IVFSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return this->search_impl(query, qmeta, 1, context); } int IVFSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (entity_->vector_count() <= bruteforce_threshold_) { return this->search_bf_impl(query, qmeta, count, context); } if (!query || qmeta.element_size() != meta_.element_size()) { LOG_ERROR("Null query or invalid qmeta"); return IndexError_InvalidArgument; } IVFSearcherContext *ctx = dynamic_cast(context.get()); if (!ctx || ctx->topk() == 0) { LOG_ERROR("Invalid context or topk not set yet"); return IndexError_InvalidArgument; } if (ctx->magic() != magic_) { //! context is created by another searcher int ret = update_context(ctx); ivf_check_error_code(ret); } ctx->reset_results(count); auto &entity = ctx->entity(); auto &filter = ctx->filter(); auto ¢roid_index_ctx = ctx->centroid_searcher_ctx(); int ret = centroid_index_->search(query, qmeta, count, centroid_index_ctx); ivf_check_error_code(ret); //! Transform the querys for querying in inverted vector index later IndexQueryMeta iv_qmeta; ret = entity->transform(query, qmeta, count, &query, &iv_qmeta); ivf_check_with_msg(ret, "Failed to transform querys"); for (size_t q = 0; q < count; ++q) { auto ¢roids = centroid_index_ctx->result(q); auto &context_stats = ctx->mutable_stats(q); auto &heap = ctx->mutable_result_heap(); heap.clear(); uint32_t total_scan_count = 0; for (size_t i = 0; i < centroids.size() && total_scan_count < ctx->max_scan_count(); ++i) { auto cid = centroids[i].key(); uint32_t scan_count = 0; if (!filter.is_valid()) { ret = entity->search(cid, query, &scan_count, &heap, &context_stats); } else { ret = entity->search(cid, query, filter, &scan_count, &heap, &context_stats); } ivf_check_with_msg(ret, "Failed to search in entity for %s", IndexError::What(ret)); total_scan_count += scan_count; } heap.sort(); // sort the results if (!filter.is_valid()) { // mapping the local id to key if query without filter ret = entity->retrieve_keys(&heap); ivf_check_error_code(ret); } entity->normalize(q, &heap); ctx->topk_to_result(q); query = static_cast(query) + iv_qmeta.element_size(); } return 0; } const IndexSearcher::Stats &IVFSearcher::stats(void) const { return stats_; } IndexSearcher::Context::Pointer IVFSearcher::create_context() const { if (searcher_state_ != STATE_LOADED) { LOG_ERROR("Load the index first before create context"); return nullptr; } auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone IVFEntity"); return nullptr; } auto centroid_index_ctx = centroid_index_->create_context(); if (!centroid_index_ctx) { LOG_ERROR("Failed to create centroid index context"); return nullptr; } auto context = new (std::nothrow) IVFSearcherContext(entity, centroid_index_ctx); if (!context) { LOG_ERROR("Failed to alloc IVFSearcherContext"); return nullptr; } int ret = context->init(params_); if (ret != 0) { delete context; return nullptr; } context->set_magic(magic_); return Context::Pointer(context); } IndexProvider::Pointer IVFSearcher::create_provider(void) const { if (searcher_state_ != STATE_LOADED) { LOG_ERROR("Load the index first before create provider"); return nullptr; } auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone IVFEntity"); return Provider::Pointer(); } auto *provider = new (std::nothrow) IVFIndexProvider(entity->has_orignal_feature() ? meta_ : entity->meta(), entity, "IVFSearcher"); if (!provider) { LOG_ERROR("Failed to alloc IVFIndexProvider"); return Provider::Pointer(); } return Provider::Pointer(provider); } INDEX_FACTORY_REGISTER_SEARCHER(IVFSearcher); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_searcher.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "ivf_centroid_index.h" #include "ivf_entity.h" #include "ivf_searcher_context.h" namespace zvec { namespace core { /*! IVF Searcher */ class IVFSearcher : public IndexSearcher { public: //! Initialize Searcher virtual int init(const ailego::Params ¶meters) override; //! Cleanup Searcher virtual int cleanup(void) override; //! Load index from container virtual int load(IndexStorage::Pointer container, IndexMetric::Pointer metric) override; //! Unload index virtual int unload(void) override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Retrieve statistics virtual const Stats &stats(void) const override; //! Create a searcher context virtual Context::Pointer create_context(void) const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } //! Retrieve params of index virtual const ailego::Params ¶ms(void) const override { return params_; } protected: int update_context(IVFSearcherContext *ctx) const { auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone QcEntity"); return IndexError_Runtime; } //! The centroid index searcher may be different, so need to create one auto centroid_ctx = centroid_index_->create_context(); if (!centroid_ctx) { LOG_ERROR("Failed to create centroid index searcher context"); return IndexError_Runtime; } return ctx->update_context(entity, centroid_ctx, params_, magic_); } private: //! Constants static constexpr uint32_t kDefaultBfThreshold = 1000u; enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; //! Members IndexMeta meta_{}; ailego::Params params_{}; IVFCentroidIndex::Pointer centroid_index_{}; IVFEntity::Pointer entity_{}; uint32_t bruteforce_threshold_{kDefaultBfThreshold}; uint32_t magic_{0}; Stats stats_{}; State searcher_state_{STATE_INIT}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_searcher_context.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "ivf_entity.h" #include "ivf_utility.h" namespace zvec { namespace core { /*! IVF Searcher Context */ class IVFSearcherContext : public IndexSearcher::Context { public: IVFSearcherContext(const IVFEntity::Pointer &ivf_entity, IndexSearcher::Context::Pointer ¢roid_ctx) : entity_(ivf_entity), centroid_searcher_ctx_(std::move(centroid_ctx)) {} public: //! Set topk of search result virtual void set_topk(uint32_t k) override { topk_ = k; result_heap_.limit(topk_); result_heap_.set_threshold(this->threshold()); } //! Retrieve search result virtual const IndexDocumentList &result(void) const override { return results_[0]; } //! Retrieve search result with index virtual const IndexDocumentList &result(size_t idx) const override { ailego_assert_with(results_.size() > idx, "invalid index"); return results_[idx]; } //! Retrieve mutable result with index virtual IndexDocumentList *mutable_result(size_t idx) override { ailego_assert_with(idx < results_.size(), "invalid idx"); return &results_[idx]; } inline IndexDocumentHeap *result_heap() { return &result_heap_; } //! Update the parameters of context virtual int update(const ailego::Params ¶ms) override { params.get(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_IVF_SEARCHER_SCAN_RATIO, &scan_ratio_); if (scan_ratio_ <= 0.0) { LOG_ERROR("Invalid params %s=%f", PARAM_IVF_SEARCHER_SCAN_RATIO.c_str(), scan_ratio_); return IndexError_InvalidArgument; } size_t topk_val = std::max(static_cast( std::round(entity_->inverted_list_count() * scan_ratio_)), 1u); centroid_searcher_ctx_->set_topk(topk_val); max_scan_count_ = static_cast(std::ceil(entity_->vector_count() * scan_ratio_)); max_scan_count_ = std::max(bruteforce_threshold_, max_scan_count_); return 0; } //! Retrieve magic number virtual uint32_t magic(void) const override { return magic_; } public: //! Initialize the context int init(const ailego::Params ¶ms) { return this->update(params); } //! Update the magic number void set_magic(uint32_t mag) { magic_ = mag; } //! Get Topk Value uint32_t topk() const override { return topk_; } //! Retrieve scan ratio float scan_ratio(void) const { return scan_ratio_; } //! Retrieve max scan count uint32_t max_scan_count(void) const { return max_scan_count_; } uint32_t bruteforce_threshold() const { return bruteforce_threshold_; } //! Retrieve magic number const IVFEntity::Pointer &entity() const { return entity_; } //! Retrieve Mutable Query Result By Query Index IndexDocumentHeap &mutable_result_heap() { return result_heap_; } void set_fetch_vector(bool v) override { fetch_vector_ = v; } bool fetch_vector(void) const override { return fetch_vector_; } //! Reset all the query results void reset_results(size_t qnum) { results_.resize(qnum); stats_vec_.resize(qnum); for (size_t i = 0; i < qnum; ++i) { results_[i].clear(); stats_vec_[i].clear(); } result_heap_.clear(); result_heap_.limit(topk_); result_heap_.set_threshold(this->threshold()); } //! Update context, the context may be shared by different searcher int update_context(IVFEntity::Pointer &new_entity, IndexSearcher::Context::Pointer ¢roid_ctx, const ailego::Params ¶ms, uint32_t magic_num) { entity_ = new_entity; centroid_searcher_ctx_ = std::move(centroid_ctx); int ret = this->update(params); ivf_check_error_code(ret); magic_ = magic_num; return 0; } //! Retrieve the centroid index context IndexSearcher::Context::Pointer ¢roid_searcher_ctx(void) { return centroid_searcher_ctx_; } const Stats &stats(size_t idx = 0) const { ailego_assert_with(stats_vec_.size() > idx, "invalid index"); return stats_vec_[idx]; } Stats &mutable_stats(size_t idx = 0) { ailego_assert_with(stats_vec_.size() > idx, "invalid index"); return stats_vec_[idx]; } void topk_to_result(uint32_t idx) { if (ailego_unlikely(result_heap_.size() == 0)) { return; } ailego_assert_with(idx < results_.size(), "invalid idx"); int size = std::min(topk_, static_cast(result_heap_.size())); result_heap_.sort(); results_[idx].clear(); for (int i = 0; i < size; ++i) { auto score = result_heap_[i].score(); if (score > this->threshold()) { break; } key_t key = result_heap_[i].key(); if (fetch_vector_) { IndexStorage::MemoryBlock block; entity_->get_vector_by_key(key, block); results_[idx].emplace_back(key, score, key, block); } else { results_[idx].emplace_back(key, score); } } } private: //! Constants static constexpr float kDefaultScanRatio = 0.1f; static constexpr uint32_t kDefaultBfThreshold = 1000u; //! Members IVFEntity::Pointer entity_{}; IndexSearcher::Context::Pointer centroid_searcher_ctx_{}; IndexDocumentHeap result_heap_; std::vector results_{}; std::vector stats_vec_{}; bool fetch_vector_{false}; uint32_t topk_{0}; uint32_t magic_{0}; float scan_ratio_{kDefaultScanRatio}; uint32_t max_scan_count_{0}; uint32_t bruteforce_threshold_{kDefaultBfThreshold}; }; } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_streamer.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ivf_streamer.h" #include #include #include "ivf_centroid_index.h" #include "ivf_index_provider.h" #include "ivf_params.h" namespace zvec { namespace core { int IVFStreamer::init(const IndexMeta &meta, const ailego::Params ¶meters) { meta_ = meta; params_ = parameters; params_.get(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); searcher_state_ = STATE_INITED; return 0; } int IVFStreamer::cleanup(void) { this->unload(); params_.clear(); bruteforce_threshold_ = kDefaultBfThreshold; searcher_state_ = STATE_INIT; return 0; } int IVFStreamer::open(IndexStorage::Pointer storage) { if (!storage) { LOG_ERROR("Invalid storage"); return IndexError_InvalidArgument; } if (searcher_state_ != STATE_INITED) { LOG_ERROR("Initalize the searcher first before load index"); return IndexError_Runtime; } ailego::ElapsedTime timer; int ret = IndexHelper::DeserializeFromStorage(storage.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize meta from storage"); return ret; } //! Load centroid index centroid_index_ = std::make_shared(); if (!centroid_index_) { return IndexError_NoMemory; } auto seg = storage->get(IVF_CENTROID_SEG_ID, 0); if (!seg) { LOG_ERROR("Failed to get segment %s", IVF_CENTROID_SEG_ID.c_str()); return IndexError_InvalidFormat; } IndexStorage::Pointer seg_container = std::make_shared(seg); if (!seg_container) { return IndexError_NoMemory; } ret = seg_container->open(std::string(), false); if (ret != 0) { LOG_ERROR("IndexSegmentStorage load failed for %s", IndexError::What(ret)); return ret; } ret = centroid_index_->load(seg_container, params_); if (ret != 0) { LOG_ERROR("Failed to load index for %s", IndexError::What(ret)); return ret; } auto reformer = centroid_index_->reformer(); params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index entity_ = std::make_shared(); if (!entity_) { return IndexError_NoMemory; } ret = entity_->load(storage); ivf_check_error_code(ret); magic_ = IndexContext::GenerateMagic(); stats_.set_loaded_count(entity_->vector_count()); stats_.set_loaded_costtime(timer.milli_seconds()); searcher_state_ = STATE_LOADED; return 0; } int IVFStreamer::unload(void) { magic_ = 0; centroid_index_.reset(); entity_.reset(); stats_.set_loaded_count(0UL); stats_.set_loaded_costtime(0UL); stats_.clear_attributes(); searcher_state_ = STATE_INITED; return 0; } int IVFStreamer::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return search_bf_impl(query, qmeta, 1, context); } int IVFStreamer::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (!query || qmeta.element_size() != meta_.element_size()) { LOG_ERROR("Null query or invalid qmeta"); return IndexError_InvalidArgument; } IVFSearcherContext *ctx = dynamic_cast(context.get()); if (!ctx || ctx->topk() == 0) { LOG_ERROR("Invalid context or topk not set yet"); return IndexError_InvalidArgument; } if (ctx->magic() != magic_) { //! context is created by another searcher int ret = this->update_context(ctx); ivf_check_error_code(ret); } ctx->reset_results(count); auto &entity = ctx->entity(); auto &filter = ctx->filter(); //! Transform the querys for querying in inverted vector index later IndexQueryMeta iv_qmeta; int ret = entity->transform(query, qmeta, count, &query, &iv_qmeta); ivf_check_with_msg(ret, "Failed to transform querys"); // TODO: do batch search in matrix for (size_t q = 0; q < count; ++q) { auto &context_stats = ctx->mutable_stats(q); auto &heap = ctx->mutable_result_heap(); heap.clear(); if (!filter.is_valid()) { ret = entity->search(query, &heap, &context_stats); } else { ret = entity->search(query, filter, &heap, &context_stats); } ivf_check_with_msg(ret, "Failed to search in entity for %s", IndexError::What(ret)); heap.sort(); // sort the results if (!filter.is_valid()) { // mapping the local id to key if query without filter ret = entity->retrieve_keys(&heap); ivf_check_error_code(ret); } entity->normalize(q, &heap); ctx->topk_to_result(q); query = static_cast(query) + iv_qmeta.element_size(); } return 0; } int IVFStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { return this->search_impl(query, qmeta, 1, context); } int IVFStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (entity_->vector_count() <= bruteforce_threshold_) { return this->search_bf_impl(query, qmeta, count, context); } if (!query || qmeta.element_size() != meta_.element_size()) { LOG_ERROR("Null query or invalid qmeta"); return IndexError_InvalidArgument; } IVFSearcherContext *ctx = dynamic_cast(context.get()); if (!ctx || ctx->topk() == 0) { LOG_ERROR("Invalid context or topk not set yet"); return IndexError_InvalidArgument; } if (ctx->magic() != magic_) { //! context is created by another searcher int ret = update_context(ctx); ivf_check_error_code(ret); } ctx->reset_results(count); auto &entity = ctx->entity(); auto &filter = ctx->filter(); auto ¢roid_index_ctx = ctx->centroid_searcher_ctx(); int ret = centroid_index_->search(query, qmeta, count, centroid_index_ctx); ivf_check_error_code(ret); //! Transform the querys for querying in inverted vector index later IndexQueryMeta iv_qmeta; ret = entity->transform(query, qmeta, count, &query, &iv_qmeta); ivf_check_with_msg(ret, "Failed to transform querys"); for (size_t q = 0; q < count; ++q) { auto ¢roids = centroid_index_ctx->result(q); auto &context_stats = ctx->mutable_stats(q); auto &heap = ctx->mutable_result_heap(); heap.clear(); uint32_t total_scan_count = 0; for (size_t i = 0; i < centroids.size() && total_scan_count < ctx->max_scan_count(); ++i) { auto cid = centroids[i].key(); uint32_t scan_count = 0; if (!filter.is_valid()) { ret = entity->search(cid, query, &scan_count, &heap, &context_stats); } else { ret = entity->search(cid, query, filter, &scan_count, &heap, &context_stats); } ivf_check_with_msg(ret, "Failed to search in entity for %s", IndexError::What(ret)); total_scan_count += scan_count; } heap.sort(); // sort the results if (!filter.is_valid()) { // mapping the local id to key if query without filter ret = entity->retrieve_keys(&heap); ivf_check_error_code(ret); } entity->normalize(q, &heap); ctx->topk_to_result(q); query = static_cast(query) + iv_qmeta.element_size(); } return 0; } const IndexSearcher::Stats &IVFStreamer::stats(void) const { return stats_; } IndexSearcher::Context::Pointer IVFStreamer::create_context() const { if (searcher_state_ != STATE_LOADED) { LOG_ERROR("Load the index first before create context"); return nullptr; } auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone IVFEntity"); return nullptr; } auto centroid_index_ctx = centroid_index_->create_context(); if (!centroid_index_ctx) { LOG_ERROR("Failed to create centroid index context"); return nullptr; } auto context = new (std::nothrow) IVFSearcherContext(entity, centroid_index_ctx); if (!context) { LOG_ERROR("Failed to alloc IVFSearcherContext"); return nullptr; } int ret = context->init(params_); if (ret != 0) { delete context; return nullptr; } context->set_magic(magic_); return Context::Pointer(context); } IndexProvider::Pointer IVFStreamer::create_provider(void) const { if (searcher_state_ != STATE_LOADED) { LOG_ERROR("Load the index first before create provider"); return nullptr; } auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone IVFEntity"); return Provider::Pointer(); } auto *provider = new (std::nothrow) IVFIndexProvider(entity->has_orignal_feature() ? meta_ : entity->meta(), entity, "IVFStreamer"); if (!provider) { LOG_ERROR("Failed to alloc IVFIndexProvider"); return Provider::Pointer(); } return Provider::Pointer(provider); } INDEX_FACTORY_REGISTER_STREAMER(IVFStreamer); } // namespace core } // namespace zvec ================================================ FILE: src/core/algorithm/ivf/ivf_streamer.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef __IVF_STREAMER_H__ #define __IVF_STREAMER_H__ #include #include "ivf_centroid_index.h" #include "ivf_entity.h" #include "ivf_searcher_context.h" namespace zvec { namespace core { /*! IVF Searcher */ class IVFStreamer : public IndexStreamer { public: //! Initialize Searcher virtual int init(const IndexMeta & /*meta*/, const ailego::Params & /*params*/) override; //! Cleanup Searcher virtual int cleanup(void) override; //! Load index from container virtual int open(IndexStorage::Pointer storage) override; virtual int flush(uint64_t /*check_point*/) override { return 0; } virtual int close(void) override { return this->unload(); } //! Unload index virtual int unload(void) override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity brute force search virtual int search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const override; //! Similarity search virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const override; //! Retrieve statistics virtual const Stats &stats(void) const override; //! Create a searcher context virtual Context::Pointer create_context(void) const override; //! Create a new iterator virtual IndexProvider::Pointer create_provider(void) const override; //! Retrieve meta of index virtual const IndexMeta &meta(void) const override { return meta_; } virtual int get_vector_by_id( const uint32_t id, IndexStorage::MemoryBlock &block) const override { return entity_->get_vector_by_key(id, block); } protected: int update_context(IVFSearcherContext *ctx) const { auto entity = entity_->clone(); if (!entity) { LOG_ERROR("Failed to clone QcEntity"); return IndexError_Runtime; } //! The centroid index searcher may be different, so need to create one auto centroid_ctx = centroid_index_->create_context(); if (!centroid_ctx) { LOG_ERROR("Failed to create centroid index searcher context"); return IndexError_Runtime; } return ctx->update_context(entity, centroid_ctx, params_, magic_); } private: //! Constants static constexpr uint32_t kDefaultBfThreshold = 1000u; enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; //! Members IndexMeta meta_{}; ailego::Params params_{}; IndexBuilder::Pointer builder_; IVFCentroidIndex::Pointer centroid_index_{}; IVFEntity::Pointer entity_{}; uint32_t bruteforce_threshold_{kDefaultBfThreshold}; uint32_t magic_{0}; Stats stats_{}; State searcher_state_{STATE_INIT}; }; } // namespace core } // namespace zvec #endif //__IVF_STREAMER_H__ ================================================ FILE: src/core/algorithm/ivf/ivf_utility.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include namespace zvec { namespace core { #ifndef ivf_check_error_code #define ivf_check_error_code(code) \ if (ailego_unlikely((code) != 0)) return code #endif #ifndef ivf_assert #define ivf_assert(cond, code) \ if (ailego_unlikely(!(cond))) return code #endif #ifndef ivf_check_with_msg #define ivf_check_with_msg(code, fmt, args...) \ do { \ if (ailego_unlikely((code) != 0)) { \ LOG_ERROR(fmt, ##args); \ return code; \ } \ } while (0) #endif #ifndef ivf_assert_with_msg #define ivf_assert_with_msg(cond, err, fmt, args...) \ do { \ if (ailego_unlikely(!(cond))) { \ LOG_ERROR(fmt, ##args); \ return err; \ } \ } while (0) #endif /*! Quantized Clustering Utility */ class IVFUtility { public: //! Generator a random path with specificed prefix static inline std::string GenerateRandomPath(const std::string &prefix) { uint64_t timestamp = ailego::Monotime::MicroSeconds(); return prefix + std::to_string(timestamp); } //! Compute the default scan ratio for total vectors static inline float ComputeScanRatio(size_t vector_count) { // the fitting function for the follow points: 1000000(0.02) // 10000000(0.01) 50000000(0.005) 100000000(0.001) float scan_ratio = -0.004 * std::log(vector_count) + 0.0751; scan_ratio = std::max(scan_ratio, 0.0001f); return scan_ratio; } //! Transpose the vectors in row major order to column major order static inline void Transpose(size_t align_size, const void *src, size_t m, size_t dim, void *dst); //! Transpose the vectors in column major order to row major order static inline void ReverseTranspose(size_t align_size, const void *src, size_t m, size_t dim, void *dst); //! Aligned size of a block vectors buffer static inline size_t AlignedSize(size_t fnum, size_t element_size); //! Aligned size of one vector buffer static inline size_t AlignedSize(size_t element_size); //! Sort arr with size in ascending order, and keep the index postion //! n2o keep the mapping: new position => origin postion //! For example, the input arr = [5, 3, 9, 6, 7], size = 5, after sort // arr = [3, 5, 6, 7, 9] // n2o = [1, 0, 3, 4, 2] //! To save memory, no extra memory is allocated template static void Sort(T *arr, std::vector *n2o, size_t size) { std::vector o2n; o2n.resize(size); n2o->resize(size); std::iota(n2o->begin(), n2o->end(), 0U); std::sort(n2o->begin(), n2o->end(), [&](I i, I j) { return arr[i] < arr[j]; }); for (I i = 0U; i < size; ++i) { o2n[(*n2o)[i]] = i; } //! reorder arr in place, according to given n2o index for (I i = 0; i < size; ++i) { if (i != (*n2o)[i]) { T tmp = arr[i]; I j = i, k; while (i != (k = (*n2o)[j])) { arr[j] = arr[k]; (*n2o)[j] = j; j = k; } arr[j] = tmp; (*n2o)[j] = j; } } for (I i = 0U; i < size; ++i) { (*n2o)[o2n[i]] = i; } } //! Transpose one vector in block template static inline void TransposeOne(const void *src, size_t M, size_t N, void *dst) { for (size_t i = 0; i < N; ++i) { reinterpret_cast(dst)[i] = reinterpret_cast(src)[i * M]; } } }; void IVFUtility::Transpose(size_t align_size, const void *src, size_t m, size_t dim, void *dst) { switch (align_size) { case 2: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; case 4: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; case 8: ailego::MatrixHelper::Transpose(src, m, dim, dst); break; } } void IVFUtility::ReverseTranspose(size_t align_size, const void *src, size_t m, size_t dim, void *dst) { switch (align_size) { case 2: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; case 4: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; case 8: ailego::MatrixHelper::ReverseTranspose(src, m, dim, dst); break; } } size_t IVFUtility::AlignedSize(size_t fnum, size_t element_size) { return ailego_align(fnum * element_size, 32); } size_t IVFUtility::AlignedSize(size_t element_size) { return ailego_align(element_size, 32); } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_framework STATIC STRICT ALWAYS_LINK SRCS *.cc LIBS zvec_ailego INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/framework/index_cluster.cc ================================================ // namespace aitheta2 // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include namespace zvec { namespace core { static const std::string CLUSTER_CENTROIDS_FEATURES_NAME = "IndexCluster.Centroids.Features"; static const std::string CLUSTER_CENTROIDS_INDEXES_NAME = "IndexCluster.Centroids.Indexes"; /*! Item Centroid Format */ struct ItemCentroidFormat { uint32_t parent; uint32_t reserved0_; uint64_t follows; double score; uint64_t reserved1_; }; static inline bool GatherSubitemsCount(const ItemCentroidFormat *format, size_t count, std::vector *out) { out->resize(count + 1); for (const ItemCentroidFormat *it = format, *end = format + count; it != end; ++it) { uint32_t parent = it->parent + 1; if (parent > count) { return false; } (*out)[parent] += 1; } return (out->front() != 0); } int IndexCluster::Deserialize(const IndexMeta &meta, IndexBundle::Pointer bundle, CentroidList *cents) { if (!bundle || !cents) { return IndexError_InvalidArgument; } ailego::BlobWrap features = bundle->get(CLUSTER_CENTROIDS_FEATURES_NAME); ailego::BlobWrap indexes = bundle->get(CLUSTER_CENTROIDS_INDEXES_NAME); if (!features.is_valid() || !indexes.is_valid()) { return IndexError_InvalidArgument; } if (features.size() % meta.element_size() != 0 || indexes.size() % sizeof(ItemCentroidFormat) != 0) { return IndexError_InvalidLength; } size_t count = features.size() / meta.element_size(); if (indexes.size() / sizeof(ItemCentroidFormat) != count) { return IndexError_InvalidLength; } const ItemCentroidFormat *format = reinterpret_cast(indexes.buffer()); std::vector subitems; if (!GatherSubitemsCount(format, count, &subitems)) { return IndexError_InvalidFormat; } std::vector items; items.reserve(count); cents->clear(); cents->reserve(subitems.front()); const uint8_t *feat = reinterpret_cast(features.buffer()); size_t feat_size = meta.element_size(); for (size_t i = 0; i < count; ++i, ++format, feat += feat_size) { CentroidList *current = cents; if (format->parent != static_cast(-1)) { if (format->parent >= items.size()) { return IndexError_InvalidFormat; } current = items[format->parent]->mutable_subitems(); } current->emplace_back(feat, feat_size); // Update information Centroid *last_one = &(current->back()); last_one->set_follows(static_cast(format->follows)); last_one->set_score(format->score); last_one->mutable_subitems()->reserve(subitems[i + 1]); items.push_back(last_one); } return 0; } static void SerializeToBuffers(const IndexCluster::CentroidList ¢s, std::string *features, std::string *indexes) { uint32_t parent = static_cast(indexes->size() / sizeof(ItemCentroidFormat)) - 1; for (const auto &it : cents) { ItemCentroidFormat format{parent, 0, it.follows(), it.score(), 0}; indexes->append(reinterpret_cast(&format), sizeof(format)); features->append(reinterpret_cast(it.feature()), it.size()); if (!it.subitems().empty()) { SerializeToBuffers(it.subitems(), features, indexes); } } } int IndexCluster::Serialize(const IndexMeta &meta, const CentroidList ¢s, IndexBundle::Pointer *out) { size_t cents_total = cents.size(); // Check the centroids for (const auto &it : cents) { if (!it.is_matched(meta)) { return IndexError_Mismatch; } cents_total += it.subcount(); } std::string features, indexes; features.reserve(cents_total * meta.element_size()); indexes.reserve(cents_total * sizeof(ItemCentroidFormat)); SerializeToBuffers(cents, &features, &indexes); std::shared_ptr bundle = std::make_shared(); bundle->set(CLUSTER_CENTROIDS_FEATURES_NAME, std::move(features)); bundle->set(CLUSTER_CENTROIDS_INDEXES_NAME, std::move(indexes)); *out = std::move(bundle); return 0; } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_context.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include namespace zvec { namespace core { uint32_t IndexContext::GenerateMagic(void) { static std::atomic_uint32_t magic_number{std::random_device()()}; return magic_number.fetch_add(1); } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_converter.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include namespace zvec { namespace core { int IndexConverter::TrainAndTransform(const IndexConverter::Pointer &converter, IndexHolder::Pointer holder) { auto two_pass_holder = IndexHelper::MakeTwoPassHolder(std::move(holder)); int ret = converter->train(two_pass_holder); if (ret == 0) { ret = converter->transform(std::move(two_pass_holder)); } return ret; } int IndexConverter::TrainTransformAndDump( const IndexConverter::Pointer &converter, IndexHolder::Pointer holder, const IndexDumper::Pointer &dumper) { int ret = IndexConverter::TrainAndTransform(converter, std::move(holder)); if (ret == 0) { ret = converter->dump(dumper); } return ret; } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_error.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include namespace zvec { namespace core { INDEX_ERROR_CODE_DEFINE(Success, 0, "Success"); INDEX_ERROR_CODE_DEFINE(Runtime, 1, "Runtime error"); INDEX_ERROR_CODE_DEFINE(Logic, 2, "Logic error"); INDEX_ERROR_CODE_DEFINE(Type, 3, "Type error"); INDEX_ERROR_CODE_DEFINE(System, 4, "System call error"); INDEX_ERROR_CODE_DEFINE(Cast, 5, "Cast error"); INDEX_ERROR_CODE_DEFINE(IO, 6, "IO error"); INDEX_ERROR_CODE_DEFINE(AuthExpired, 7, "Auth expired error"); INDEX_ERROR_CODE_DEFINE(NotImplemented, 11, "Not implemented"); INDEX_ERROR_CODE_DEFINE(Unsupported, 12, "Unsupported"); INDEX_ERROR_CODE_DEFINE(Denied, 13, "Permission denied"); INDEX_ERROR_CODE_DEFINE(Canceled, 14, "Operation canceled"); INDEX_ERROR_CODE_DEFINE(Overflow, 15, "Overflow"); INDEX_ERROR_CODE_DEFINE(Underflow, 16, "Underflow"); INDEX_ERROR_CODE_DEFINE(OutOfRange, 17, "Out of range"); INDEX_ERROR_CODE_DEFINE(NoBuffer, 18, "No buffer space available"); INDEX_ERROR_CODE_DEFINE(NoMemory, 19, "Not enough space"); INDEX_ERROR_CODE_DEFINE(NoParamFound, 20, "No parameter found"); INDEX_ERROR_CODE_DEFINE(NoReady, 21, "No ready"); INDEX_ERROR_CODE_DEFINE(NoExist, 22, "No exist"); INDEX_ERROR_CODE_DEFINE(Exist, 23, "Already exist"); INDEX_ERROR_CODE_DEFINE(Mismatch, 24, "Mismatch"); INDEX_ERROR_CODE_DEFINE(Duplicate, 25, "Duplicate"); INDEX_ERROR_CODE_DEFINE(Uninitialized, 26, "Uninitialized"); INDEX_ERROR_CODE_DEFINE(InvalidArgument, 31, "Invalid argument"); INDEX_ERROR_CODE_DEFINE(InvalidFormat, 32, "Invalid format"); INDEX_ERROR_CODE_DEFINE(InvalidLength, 33, "Invalid length"); INDEX_ERROR_CODE_DEFINE(InvalidChecksum, 34, "Invalid checksum"); INDEX_ERROR_CODE_DEFINE(InvalidValue, 35, "Invalid value"); INDEX_ERROR_CODE_DEFINE(CreateDirectory, 101, "Create directory error"); INDEX_ERROR_CODE_DEFINE(OpenDirectory, 102, "Open directory error"); INDEX_ERROR_CODE_DEFINE(Serialize, 105, "Serialize error"); INDEX_ERROR_CODE_DEFINE(Deserialize, 106, "Deserialize error"); INDEX_ERROR_CODE_DEFINE(CreateFile, 111, "Create file error"); INDEX_ERROR_CODE_DEFINE(OpenFile, 112, "Open file error"); INDEX_ERROR_CODE_DEFINE(SeekFile, 113, "Seek file error"); INDEX_ERROR_CODE_DEFINE(CloseFile, 114, "Close file error"); INDEX_ERROR_CODE_DEFINE(TruncateFile, 115, "TruncateFile file error"); INDEX_ERROR_CODE_DEFINE(MMapFile, 116, "MMap file error"); INDEX_ERROR_CODE_DEFINE(FlushFile, 117, "Flush file error"); INDEX_ERROR_CODE_DEFINE(WriteData, 121, "Write data error"); INDEX_ERROR_CODE_DEFINE(ReadData, 122, "Read data error"); INDEX_ERROR_CODE_DEFINE(PackIndex, 201, "Read data error"); INDEX_ERROR_CODE_DEFINE(UnpackIndex, 202, "Read data error"); INDEX_ERROR_CODE_DEFINE(IndexLoaded, 203, "Index loaded"); INDEX_ERROR_CODE_DEFINE(NoIndexLoaded, 204, "No index loaded"); INDEX_ERROR_CODE_DEFINE(NoTrained, 205, "No trained"); INDEX_ERROR_CODE_DEFINE(IndexFull, 206, "Index full"); } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_factory.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include namespace zvec { namespace core { IndexMetric::Pointer IndexFactory::CreateMetric(const std::string &name) { IndexMetric::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasMetric(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllMetrics(void) { return ailego::Factory::Classes(); } IndexLogger::Pointer IndexFactory::CreateLogger(const std::string &name) { IndexLogger::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasLogger(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllLoggers(void) { return ailego::Factory::Classes(); } IndexDumper::Pointer IndexFactory::CreateDumper(const std::string &name) { IndexDumper::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasDumper(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllDumpers(void) { return ailego::Factory::Classes(); } IndexStorage::Pointer IndexFactory::CreateStorage(const std::string &name) { IndexStorage::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasStorage(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllStorages(void) { return ailego::Factory::Classes(); } IndexConverter::Pointer IndexFactory::CreateConverter(const std::string &name) { IndexConverter::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasConverter(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllConverters(void) { return ailego::Factory::Classes(); } IndexReformer::Pointer IndexFactory::CreateReformer(const std::string &name) { IndexReformer::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasReformer(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllReformers(void) { return ailego::Factory::Classes(); } IndexTrainer::Pointer IndexFactory::CreateTrainer(const std::string &name) { IndexTrainer::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasTrainer(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllTrainers(void) { return ailego::Factory::Classes(); } IndexBuilder::Pointer IndexFactory::CreateBuilder(const std::string &name) { IndexBuilder::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasBuilder(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllBuilders(void) { return ailego::Factory::Classes(); } IndexSearcher::Pointer IndexFactory::CreateSearcher(const std::string &name) { IndexSearcher::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasSearcher(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllSearchers(void) { return ailego::Factory::Classes(); } IndexStreamer::Pointer IndexFactory::CreateStreamer(const std::string &name) { IndexStreamer::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasStreamer(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllStreamers(void) { return ailego::Factory::Classes(); } IndexReducer::Pointer IndexFactory::CreateReducer(const std::string &name) { IndexReducer::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasReducer(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllReducers(void) { return ailego::Factory::Classes(); } IndexCluster::Pointer IndexFactory::CreateCluster(const std::string &name) { IndexCluster::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasCluster(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllClusters(void) { return ailego::Factory::Classes(); } IndexStreamerReducer::Pointer IndexFactory::CreateStreamerReducer( const std::string &name) { IndexStreamerReducer::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasStreamerReducer(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllStreamerReducers(void) { return ailego::Factory::Classes(); } IndexRefiner::Pointer IndexFactory::CreateRefiner(const std::string &name) { IndexRefiner::Pointer obj = ailego::Factory::MakeShared(name.c_str()); if (obj) { obj->set_name(name); } return obj; } bool IndexFactory::HasRefiner(const std::string &name) { return ailego::Factory::Has(name.c_str()); } std::vector IndexFactory::AllRefiners(void) { return ailego::Factory::Classes(); } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_flow.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include //! Default storage #define INDEX_FLOW_STORAGE_DEFAULT "MMapFileReadStorage" namespace zvec { namespace core { // Index Flow int IndexFlow::set_storage(const std::string &name, const ailego::Params ¶ms) { storage_ = IndexFactory::CreateStorage(name); if (!storage_) { LOG_ERROR("Failed to create a index storage with name: %s", name.c_str()); return IndexError_NoExist; } int ret = storage_->init(params); if (ret < 0) { storage_ = nullptr; LOG_ERROR("Failed to initialize index storage %s", name.c_str()); return ret; } return 0; } int IndexFlow::set_searcher(IndexSearcher::Pointer searcher) { user_searcher_ = searcher; return 0; } int IndexFlow::set_searcher(const std::string &name, const ailego::Params ¶ms) { user_searcher_ = IndexFactory::CreateSearcher(name); if (!user_searcher_) { LOG_ERROR("Failed to create a index searcher with name: %s", name.c_str()); return IndexError_NoExist; } int ret = user_searcher_->init(params); if (ret < 0) { user_searcher_ = nullptr; LOG_ERROR("Failed to initialize index searcher %s", name.c_str()); return ret; } return 0; } int IndexFlow::set_reformer(const std::string &name, const ailego::Params ¶ms) { user_reformer_ = IndexFactory::CreateReformer(name); if (!user_reformer_) { LOG_ERROR("Failed to create a index reformer with name: %s", name.c_str()); return IndexError_NoExist; } int ret = user_reformer_->init(params); if (ret < 0) { user_reformer_ = nullptr; LOG_ERROR("Failed to initialize index reformer %s", name.c_str()); return ret; } return 0; } int IndexFlow::set_metric(const std::string &name, const ailego::Params ¶ms) { if (!IndexFactory::HasMetric(name)) { LOG_ERROR("The index metric with name %s does not exist.", name.c_str()); return IndexError_NoExist; } user_metric_name_ = name; user_metric_params_ = params; return 0; } int IndexFlow::load(const std::string &path) { // Prepare storage if (!storage_) { this->set_storage(INDEX_FLOW_STORAGE_DEFAULT, ailego::Params()); } if (!storage_) { LOG_ERROR("The index storage is uninitialized."); return IndexError_Uninitialized; } int ret = storage_->open(path, false); if (ret != 0) { LOG_ERROR("Failed to load index with storage %s", storage_->name().c_str()); return ret; } ret = IndexHelper::DeserializeFromStorage(storage_.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize index meta with storage %s", storage_->name().c_str()); return ret; } ret = load_internal(); if (ret != 0) { LOG_ERROR("Failed to load index with storage %s", storage_->name().c_str()); return ret; } return 0; } int IndexFlow::load_internal() { // Prepare metric const std::string &metric_name = user_metric_name_.empty() ? meta_.metric_name() : user_metric_name_; const ailego::Params &metric_params = user_metric_name_.empty() ? meta_.metric_params() : user_metric_params_; if (metric_name.empty()) { LOG_ERROR("The metric name from index file is empty."); return IndexError_NoExist; } metric_ = IndexFactory::CreateMetric(metric_name); if (!metric_) { LOG_ERROR("Failed to create a index metric with name: %s", metric_name.c_str()); return IndexError_NoExist; } int ret = metric_->init(meta_, metric_params); if (ret < 0) { LOG_ERROR("Failed to initialize index metric %s", metric_name.c_str()); metric_ = nullptr; return ret; } if (!metric_->is_matched(meta_)) { LOG_ERROR("The index meta is unmatched for index metric %s", metric_->name().c_str()); return IndexError_Mismatch; } auto query_metric = metric_->query_metric(); if (query_metric) { metric_ = query_metric; } // Prepare reformer if (!user_reformer_) { const std::string &reformer_name = meta_.reformer_name(); if (!reformer_name.empty()) { reformer_ = IndexFactory::CreateReformer(reformer_name); if (!reformer_) { LOG_ERROR("Failed to create a index reformer with name: %s", reformer_name.c_str()); return IndexError_NoExist; } ret = reformer_->init(meta_.reformer_params()); if (ret < 0) { LOG_ERROR("Failed to initialize index reformer %s", reformer_name.c_str()); reformer_ = nullptr; return ret; } } } else { // Using user reformer reformer_ = user_reformer_; } if (reformer_) { ret = reformer_->load(storage_); if (ret < 0) { LOG_ERROR("Failed to load index with reformer %s, storage %s", reformer_->name().c_str(), storage_->name().c_str()); return ret; } } // Prepare searcher if (!user_searcher_) { const std::string &name = meta_.searcher_name(); if (name.empty()) { LOG_ERROR("The searcher name from index file is empty."); return IndexError_NoExist; } searcher_ = IndexFactory::CreateSearcher(name); if (!searcher_) { LOG_ERROR("Failed to create a index searcher with name: %s", name.c_str()); return IndexError_NoExist; } ret = searcher_->init(meta_.searcher_params()); if (ret < 0) { LOG_ERROR("Failed to initialize index searcher %s", name.c_str()); searcher_ = nullptr; return ret; } } else { // Using user searcher searcher_ = user_searcher_; } ret = searcher_->load(storage_, metric_); if (ret < 0) { LOG_ERROR("Failed to load index with searcher %s, storage %s, metric %s", searcher_->name().c_str(), storage_->name().c_str(), metric_->name().c_str()); return ret; } // searcher_->print_all_neighbour(); return 0; } int IndexFlow::unload(void) { if (searcher_) { int ret = searcher_->unload(); if (ret < 0) { LOG_WARN("Unload index searcher %s error, %d", searcher_->name().c_str(), ret); } searcher_ = nullptr; } if (reformer_) { int ret = reformer_->unload(); if (ret < 0) { LOG_WARN("Unload index reformer %s error, %d", reformer_->name().c_str(), ret); } reformer_ = nullptr; } if (metric_) { int ret = metric_->cleanup(); if (ret < 0) { LOG_WARN("Cleanup index metric %s error, %d", metric_->name().c_str(), ret); } metric_ = nullptr; } if (storage_) { int ret = storage_->cleanup(); if (ret < 0) { LOG_WARN("Unload index searcher %s error, %d", storage_->name().c_str(), ret); } storage_ = nullptr; } return 0; } int IndexFlow::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { IndexQueryMeta new_qmeta; error_code = reformer_->transform(query, qmeta, context->mutable_features(), &new_qmeta); if (error_code == 0) { if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl( reinterpret_cast(context->features().data()), new_qmeta, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl(query, qmeta, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (auto &it : const_cast( context->searcher_context()->result())) { metric_->normalize(it.mutable_score()); } } if (reformer_) { error_code = reformer_->normalize(query, qmeta, const_cast( context->searcher_context()->result())); } } return error_code; } int IndexFlow::search_impl(const void *query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { IndexQueryMeta new_qmeta; error_code = reformer_->transform(query, qmeta, context->mutable_features(), &new_qmeta); if (error_code == 0) { if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl( reinterpret_cast(context->features().data()), new_qmeta, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl(query, qmeta, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (auto &it : const_cast( context->searcher_context()->result())) { metric_->normalize(it.mutable_score()); } } if (reformer_) { error_code = reformer_->normalize(query, qmeta, const_cast( context->searcher_context()->result())); } } return error_code; } int IndexFlow::search_bf_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !count || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { IndexQueryMeta new_qmeta; error_code = reformer_->transform(query, qmeta, count, context->mutable_features(), &new_qmeta); if (error_code == 0) { if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl( reinterpret_cast(context->features().data()), new_qmeta, count, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl(query, qmeta, count, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (uint32_t i = 0; i < count; ++i) { IndexDocumentList &result = const_cast( context->searcher_context()->result(i)); for (auto &it : result) { metric_->normalize(it.mutable_score()); } } } if (reformer_) { size_t offset = 0; for (uint32_t i = 0; i < count; ++i) { error_code = reformer_->normalize( reinterpret_cast(query) + offset, qmeta, const_cast( context->searcher_context()->result(i))); if (error_code != 0) { break; } offset += qmeta.element_size(); } } } return error_code; } int IndexFlow::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !count || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { IndexQueryMeta new_qmeta; error_code = reformer_->transform(query, qmeta, count, context->mutable_features(), &new_qmeta); if (error_code == 0) { if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl( reinterpret_cast(context->features().data()), new_qmeta, count, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl(query, qmeta, count, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (uint32_t i = 0; i < count; ++i) { IndexDocumentList &result = const_cast( context->searcher_context()->result(i)); for (auto &it : result) { metric_->normalize(it.mutable_score()); } } } if (reformer_) { size_t offset = 0; for (uint32_t i = 0; i < count; ++i) { error_code = reformer_->normalize( reinterpret_cast(query) + offset, qmeta, const_cast( context->searcher_context()->result(i))); if (error_code != 0) { break; } offset += qmeta.element_size(); } } } return error_code; } // Index Sparse Flow int IndexSparseFlow::set_storage(const std::string &name, const ailego::Params ¶ms) { storage_ = IndexFactory::CreateStorage(name); if (!storage_) { LOG_ERROR("Failed to create a index storage with name: %s", name.c_str()); return IndexError_NoExist; } int ret = storage_->init(params); if (ret < 0) { storage_ = nullptr; LOG_ERROR("Failed to initialize index storage %s", name.c_str()); return ret; } return 0; } int IndexSparseFlow::set_searcher(IndexSearcher::Pointer searcher) { user_searcher_ = searcher; return 0; } int IndexSparseFlow::set_searcher(const std::string &name, const ailego::Params ¶ms) { user_searcher_ = IndexFactory::CreateSearcher(name); if (!user_searcher_) { LOG_ERROR("Failed to create a index sparse searcher with name: %s", name.c_str()); return IndexError_NoExist; } int ret = user_searcher_->init(params); if (ret < 0) { user_searcher_ = nullptr; LOG_ERROR("Failed to initialize index sparse searcher %s", name.c_str()); return ret; } return 0; } int IndexSparseFlow::set_reformer(const std::string &name, const ailego::Params ¶ms) { user_reformer_ = IndexFactory::CreateReformer(name); if (!user_reformer_) { LOG_ERROR("Failed to create a index sparse reformer with name: %s", name.c_str()); return IndexError_NoExist; } int ret = user_reformer_->init(params); if (ret < 0) { user_reformer_ = nullptr; LOG_ERROR("Failed to initialize index sparse reformer %s", name.c_str()); return ret; } return 0; } int IndexSparseFlow::set_metric(const std::string &name, const ailego::Params ¶ms) { if (!IndexFactory::HasMetric(name)) { LOG_ERROR("The index metric with name %s does not exist.", name.c_str()); return IndexError_NoExist; } user_metric_name_ = name; user_metric_params_ = params; return 0; } int IndexSparseFlow::load(const std::string &path) { // Prepare storage if (!storage_) { this->set_storage(INDEX_FLOW_STORAGE_DEFAULT, ailego::Params()); } if (!storage_) { LOG_ERROR("The index storage is uninitialized."); return IndexError_Uninitialized; } int ret = storage_->open(path, false); if (ret != 0) { LOG_ERROR("Failed to load index with storage %s", storage_->name().c_str()); return ret; } ret = IndexHelper::DeserializeFromStorage(storage_.get(), &meta_); if (ret != 0) { LOG_ERROR("Failed to deserialize index meta with storage %s", storage_->name().c_str()); return ret; } ret = load_internal(); if (ret != 0) { LOG_ERROR("Failed to load index with storage %s", storage_->name().c_str()); return ret; } return 0; } int IndexSparseFlow::load_internal() { // Prepare metric const std::string &metric_name = user_metric_name_.empty() ? meta_.metric_name() : user_metric_name_; const ailego::Params &metric_params = user_metric_name_.empty() ? meta_.metric_params() : user_metric_params_; if (metric_name.empty()) { LOG_ERROR("The metric name from index file is empty."); return IndexError_NoExist; } metric_ = IndexFactory::CreateMetric(metric_name); if (!metric_) { LOG_ERROR("Failed to create a index metric with name: %s", metric_name.c_str()); return IndexError_NoExist; } int ret = metric_->init(meta_, metric_params); if (ret < 0) { LOG_ERROR("Failed to initialize index metric %s", metric_name.c_str()); metric_ = nullptr; return ret; } auto query_metric = metric_->query_metric(); if (query_metric) { metric_ = query_metric; } // Prepare reformer if (!user_reformer_) { const std::string &reformer_name = meta_.reformer_name(); if (!reformer_name.empty()) { reformer_ = IndexFactory::CreateReformer(reformer_name); if (!reformer_) { LOG_ERROR("Failed to create a index sparse reformer with name: %s", reformer_name.c_str()); return IndexError_NoExist; } ret = reformer_->init(meta_.reformer_params()); if (ret < 0) { LOG_ERROR("Failed to initialize index reformer %s", reformer_name.c_str()); reformer_ = nullptr; return ret; } } } else { // Using user reformer reformer_ = user_reformer_; } if (reformer_) { ret = reformer_->load(storage_); if (ret < 0) { LOG_ERROR("Failed to load index with reformer %s, storage %s", reformer_->name().c_str(), storage_->name().c_str()); return ret; } } // Prepare searcher if (!user_searcher_) { const std::string &name = meta_.searcher_name(); if (name.empty()) { LOG_ERROR("The searcher name from index file is empty."); return IndexError_NoExist; } searcher_ = IndexFactory::CreateSearcher(name); if (!searcher_) { LOG_ERROR("Failed to create a index searcher with name: %s", name.c_str()); return IndexError_NoExist; } ret = searcher_->init(meta_.searcher_params()); if (ret < 0) { LOG_ERROR("Failed to initialize index searcher %s", name.c_str()); searcher_ = nullptr; return ret; } } else { // Using user searcher searcher_ = user_searcher_; } ret = searcher_->load(storage_, metric_); if (ret < 0) { LOG_ERROR("Failed to load index with searcher %s, storage %s, metric %s", searcher_->name().c_str(), storage_->name().c_str(), metric_->name().c_str()); return ret; } // searcher_->print_all_neighbour(); return 0; } int IndexSparseFlow::unload(void) { if (searcher_) { int ret = searcher_->unload(); if (ret < 0) { LOG_WARN("Unload index searcher %s error, %d", searcher_->name().c_str(), ret); } searcher_ = nullptr; } if (reformer_) { int ret = reformer_->unload(); if (ret < 0) { LOG_WARN("Unload index reformer %s error, %d", reformer_->name().c_str(), ret); } reformer_ = nullptr; } if (metric_) { int ret = metric_->cleanup(); if (ret < 0) { LOG_WARN("Cleanup index metric %s error, %d", metric_->name().c_str(), ret); } metric_ = nullptr; } if (storage_) { int ret = storage_->cleanup(); if (ret < 0) { LOG_WARN("Unload index searcher %s error, %d", storage_->name().c_str(), ret); } storage_ = nullptr; } return 0; } int IndexSparseFlow::search_bf_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { if (ailego_unlikely(!context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { std::string ovec; IndexQueryMeta new_qmeta; error_code = reformer_->transform(sparse_count, sparse_indices, sparse_query, qmeta, &ovec, &new_qmeta); if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } if (error_code == 0) { error_code = searcher_->search_bf_impl(sparse_count, sparse_indices, ovec.data(), new_qmeta, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (auto &it : const_cast( context->searcher_context()->result())) { metric_->normalize(it.mutable_score()); } } } return error_code; } int IndexSparseFlow::search_impl(const uint32_t sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, Context::Pointer &context) const { if (ailego_unlikely(!context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { std::string ovec; IndexQueryMeta new_qmeta; error_code = reformer_->transform(sparse_count, sparse_indices, sparse_query, qmeta, &ovec, &new_qmeta); if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } if (error_code == 0) { error_code = searcher_->search_impl(sparse_count, sparse_indices, ovec.data(), new_qmeta, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl(sparse_count, sparse_indices, sparse_query, qmeta, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (auto &it : const_cast( context->searcher_context()->result())) { metric_->normalize(it.mutable_score()); } } } return error_code; } int IndexSparseFlow::search_bf_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!count || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { std::string ovec; IndexQueryMeta new_qmeta; error_code = reformer_->transform(sparse_count, sparse_indices, sparse_query, qmeta, count, &ovec, &new_qmeta); if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } if (error_code == 0) { error_code = searcher_->search_bf_impl(sparse_count, sparse_indices, ovec.data(), new_qmeta, count, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_bf_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (uint32_t i = 0; i < count; ++i) { IndexDocumentList &result = const_cast( context->searcher_context()->result(i)); for (auto &it : result) { metric_->normalize(it.mutable_score()); } } } } return error_code; } int IndexSparseFlow::search_impl(const uint32_t *sparse_count, const uint32_t *sparse_indices, const void *sparse_query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!count || !context)) { return IndexError_InvalidArgument; } int error_code = 0; if (reformer_) { std::string ovec; IndexQueryMeta new_qmeta; error_code = reformer_->transform(sparse_count, sparse_indices, sparse_query, qmeta, count, &ovec, &new_qmeta); if (ailego_unlikely(!metric_->is_matched(meta_, new_qmeta))) { return IndexError_Mismatch; } if (error_code == 0) { error_code = searcher_->search_impl(sparse_count, sparse_indices, ovec.data(), new_qmeta, count, context->searcher_context()); } } else { if (ailego_unlikely(!metric_->is_matched(meta_, qmeta))) { return IndexError_Mismatch; } error_code = searcher_->search_impl(sparse_count, sparse_indices, sparse_query, qmeta, count, context->searcher_context()); } if (error_code == 0) { if (metric_->support_normalize()) { for (uint32_t i = 0; i < count; ++i) { IndexDocumentList &result = const_cast( context->searcher_context()->result(i)); for (auto &it : result) { metric_->normalize(it.mutable_score()); } } } } return error_code; } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_helper.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include namespace zvec { namespace core { int IndexHelper::SerializeToDumper(const IndexMeta &mt, IndexDumper *dumper, const std::string &key) { std::string buffer; mt.serialize(&buffer); size_t data_size = buffer.size(); uint32_t data_crc = ailego::Crc32c::Hash(buffer.data(), buffer.size(), 0); buffer.resize((data_size + 31u) & ~31u); if (dumper->write(buffer.data(), buffer.size()) != buffer.size()) { return IndexError_WriteData; } if (dumper->append(key, data_size, buffer.size() - data_size, data_crc) != 0) { return IndexError_WriteData; } return IndexError_Success; } int IndexHelper::SerializeToStorage(const IndexMeta &mt, IndexStorage *storage, const std::string &key) { std::string buffer; mt.serialize(&buffer); auto segment = storage->get(key); if (!segment) { const size_t align_size = 4096 * 4; size_t meta_size = (buffer.size() + align_size - 1) / align_size * align_size; if (storage->append(key, meta_size) != 0) { return IndexError_WriteData; } segment = storage->get(key); if (!segment) { return IndexError_NoExist; } } if (segment->write(0, buffer.data(), buffer.size()) != buffer.size()) { return IndexError_WriteData; } segment->resize(buffer.size()); segment->update_data_crc( ailego::Crc32c::Hash(buffer.data(), buffer.size(), 0)); return IndexError_Success; } int IndexHelper::DeserializeFromStorage(IndexStorage *storage, const std::string &key, IndexMeta *out) { auto segment = storage->get(key); if (!segment) { return IndexError_NoExist; } uint32_t crc = segment->data_crc(); size_t len = segment->data_size(); const void *data = nullptr; if (segment->read(0, &data, len) != len) { return IndexError_ReadData; } if (crc != 0u && ailego::Crc32c::Hash(data, len, 0u) != crc) { return IndexError_InvalidChecksum; } if (!out->deserialize(data, len)) { return IndexError_Deserialize; } return IndexError_Success; } /*! Two Pass Index Holder */ class TwoPassIndexHolder : public IndexHolder { private: /*! First Pass Iterator * store elements during iterating for second iterating. */ class FirstPassIterator : public IndexHolder::Iterator { public: //! Index Holder Iterator Pointer typedef std::unique_ptr Pointer; //! Constructor FirstPassIterator(TwoPassIndexHolder *owner, IndexHolder::Iterator::Pointer &&iter) : holder_(owner), front_iter_(std::move(iter)) {} //! Destructor virtual ~FirstPassIterator(void) {} //! Retrieve pointer of data const void *data(void) const override { return front_iter_->data(); } //! Test if the iterator is valid bool is_valid(void) const override { return front_iter_->is_valid(); } //! Retrieve primary key uint64_t key(void) const override { return front_iter_->key(); } //! Next iterator void next(void) override { holder_->features_.emplace_back( front_iter_->key(), std::string((const char *)front_iter_->data(), holder_->front_->element_size())); front_iter_->next(); } private: TwoPassIndexHolder *holder_{nullptr}; IndexHolder::Iterator::Pointer front_iter_{}; }; class SecondPassIterator : public IndexHolder::Iterator { public: //! Second Pass Iterator Pointer typedef std::unique_ptr Pointer; //! Constructor SecondPassIterator(TwoPassIndexHolder *owner) : holder_(owner) { features_iter_ = holder_->features_.begin(); } //! Destructor virtual ~SecondPassIterator(void) {} //! Retrieve pointer of data const void *data(void) const override { return features_iter_->second.data(); } //! Test if the iterator is valid bool is_valid(void) const override { return (features_iter_ != holder_->features_.end()); } //! Retrieve primary key uint64_t key(void) const override { return features_iter_->first; } //! Next iterator void next(void) override { holder_->features_.erase(features_iter_++); } private: TwoPassIndexHolder *holder_{nullptr}; typename std::list>::iterator features_iter_{}; }; public: //! Constructor TwoPassIndexHolder(IndexHolder::Pointer &&front) : front_(std::move(front)), data_type_(front_->data_type()), dimension_(front_->dimension()), element_size_(front_->element_size()), count_(front_->count()) {} //! Retrieve count of elements in holder (-1 indicates unknown) size_t count(void) const override { return count_; } //! Retrieve dimension size_t dimension(void) const override { return dimension_; } //! Retrieve type information IndexMeta::DataType data_type(void) const override { return data_type_; } //! Retrieve element size in bytes size_t element_size(void) const override { return element_size_; } //! Retrieve if it can multi-pass bool multipass(void) const override { return false; } //! Create a new iterator IndexHolder::Iterator::Pointer create_iterator(void) override { ++pass_; if (pass_ == 1) { IndexHolder::Iterator::Pointer iter = front_->create_iterator(); return iter ? IndexHolder::Iterator::Pointer( new TwoPassIndexHolder::FirstPassIterator( this, std::move(iter))) : IndexHolder::Iterator::Pointer(); } else if (pass_ == 2) { return IndexHolder::Iterator::Pointer( new TwoPassIndexHolder::SecondPassIterator(this)); } return nullptr; } private: //! Disable them TwoPassIndexHolder(void) = delete; //! Members IndexHolder::Pointer front_{}; std::list> features_{}; size_t pass_{0}; IndexMeta::DataType data_type_{IndexMeta::DataType::DT_UNDEFINED}; size_t dimension_; size_t element_size_; size_t count_; }; IndexHolder::Pointer IndexHelper::MakeTwoPassHolder( IndexHolder::Pointer holder) { if (holder->multipass()) { return holder; } return IndexHolder::Pointer(new TwoPassIndexHolder(std::move(holder))); } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_logger.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include namespace zvec { namespace core { const int IndexLogger::LEVEL_DEBUG = 0; const int IndexLogger::LEVEL_INFO = 1; const int IndexLogger::LEVEL_WARN = 2; const int IndexLogger::LEVEL_ERROR = 3; const int IndexLogger::LEVEL_FATAL = 4; /*! Console Logger */ struct ConsoleLogger : public IndexLogger { //! Initialize Logger int init(const zvec::ailego::Params &) override { return 0; } //! Cleanup Logger int cleanup(void) override { return 0; } //! Log Message void log(int level, const char *file, int line, const char *format, va_list args) override { char buffer[8192]; std::ostringstream stream; ailego::Realtime::Localtime(buffer, sizeof(buffer)); stream << '[' << LevelString(level) << ' ' << buffer << ' ' << std::this_thread::get_id() << ' ' << ailego::File::BaseName(file) << ':' << line << "] "; vsnprintf(buffer, sizeof(buffer), format, args); stream << buffer << '\n'; if (level <= LEVEL_INFO) { std::cout << stream.str() << std::flush; } else { std::cerr << stream.str() << std::flush; } } }; //! Logger Level int IndexLoggerBroker::logger_level_ = 0; //! Logger IndexLogger::Pointer IndexLoggerBroker::logger_(new ConsoleLogger); } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_mapping.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "ailego/utility/memory_helper.h" #ifdef __linux__ #include #include #ifndef HUGETLBFS_MAGIC #define HUGETLBFS_MAGIC 0x958458f6 #endif #endif namespace zvec { namespace core { static inline size_t CalcPageAlignedSize(size_t size, bool huge_size) { size_t page_size = ailego::MemoryHelper::PageSize(); if (huge_size) { page_size = ailego::MemoryHelper::HugePageSize(); } return (size + page_size - 1) / page_size * page_size; } static inline bool WritePadding(ailego::File &file, size_t size) { std::string padding(ailego::MemoryHelper::PageSize(), 0); for (size_t i = 0, count = size / padding.size(); i < count; ++i) { if (file.write(padding.data(), padding.size()) != padding.size()) { return false; } } padding.resize(size % padding.size()); if (padding.size()) { if (file.write(padding.data(), padding.size()) != padding.size()) { return false; } } return true; } static inline int UnpackMappingSize(ailego::File &file, size_t *len) { IndexFormat::MetaHeader header; if (file.read(&header, sizeof(header)) != sizeof(header)) { LOG_ERROR("Failed to read file, errno %d, %s", errno, std::strerror(errno)); return IndexError_ReadData; } if (header.meta_header_size != sizeof(IndexFormat::MetaHeader) || header.meta_footer_size != sizeof(IndexFormat::MetaFooter)) { return IndexError_InvalidValue; } if (ailego::Crc32c::Hash(&header, sizeof(header), header.header_crc) != header.header_crc) { return IndexError_InvalidChecksum; } if ((int32_t)header.meta_footer_offset < 0) { return IndexError_Unsupported; } *len = header.meta_footer_offset + header.meta_footer_size; if (*len > file.size()) { return IndexError_InvalidLength; } return 0; } int IndexMapping::open(const std::string &path, bool cow, bool full_mode) { path_ = path; full_mode_ = full_mode; copy_on_write_ = cow; huge_page_ = Ishugetlbfs(path); bool read_only = copy_on_write_ && !full_mode_; if (!file_.open(path.c_str(), read_only, false)) { LOG_ERROR("Failed to open file %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_OpenFile; } size_t mapping_size = 0u; int error_code = UnpackMappingSize(file_, &mapping_size); if (error_code != 0) { file_.close(); return error_code; } if (!file_.seek(0, ailego::File::Origin::End)) { LOG_ERROR("Failed to seek file %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_SeekFile; } return this->init_index_mapping(mapping_size); } int IndexMapping::create(const std::string &path, size_t seg_meta_capacity) { path_ = path; seg_meta_capacity_ = seg_meta_capacity; current_header_start_offset_ = 0; // write() & copying to mmap() will auto extend the file size if (!file_.create(path.c_str(), 0)) { LOG_ERROR("Failed to create file %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_CreateFile; } huge_page_ = Ishugetlbfs(path); if (huge_page_) { return init_hugepage_meta_section(); } return init_meta_section(); } int IndexMapping::init_meta_section() { if (current_header_start_offset_ % ailego::MemoryHelper::PageSize() != 0) { LOG_ERROR("File offset %zu is not a multiple of the page size: %zu", current_header_start_offset_, ailego::MemoryHelper::PageSize()); return IndexError_InvalidValue; } auto &path = path_; size_t len = CalcPageAlignedSize(seg_meta_capacity_ + sizeof(IndexFormat::MetaHeader) + sizeof(IndexFormat::MetaFooter), false); IndexFormat::MetaHeader meta_header; IndexFormat::MetaFooter meta_footer; // Write index header IndexFormat::SetupMetaHeader(&meta_header, len - sizeof(meta_footer), len); if (!file_.seek(current_header_start_offset_, ailego::File::Origin::Begin)) { LOG_ERROR("Failed to seek file %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_SeekFile; } if (file_.write(&meta_header, sizeof(meta_header)) != sizeof(meta_header)) { LOG_ERROR("Failed to write file: %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_WriteData; } // Write padding data uint32_t segments_meta_size = static_cast(len - (sizeof(meta_header) + sizeof(meta_footer))); if (!WritePadding(file_, segments_meta_size)) { LOG_ERROR("Failed to write file: %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_WriteData; } // Write index footer IndexFormat::SetupMetaFooter(&meta_footer); meta_footer.segments_meta_size = segments_meta_size; meta_footer.total_size = len; IndexFormat::UpdateMetaFooter(&meta_footer, 0); if (file_.write(&meta_footer, sizeof(meta_footer)) != sizeof(meta_footer)) { LOG_ERROR("Failed to write file: %s, errno %d, %s", path.c_str(), errno, std::strerror(errno)); return IndexError_WriteData; } return this->init_index_mapping(len); } int IndexMapping::init_hugepage_meta_section() { ssize_t file_offset = (ssize_t)current_header_start_offset_; if (file_offset % ailego::MemoryHelper::HugePageSize() != 0) { LOG_ERROR("File offset %zu is not a multiple of the page size: %zu", file_offset, ailego::MemoryHelper::HugePageSize()); return IndexError_InvalidValue; } size_t len = CalcPageAlignedSize(seg_meta_capacity_ + sizeof(IndexFormat::MetaHeader) + sizeof(IndexFormat::MetaFooter), true); int opts = ailego::File::MMAP_SHARED | ailego::File::MMAP_HUGE_PAGE; void *addr = ailego::File::MemoryMap(file_.native_handle(), file_offset, len, opts); IndexFormat::MetaHeader meta_header; IndexFormat::MetaFooter meta_footer; // Write index header IndexFormat::SetupMetaHeader(&meta_header, len - sizeof(meta_footer), len); memcpy((char *)addr + file_offset, &meta_header, sizeof(meta_header)); file_offset += sizeof(meta_header); // Write padding data uint32_t segments_meta_size = static_cast(len - (sizeof(meta_header) + sizeof(meta_footer))); std::string padding(ailego::MemoryHelper::HugePageSize(), 0); for (size_t i = 0, count = segments_meta_size / padding.size(); i < count; ++i) { memcpy((char *)addr + file_offset, padding.data(), padding.size()); file_offset += padding.size(); } padding.resize(segments_meta_size % padding.size()); if (padding.size()) { memcpy((char *)addr + file_offset, padding.data(), padding.size()); file_offset += padding.size(); } // Write index footer IndexFormat::SetupMetaFooter(&meta_footer); meta_footer.segments_meta_size = segments_meta_size; meta_footer.total_size = len; IndexFormat::UpdateMetaFooter(&meta_footer, 0); memcpy((char *)addr + file_offset, &meta_footer, sizeof(meta_footer)); file_offset += sizeof(meta_footer); return this->init_index_mapping(len); } void IndexMapping::close(void) { // Unmap all memory this->unmap_all(); if (header_) { for (auto item : header_addr_map_) { auto header = item.second; ailego::File::MemoryUnmap(header, header->content_offset); } } // Reset members segment_ids_offset_ = 0; segment_start_ = nullptr; header_ = nullptr; header_addr_map_.clear(); footer_ = nullptr; index_size_ = 0u; segments_.clear(); file_.close(); copy_on_write_ = false; full_mode_ = false; header_dirty_ = false; huge_page_ = false; } void IndexMapping::refresh(uint64_t check_point) { // support add_with_id for (auto item : header_addr_map_) { auto header_start_offset = item.first; auto header = item.second; auto footer = reinterpret_cast( reinterpret_cast(header) + header->meta_footer_offset); auto segment_start = reinterpret_cast( reinterpret_cast(header) + (header->meta_footer_offset - footer->segments_meta_size)); footer->segments_meta_crc = ailego::Crc32c::Hash(segment_start, footer->segments_meta_size, 0); IndexFormat::UpdateMetaFooter(footer, check_point); } header_dirty_ = true; } int IndexMapping::append(const std::string &id, size_t size) { size = CalcPageAlignedSize(size, huge_page_); if (size == 0) { return IndexError_InvalidArgument; } if (segments_.find(id) != segments_.end()) { return IndexError_Duplicate; } size_t id_size = std::strlen(id.c_str()) + 1; size_t need_size = sizeof(IndexFormat::SegmentMeta) + id_size; if (sizeof(IndexFormat::SegmentMeta) * footer_->segment_count + need_size > segment_ids_offset_) { LOG_DEBUG("segment meta section expanded: %s", path_.c_str()); footer_->next_meta_header_offset = index_size_; refresh(0); flush(); // mmap file storage write() will update segment's meta // ailego::File::MemoryUnmap(header_, header_->content_offset); header_ = nullptr; footer_ = nullptr; current_header_start_offset_ = index_size_; const int ret = huge_page_ ? init_hugepage_meta_section() : init_meta_section(); if (ret != 0) { return ret; } } if (!copy_on_write_ && !file_.truncate(index_size_ + size)) { LOG_ERROR("Failed to truncate file, errno %d, %s", errno, std::strerror(errno)); return IndexError_TruncateFile; } // Update segment table segment_ids_offset_ -= static_cast(id_size); IndexFormat::SegmentMeta *segment = segment_start_ + footer_->segment_count; segment->segment_id_offset = segment_ids_offset_; segment->data_index = index_size_ - header_->content_offset - current_header_start_offset_; segment->data_size = 0; segment->data_crc = 0; segment->padding_size = size; memcpy((uint8_t *)segment_start_ + segment_ids_offset_, id.c_str(), id_size); index_size_ += size; // Update index footer footer_->segments_meta_crc = ailego::Crc32c::Hash(segment_start_, footer_->segments_meta_size, 0); footer_->segment_count += 1; footer_->content_size += size; footer_->total_size += size; IndexFormat::UpdateMetaFooter(footer_, 0); segments_.emplace( id, SegmentInfo{Segment{segment}, current_header_start_offset_, header_}); header_dirty_ = true; return 0; } IndexMapping::Segment *IndexMapping::map(const std::string &id, bool warmup, bool locked) { auto iter = segments_.find(id); if (iter == segments_.end()) { return nullptr; } SegmentInfo &segment_info = iter->second; Segment *item = &segment_info.segment; if (!item->data()) { auto meta = item->meta(); size_t mapping_size = meta->data_size + meta->padding_size; size_t offset = segment_info.segment_header_start_offset + segment_info.segment_header->content_offset + meta->data_index; void *addr = nullptr; if (!copy_on_write_) { int opts = ailego::File::MMAP_SHARED; if (huge_page_) { opts |= ailego::File::MMAP_HUGE_PAGE; } addr = ailego::File::MemoryMap(file_.native_handle(), offset, mapping_size, opts); } else { size_t file_size = file_.size(); int opts = ailego::File::MMAP_POPULATE; if (huge_page_) { opts |= ailego::File::MMAP_HUGE_PAGE; } if (offset < file_size) { ailego_assert(offset + mapping_size <= file_size); addr = ailego::File::MemoryMap(file_.native_handle(), offset, mapping_size, opts); } else { addr = ailego::File::MemoryMap(mapping_size, opts); } } if (!addr) { LOG_ERROR("Map segment failed, segment id %s", id.c_str()); return nullptr; } item->set_data(addr); // Lock memory if (locked) { ailego::File::MemoryLock(item->data(), mapping_size); } // Warmup memory if (warmup && meta->data_size) { ailego::File::MemoryWarmup(item->data(), meta->data_size); } } return item; } void IndexMapping::unmap(const std::string &id) { auto iter = segments_.find(id); if (iter != segments_.end()) { SegmentInfo &segment_info = iter->second; Segment *item = &segment_info.segment; if (item->data()) { ailego::File::MemoryUnmap( item->data(), item->meta()->data_size + item->meta()->padding_size); item->set_data(nullptr); } } } void IndexMapping::unmap_all(void) { for (auto iter = segments_.begin(); iter != segments_.end(); ++iter) { SegmentInfo &segment_info = iter->second; Segment *item = &segment_info.segment; if (item->data()) { ailego::File::MemoryUnmap( item->data(), item->meta()->data_size + item->meta()->padding_size); item->set_data(nullptr); } } } int IndexMapping::flush(void) { if ((file_.size() < index_size_) && !file_.truncate(index_size_)) { LOG_ERROR("Failed to truncate file size %zu, errno %d, %s", index_size_, errno, std::strerror(errno)); return IndexError_TruncateFile; } for (auto iter = segments_.begin(); iter != segments_.end(); ++iter) { SegmentInfo &segment_info = iter->second; Segment *item = &segment_info.segment; if (!item->data() || !item->dirty()) { continue; } size_t segment_size = item->meta()->data_size + item->meta()->padding_size; if (full_mode_ && copy_on_write_) { size_t off = segment_info.segment_header_start_offset + segment_info.segment_header->content_offset + item->meta()->data_index; if (file_.write(off, item->data(), segment_size) != segment_size) { LOG_ERROR("Failed to write segment, size %zu, errno %d, %s", segment_size, errno, std::strerror(errno)); return IndexError_WriteData; } } else { ailego::File::MemoryFlush(item->data(), segment_size); } item->reset_dirty(); } if (!header_dirty_) { return 0; } header_dirty_ = false; if (full_mode_ && copy_on_write_) { for (auto item : header_addr_map_) { auto header_start_offset = item.first; auto header = item.second; if (file_.write(header_start_offset, header, header->content_offset) != header->content_offset) { LOG_ERROR("Failed to write segment, size %lu, errno %d, %s", header->content_offset, errno, std::strerror(errno)); return IndexError_WriteData; } } } else { for (auto item : header_addr_map_) { auto header = item.second; ailego::File::MemoryFlush(header, header->content_offset); } } return 0; } int IndexMapping::init_index_mapping(size_t len) { int opts = copy_on_write_ ? ailego::File::MMAP_POPULATE : ailego::File::MMAP_SHARED; if (huge_page_) { opts |= ailego::File::MMAP_HUGE_PAGE; } uint8_t *start = reinterpret_cast(ailego::File::MemoryMap( file_.native_handle(), current_header_start_offset_, len, opts)); if (!start) { LOG_ERROR("Failed to map file, errno %d, %s", errno, std::strerror(errno)); return IndexError_MMapFile; } // Unpack header header_ = reinterpret_cast(start); header_addr_map_.insert({current_header_start_offset_, header_}); if (header_->meta_header_size != sizeof(IndexFormat::MetaHeader)) { return IndexError_InvalidLength; } if (ailego::Crc32c::Hash(header_, sizeof(*header_), header_->header_crc) != header_->header_crc) { return IndexError_InvalidChecksum; } switch (header_->version) { case IndexFormat::FORMAT_VERSION: break; default: LOG_ERROR("Unsupported index version: %u", header_->version); return IndexError_Unsupported; } // Unpack footer if (header_->meta_footer_size != sizeof(IndexFormat::MetaFooter)) { return IndexError_InvalidLength; } if ((int32_t)header_->meta_footer_offset < 0) { return IndexError_Unsupported; } size_t footer_offset = header_->meta_footer_offset; if (footer_offset + header_->meta_footer_size > len) { return IndexError_InvalidLength; } footer_ = reinterpret_cast(start + footer_offset); if (footer_offset < footer_->segments_meta_size) { return IndexError_InvalidLength; } index_size_ = file_.size(); if ((footer_->total_size > index_size_) || (footer_->content_size + footer_->content_padding_size + header_->content_offset > index_size_)) { return IndexError_InvalidLength; } if (ailego::Crc32c::Hash(footer_, sizeof(*footer_), footer_->footer_crc) != footer_->footer_crc) { return IndexError_InvalidChecksum; } // Unpack segment table if (sizeof(IndexFormat::SegmentMeta) * footer_->segment_count > footer_->segments_meta_size) { return IndexError_InvalidLength; } segment_start_ = reinterpret_cast( start + (footer_offset - footer_->segments_meta_size)); if (ailego::Crc32c::Hash(segment_start_, footer_->segments_meta_size, 0u) != footer_->segments_meta_crc) { LOG_ERROR("Index segments meta checksum is invalid."); return IndexError_InvalidChecksum; } segment_ids_offset_ = footer_->segments_meta_size; for (IndexFormat::SegmentMeta *iter = segment_start_, *end = segment_start_ + footer_->segment_count; iter != end; ++iter) { if (iter->segment_id_offset > footer_->segments_meta_size) { return IndexError_InvalidValue; } if (iter->data_index > footer_->content_size) { return IndexError_InvalidValue; } if (iter->data_index + iter->data_size > footer_->content_size) { return IndexError_InvalidLength; } if (iter->segment_id_offset < segment_ids_offset_) { segment_ids_offset_ = iter->segment_id_offset; } segments_.emplace( std::string(reinterpret_cast(segment_start_) + iter->segment_id_offset), SegmentInfo{Segment{iter}, current_header_start_offset_, header_}); } if (sizeof(IndexFormat::SegmentMeta) * footer_->segment_count > segment_ids_offset_) { return IndexError_InvalidLength; } // if (header_->version == IndexFormat::COMPATIBLE_FORMAT_VERSION_0X0002) { // header_->version = IndexFormat::CURRENT_FORMAT_VERSION; // LOG_INFO("Index file format upgraded"); // IndexFormat::UpdateMetaHeader(header_); // footer_->segments_meta_crc = // ailego::Crc32c::Hash(segment_start_, footer_->segments_meta_size, 0); // IndexFormat::UpdateMetaFooter(footer_, 0); // header_dirty_ = true; // } if (footer_->next_meta_header_offset > 0) { current_header_start_offset_ = footer_->next_meta_header_offset; // Meta sections have all the same size, so we can use the same size to map // the next meta section return this->init_index_mapping(len); } return 0; } bool IndexMapping::Ishugetlbfs(const std::string &path) const { #ifdef __linux__ struct statfs buf; if (statfs(path.c_str(), &buf) != 0) { perror("statfs"); return false; } return static_cast(buf.f_type) == HUGETLBFS_MAGIC; #else static_cast(path); return false; #endif } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_meta.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include namespace zvec { namespace core { /*! Index Meta Buffer Format */ struct IndexMetaFormatHeader { uint32_t header_size; uint32_t meta_type; uint32_t major_order; uint32_t data_type; uint32_t dimension; uint32_t unit_size; uint32_t space_id; uint32_t attachment_offset; uint32_t attachment_size; uint8_t reserved_[4092]; }; static_assert(sizeof(IndexMetaFormatHeader) % 32 == 0, "IndexMetaBufferFormat must be aligned with 32 bytes"); void IndexMeta::serialize(std::string *out) const { ailego::Params attachment; IndexMetaFormatHeader format; memset(&format, 0, sizeof(format)); format.header_size = sizeof(format); format.meta_type = static_cast(meta_type_); format.major_order = static_cast(major_order_); format.data_type = static_cast(data_type_); format.dimension = dimension_; format.unit_size = unit_size_; format.space_id = space_id_; if (!metric_name_.empty()) { ailego::Params item; item.set("name", metric_name_); item.set("revision", metric_revision_); item.set("params", metric_params_); attachment.set("metric", std::move(item)); } if (!converter_name_.empty()) { ailego::Params item; item.set("name", converter_name_); item.set("revision", converter_revision_); item.set("params", converter_params_); attachment.set("converter", std::move(item)); } if (!reformer_name_.empty()) { ailego::Params item; item.set("name", reformer_name_); item.set("revision", reformer_revision_); item.set("params", reformer_params_); attachment.set("reformer", std::move(item)); } if (!trainer_name_.empty()) { ailego::Params item; item.set("name", trainer_name_); item.set("revision", trainer_revision_); item.set("params", trainer_params_); attachment.set("trainer", std::move(item)); } if (!builder_name_.empty()) { ailego::Params item; item.set("name", builder_name_); item.set("revision", builder_revision_); item.set("params", builder_params_); attachment.set("builder", std::move(item)); } if (!reducer_name_.empty()) { ailego::Params item; item.set("name", reducer_name_); item.set("revision", reducer_revision_); item.set("params", reducer_params_); attachment.set("reducer", std::move(item)); } if (!searcher_name_.empty()) { ailego::Params item; item.set("name", searcher_name_); item.set("revision", searcher_revision_); item.set("params", searcher_params_); attachment.set("searcher", std::move(item)); } if (!streamer_name_.empty()) { ailego::Params item; item.set("name", streamer_name_); item.set("revision", streamer_revision_); item.set("params", streamer_params_); attachment.set("streamer", std::move(item)); } if (!attributes_.empty()) { attachment.set("attributes", attributes_); } out->assign(reinterpret_cast(&format), sizeof(format)); size_t offset = static_cast(out->size()); if (!attachment.empty()) { std::string buf; ailego::Params::SerializeToBuffer(attachment, &buf); out->append(buf.data(), buf.size()); IndexMetaFormatHeader *header = (IndexMetaFormatHeader *)out->data(); header->attachment_offset = static_cast(offset); header->attachment_size = static_cast(buf.size()); offset += buf.size(); } } bool IndexMeta::deserialize(const void *data, size_t len) { const IndexMetaFormatHeader *format = reinterpret_cast(data); this->clear(); if (sizeof(IndexMetaFormatHeader) > len) { return false; } if (sizeof(IndexMetaFormatHeader) > format->header_size) { return false; } meta_type_ = static_cast(format->meta_type); major_order_ = static_cast(format->major_order); data_type_ = static_cast(format->data_type); dimension_ = format->dimension; unit_size_ = format->unit_size; element_size_ = IndexMeta::ElementSizeof(data_type_, unit_size_, dimension_); space_id_ = format->space_id; // Read attachment ailego::Params attachment; if (format->attachment_size) { if (format->attachment_offset + format->attachment_size > len) { return false; } std::string str( reinterpret_cast(data) + format->attachment_offset, format->attachment_size); if (!ailego::Params::ParseFromBuffer(str, &attachment)) { return false; } } ailego::Params item; if (attachment.get("metric", &item)) { item.get("name", &metric_name_); item.get("revision", &metric_revision_); item.get("params", &metric_params_); } if (attachment.get("converter", &item)) { item.get("name", &converter_name_); item.get("revision", &converter_revision_); item.get("params", &converter_params_); } if (attachment.get("reformer", &item)) { item.get("name", &reformer_name_); item.get("revision", &reformer_revision_); item.get("params", &reformer_params_); } if (attachment.get("trainer", &item)) { item.get("name", &trainer_name_); item.get("revision", &trainer_revision_); item.get("params", &trainer_params_); } if (attachment.get("builder", &item)) { item.get("name", &builder_name_); item.get("revision", &builder_revision_); item.get("params", &builder_params_); } if (attachment.get("reducer", &item)) { item.get("name", &reducer_name_); item.get("revision", &reducer_revision_); item.get("params", &reducer_params_); } if (attachment.get("searcher", &item)) { item.get("name", &searcher_name_); item.get("revision", &searcher_revision_); item.get("params", &searcher_params_); } if (attachment.get("streamer", &item)) { item.get("name", &streamer_name_); item.get("revision", &streamer_revision_); item.get("params", &streamer_params_); } attachment.get("attributes", &attributes_); return true; } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_plugin.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include namespace zvec { namespace core { bool IndexPlugin::load(const std::string &path) { if (handle_) { return false; } handle_ = ailego::DLHelper::Load(path, nullptr); return (!!handle_); } bool IndexPlugin::load(const std::string &path, std::string *err) { if (handle_) { *err = "plugin loaded"; return false; } handle_ = ailego::DLHelper::Load(path, err); return !!handle_; } void IndexPlugin::unload(void) { if (handle_) { ailego::DLHelper::Unload(handle_); handle_ = nullptr; } } bool IndexPluginBroker::emplace(IndexPlugin &&plugin) { if (!plugin.is_valid()) { return false; } for (auto iter = plugins_.begin(); iter != plugins_.end(); ++iter) { if (iter->handle() == plugin.handle()) { plugin.unload(); return true; } } plugins_.push_back(std::move(plugin)); return true; } } // namespace core } // namespace zvec ================================================ FILE: src/core/framework/index_version.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include namespace zvec { namespace core { static const char AITHETA_VERSION_DETAILS[] = AILEGO_VERSION_COMPILE_DETAILS("All rights reserved.\n"); const char *IndexVersion::String(void) { return AITHETA_VERSION_DETAILS; } const char *IndexVersion::Details(void) { return AITHETA_VERSION_DETAILS; } } // namespace core } // namespace zvec ================================================ FILE: src/core/interface/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_interface STATIC STRICT ALWAYS_LINK SRCS *.cc indexes/*.cc INCS . ${PROJECT_ROOT_DIR}/src/ ${PROJECT_ROOT_DIR}/src/core LIBS zvec_ailego core_framework sparsehash magic_enum rabitqlib VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/interface/index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "mixed_reducer/mixed_reducer_params.h" namespace zvec::core_interface { // eliminate the pre-alloc of the context pool thread_local static std::array() - 1) * 2> _context_list; bool Index::init_context() { context_index_ = (magic_enum::enum_integer(param_.index_type) - 1) * 2 + static_cast(is_sparse_); if (_context_list[context_index_] == nullptr) { if ((_context_list[context_index_] = streamer_->create_context()) == nullptr) { LOG_ERROR("Failed to create context"); return false; } } return true; } core::IndexContext::Pointer &Index::acquire_context() { init_context(); return _context_list[context_index_]; } int Index::ParseMetricName(const BaseIndexParam ¶m) { std::string metric_name; if (is_sparse_) { // only inner product is supported for sparse index switch (param.metric_type) { case MetricType::kInnerProduct: metric_name = "InnerProductSparse"; break; case MetricType::kMIPSL2sq: metric_name = "MipsSquaredEuclideanSparse"; break; default: LOG_ERROR("Unsupported metric type"); return core::IndexError_Runtime; } } else { switch (param.metric_type) { case MetricType::kL2sq: metric_name = "SquaredEuclidean"; break; case MetricType::kInnerProduct: metric_name = "InnerProduct"; break; case MetricType::kCosine: metric_name = "Cosine"; // This is already the normalizedCosine break; case MetricType::kMIPSL2sq: metric_name = "MipsSquaredEuclidean"; break; default: LOG_ERROR("Unsupported metric type"); return core::IndexError_Runtime; } } // TODO: MIPS need to set some param // for streamer open() proxima_index_meta_.set_metric(metric_name, 0, ailego::Params()); return 0; } int Index::CreateAndInitMetric(const BaseIndexParam & /*param*/) { auto &metric_name = proxima_index_meta_.metric_name(); metric_ = core::IndexFactory::CreateMetric(metric_name); if (!metric_) { LOG_ERROR("Failed to create metric, name %s", metric_name.c_str()); return core::IndexError_Runtime; } if (const auto ret = metric_->init(proxima_index_meta_, proxima_index_meta_.metric_params()); ret != 0) { LOG_ERROR("Failed to create and init metric, name %s, code %d, desc: %s", metric_name.c_str(), ret, core::IndexError::What(ret)); return core::IndexError_Runtime; } if (metric_->query_metric()) { metric_ = metric_->query_metric(); } return core::IndexError_Success; } int Index::CreateAndInitConverterReformer(const QuantizerParam ¶m, const BaseIndexParam &index_param) { ailego::Params converter_params; std::string converter_name; if (is_sparse_) { switch (param.type) { case QuantizerType::kNone: return core::IndexError_Success; case QuantizerType::kFP16: converter_name = "HalfFloatSparseConverter"; break; default: LOG_ERROR("Unsupported quantizer type: "); return core::IndexError_Unsupported; } } else { if (index_param.metric_type == MetricType::kCosine) { switch (param.type) { case QuantizerType::kNone: if (index_param.data_type == DataType::DT_FP16) { converter_name = "CosineHalfFloatConverter"; } else if (index_param.data_type == DataType::DT_FP32) { converter_name = "CosineNormalizeConverter"; } else { LOG_ERROR("Unsupported data type: "); return core::IndexError_Unsupported; } break; case QuantizerType::kRabitq: if (index_param.data_type == DataType::DT_FP32) { converter_name = "CosineNormalizeConverter"; } else { LOG_ERROR("Unsupported data type: "); return core::IndexError_Unsupported; } break; case QuantizerType::kFP16: converter_name = "CosineFp16Converter"; break; case QuantizerType::kInt8: converter_name = "CosineInt8Converter"; break; case QuantizerType::kInt4: converter_name = "CosineInt4Converter"; break; default: LOG_ERROR("Unsupported quantizer type: "); return core::IndexError_Unsupported; } } else { switch (param.type) { case QuantizerType::kNone: return core::IndexError_Success; case QuantizerType::kFP16: converter_name = "HalfFloatConverter"; break; case QuantizerType::kInt8: converter_name = "Int8StreamingConverter"; break; case QuantizerType::kInt4: converter_name = "Int4StreamingConverter"; break; case QuantizerType::kRabitq: // no converter here return 0; default: LOG_ERROR("Unsupported quantizer type: "); return core::IndexError_Unsupported; } } } proxima_index_meta_.set_converter(converter_name, 0, converter_params); converter_ = core::IndexFactory::CreateConverter(converter_name); if (converter_ == nullptr || converter_->init(proxima_index_meta_, converter_params) != 0) { LOG_ERROR("Failed to create and init converter"); return core::IndexError_Runtime; } proxima_index_meta_ = converter_->meta(); reformer_ = core::IndexFactory::CreateReformer(proxima_index_meta_.reformer_name()); if (reformer_ == nullptr || reformer_->init(proxima_index_meta_.reformer_params()) != 0) { LOG_ERROR("Failed to create and init reformer"); return core::IndexError_Runtime; } streamer_vector_meta_.set_meta(proxima_index_meta_.data_type(), proxima_index_meta_.dimension()); streamer_vector_meta_.set_meta_type(proxima_index_meta_.meta_type()); return core::IndexError_Success; } int Index::Init(const BaseIndexParam ¶m) { param_ = param; // will lose the original type info is_sparse_ = param.is_sparse; is_huge_page_ = param.is_huge_page; proxima_index_meta_.set_meta(param.data_type, param.dimension); proxima_index_meta_.set_meta_type(is_sparse_ ? IndexMeta::MetaType::MT_SPARSE : IndexMeta::MetaType::MT_DENSE); input_vector_meta_.set_meta(proxima_index_meta_.data_type(), proxima_index_meta_.dimension()); input_vector_meta_.set_meta_type(proxima_index_meta_.meta_type()); streamer_vector_meta_ = input_vector_meta_; // when quantizer=int8/int4, the converter.init() will change the metric to // QuantizedInteger with params if (ParseMetricName(param) != 0) { LOG_ERROR("Failed to parse metric name"); return core::IndexError_Runtime; } if (CreateAndInitConverterReformer(param.quantizer_param, param) != 0) { LOG_ERROR("Failed to create and init converter"); return core::IndexError_Runtime; } // must after quantizer handled. e.g., cosine doesn't support int8 quantizer if (CreateAndInitMetric(param) != 0) { LOG_ERROR("Failed to create and init metric"); return core::IndexError_Runtime; } if (CreateAndInitStreamer(param) != 0) { LOG_ERROR("Failed to create and init streamer"); return core::IndexError_Runtime; } return 0; } int Index::Open(const std::string &file_path, StorageOptions storage_options) { ailego::Params storage_params; // storage_params.set("proxima.mmap_file.storage.memory_warmup", true); // storage_params.set("proxima.mmap_file.storage.segment_meta_capacity", // 1024); switch (storage_options.type) { case StorageOptions::StorageType::kMMAP: { storage_ = core::IndexFactory::CreateStorage("MMapFileStorage"); if (storage_ == nullptr) { LOG_ERROR("Failed to create MMapFileStorage"); return core::IndexError_Runtime; } int ret = storage_->init(storage_params); if (ret != 0) { LOG_ERROR("Failed to init MMapFileStorage, path: %s, err: %s", file_path.c_str(), core::IndexError::What(ret)); return ret; } break; } case StorageOptions::StorageType::kBufferPool: { storage_ = core::IndexFactory::CreateStorage("BufferStorage"); if (storage_ == nullptr) { LOG_ERROR("Failed to create BufferStorage"); return core::IndexError_Runtime; } int ret = storage_->init(storage_params); if (ret != 0) { LOG_ERROR("Failed to init BufferStorage, path: %s, err: %s", file_path.c_str(), core::IndexError::What(ret)); return ret; } break; } default: LOG_ERROR("Unsupported storage type"); return core::IndexError_Unsupported; } // read_options.create_new int ret = storage_->open(file_path, storage_options.create_new); if (ret != 0) { LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path.c_str(), core::IndexError::What(ret)); return core::IndexError_Runtime; } if (streamer_ == nullptr || streamer_->open(storage_) != 0) { LOG_ERROR("Failed to open streamer, path: %s", file_path.c_str()); return core::IndexError_Runtime; } // converter/reformer/metric are created in IndexFactory::CreateIndex // TODO: init // TODO: context pool if (!init_context()) { // to validate if any error, will be overwritten LOG_ERROR("Failed to init context"); return core::IndexError_Runtime; } is_open_ = true; is_read_only_ = storage_options.read_only; return 0; } int Index::Close() { if (!is_open_) { LOG_ERROR("Index is not open"); return core::IndexError_Runtime; } if (!is_read_only_) { if (ailego_unlikely(Flush() != 0)) { LOG_ERROR("Failed to cleanup streamer"); return core::IndexError_Runtime; } } if (ailego_unlikely(streamer_->cleanup() != 0)) { LOG_ERROR("Failed to cleanup streamer"); return core::IndexError_Runtime; } if (ailego_unlikely(storage_->close() != 0)) { LOG_ERROR("Failed to close storage"); return core::IndexError_Runtime; } is_open_ = false; return 0; } int Index::Flush() { if (!is_open_) { LOG_ERROR("Index is not open"); return core::IndexError_Runtime; } if (is_read_only_) { LOG_ERROR("Cannot flush read-only index"); return core::IndexError_Runtime; } if (ailego_unlikely(streamer_->flush(0) != 0)) { LOG_ERROR("Failed to flush streamer"); return core::IndexError_Runtime; } if (ailego_unlikely(storage_->flush() != 0)) { LOG_ERROR("Failed to flush storage"); return core::IndexError_Runtime; } return 0; } int Index::Fetch(const uint32_t doc_id, VectorDataBuffer *vector_data_buffer) { if (!is_open_) { LOG_ERROR("Index is not open"); return core::IndexError_Runtime; } if (is_sparse_) { return _sparse_fetch(doc_id, vector_data_buffer); } return _dense_fetch(doc_id, vector_data_buffer); } int Index::Add(const VectorData &vector_data, const uint32_t doc_id) { if (!is_open_) { LOG_ERROR("Index is not open"); return core::IndexError_Runtime; } if (is_read_only_) { LOG_ERROR("Cannot add to read-only index"); return core::IndexError_Runtime; } auto &context = acquire_context(); if (!context) { LOG_ERROR("Failed to acquire context"); return core::IndexError_Runtime; } int ret = 0; if (is_sparse_) { ret = _sparse_add(vector_data, doc_id, context); } else { ret = _dense_add(vector_data, doc_id, context); } context->reset(); return ret; } int Index::Search(const VectorData &vector_data, const BaseIndexQueryParam::Pointer &search_param, SearchResult *result) { if (!is_open_) { LOG_ERROR("Index is not open"); return core::IndexError_Runtime; } if (!is_trained_ && this->Train() != 0) { LOG_ERROR("Failed to train index"); return core::IndexError_Runtime; } auto &context = acquire_context(); if (!context) { LOG_ERROR("Failed to acquire context"); return core::IndexError_Runtime; } if (_prepare_for_search(vector_data, search_param, context) != 0) { LOG_ERROR("Failed to prepare for search"); context->reset(); return core::IndexError_Runtime; } if (is_sparse_) { int ret = _sparse_search(vector_data, search_param, result, context); context->reset(); return ret; } // dense support refiner, but sparse doesn't int ret = 0; if (search_param->refiner_param == nullptr) { ret = _dense_search(vector_data, search_param, result, context); context->reset(); } else { auto &reference_index = search_param->refiner_param->reference_index; if (reference_index == nullptr) { LOG_ERROR("Reference index is not set"); context->reset(); return core::IndexError_Runtime; } // TODO: tackle query_param's type info loss to loosen the constraint if (reference_index->param_.index_type != IndexType::kFlat) { LOG_ERROR("Reference index is not flat"); context->reset(); return core::IndexError_Runtime; } context->set_topk(_get_coarse_search_topk(search_param)); context->set_fetch_vector(false); // no need to fetch vector if (_dense_search(vector_data, search_param, result, context) != 0) { LOG_ERROR("Failed to search"); context->reset(); return core::IndexError_Runtime; } auto &base_result = context->result(); std::vector keys(base_result.size()); for (size_t i = 0; i < base_result.size(); ++i) { keys[i] = base_result[i].key(); } FlatQueryParam::Pointer flat_search_param = std::make_shared(); flat_search_param->topk = search_param->topk; flat_search_param->fetch_vector = search_param->fetch_vector; flat_search_param->filter = search_param->filter; // TODO: should copy other params? flat_search_param->bf_pks = std::make_shared>(keys); ret = reference_index->Search(vector_data, flat_search_param, result); } context->reset(); return ret; } int Index::_dense_fetch(const uint32_t doc_id, VectorDataBuffer *vector_data_buffer) { core::IndexStorage::MemoryBlock vector_block; int ret = streamer_->get_vector_by_id(doc_id, vector_block); if (ret != 0) { LOG_ERROR("Failed to fetch vector, doc_id: %u", doc_id); return core::IndexError_Runtime; } const void *vector = vector_block.data(); DenseVectorBuffer dense_vector_buffer; std::string &out_vector_buffer = dense_vector_buffer.data; // for int4, unit_size * dim != element_size out_vector_buffer.resize(input_vector_meta_.element_size()); if (reformer_ != nullptr) { if (reformer_->revert(vector, streamer_vector_meta_, &out_vector_buffer) != 0) { LOG_ERROR("Failed to convert vector"); return core::IndexError_Runtime; } } else { out_vector_buffer = std::string( static_cast(vector), input_vector_meta_.dimension() * input_vector_meta_.unit_size()); } vector_data_buffer->vector_buffer = std::move(dense_vector_buffer); return 0; } int Index::_sparse_fetch(const uint32_t doc_id, VectorDataBuffer *vector_data_buffer) { SparseVectorBuffer sparse_vector_buffer; if (0 != streamer_->get_sparse_vector_by_id( doc_id, &sparse_vector_buffer.count, &sparse_vector_buffer.indices, &sparse_vector_buffer.values)) { LOG_ERROR("Failed to fetch vector"); return core::IndexError_Runtime; } if (reformer_ != nullptr) { std::string reverted_sparse_values_buffer; if (reformer_->revert( sparse_vector_buffer.count, sparse_vector_buffer.get_indices(), sparse_vector_buffer.get_values(), streamer_vector_meta_, &reverted_sparse_values_buffer) != 0) { LOG_ERROR("Failed to convert vector"); return core::IndexError_Runtime; } sparse_vector_buffer.values = std::move(reverted_sparse_values_buffer); } vector_data_buffer->vector_buffer = std::move(sparse_vector_buffer); return 0; } int Index::_dense_add(const VectorData &vector_data, const uint32_t doc_id, core::IndexContext::Pointer &context) { if (!std::holds_alternative(vector_data.vector)) { LOG_ERROR("Invalid vector data"); return core::IndexError_Runtime; } const DenseVector &dense_vector = std::get(vector_data.vector); if (reformer_ != nullptr) { core::IndexQueryMeta new_meta; std::string new_vector; int ret; ret = reformer_->convert(dense_vector.data, input_vector_meta_, &new_vector, &new_meta); if (ret != 0) { LOG_ERROR("Failed to convert vector"); return core::IndexError_Runtime; } ret = streamer_->add_with_id_impl(doc_id, new_vector.data(), new_meta, context); if (ret != 0) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } else { int ret = streamer_->add_with_id_impl(doc_id, dense_vector.data, input_vector_meta_, context); if (ret != 0) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } return 0; } int Index::_sparse_add(const VectorData &vector_data, const uint32_t doc_id, core::IndexContext::Pointer &context) { if (!std::holds_alternative(vector_data.vector)) { LOG_ERROR("Invalid vector data"); return core::IndexError_Runtime; } const SparseVector &sparse_vector = std::get(vector_data.vector); if (reformer_ != nullptr) { std::string converted_sparse_values_buffer; core::IndexQueryMeta new_meta; int ret; ret = reformer_->convert(sparse_vector.count, sparse_vector.get_indices(), sparse_vector.get_values(), input_vector_meta_, &converted_sparse_values_buffer, &new_meta); if (ret != 0) { LOG_ERROR("Failed to convert vector"); return core::IndexError_Runtime; } ret = streamer_->add_with_id_impl( doc_id, sparse_vector.count, sparse_vector.get_indices(), converted_sparse_values_buffer.data(), new_meta, context); if (ret != 0) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } else { int ret = streamer_->add_with_id_impl( doc_id, sparse_vector.count, sparse_vector.get_indices(), sparse_vector.get_values(), input_vector_meta_, context); if (ret != 0) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } return 0; } int Index::_dense_search(const VectorData &vector_data, const BaseIndexQueryParam::Pointer &search_param, SearchResult *result, core::IndexContext::Pointer &context) { if (!std::holds_alternative(vector_data.vector)) { LOG_ERROR("Invalid vector data"); return core::IndexError_Runtime; } const DenseVector &dense_vector = std::get(vector_data.vector); auto vector = dense_vector.data; // Check if need to transform feature std::string new_vector; core::IndexQueryMeta new_meta = input_vector_meta_; if (reformer_ != nullptr) { if (reformer_->transform(dense_vector.data, input_vector_meta_, &new_vector, &new_meta) != 0) { LOG_ERROR("Failed to transform vector"); return core::IndexError_Runtime; } vector = new_vector.data(); } // TODO: group by if (search_param->bf_pks != nullptr) { // should we eliminate the copy of bf_pks? if (streamer_->search_bf_by_p_keys_impl( vector, std::vector>{*search_param->bf_pks}, new_meta, 1, context) != 0) { LOG_ERROR("Failed to search_bf_by_p_keys_impl vector"); return core::IndexError_Runtime; } } else if (search_param->is_linear) { if (streamer_->search_bf_impl(vector, new_meta, 1, context) != 0) { LOG_ERROR("Failed to search vector"); return core::IndexError_Runtime; } } else { if (streamer_->search_impl(vector, new_meta, 1, context) != 0) { LOG_ERROR("Failed to search vector"); return core::IndexError_Runtime; } } result->doc_list_ = std::move(context->result()); if (metric_->support_normalize()) { for (uint32_t i = 0; i < result->doc_list_.size(); ++i) { metric_->normalize(result->doc_list_[i].mutable_score()); } } if (reformer_) { if (reformer_->normalize(dense_vector.data, input_vector_meta_, result->doc_list_) != 0) { LOG_ERROR("Failed to normalize vector"); return core::IndexError_Runtime; } if (context->fetch_vector() && reformer_->need_revert()) { // TODO: use std::pmr to optimize memory allocation result->reverted_vector_list_.resize(context->result().size()); for (uint32_t i = 0; i < context->result().size(); ++i) { std::string &reverted_vector = result->reverted_vector_list_[i]; reverted_vector.resize(input_vector_meta_.dimension() * input_vector_meta_.unit_size()); if (reformer_->revert(context->result()[i].vector(), new_meta, &reverted_vector) != 0) { LOG_ERROR("Failed to revert vector"); return core::IndexError_Runtime; } } } } return 0; } int Index::_sparse_search(const VectorData &vector_data, const BaseIndexQueryParam::Pointer &search_param, SearchResult *result, core::IndexContext::Pointer &context) { if (!std::holds_alternative(vector_data.vector)) { LOG_ERROR("Invalid vector data"); return core::IndexError_Runtime; } const SparseVector &sparse_vector = std::get(vector_data.vector); auto indices = sparse_vector.get_indices(); auto values = sparse_vector.get_values(); std::string converted_sparse_values_buffer; core::IndexQueryMeta new_meta = input_vector_meta_; if (reformer_ != nullptr) { if (reformer_->transform(sparse_vector.count, indices, values, input_vector_meta_, &converted_sparse_values_buffer, &new_meta) != 0) { LOG_ERROR("Failed to transform vector"); return core::IndexError_Runtime; } values = converted_sparse_values_buffer.data(); } if (search_param->bf_pks != nullptr) { if (streamer_->search_bf_by_p_keys_impl( sparse_vector.count, indices, values, std::vector>{*search_param->bf_pks}, new_meta, context) != 0) { LOG_ERROR("Failed to search_bf_by_p_keys_impl vector"); return core::IndexError_Runtime; } } else if (search_param->is_linear) { if (streamer_->search_bf_impl(sparse_vector.count, indices, values, new_meta, context) != 0) { LOG_ERROR("Failed to search vector"); return core::IndexError_Runtime; } } else { if (streamer_->search_impl(sparse_vector.count, indices, values, new_meta, context) != 0) { LOG_ERROR("Failed to search vector"); return core::IndexError_Runtime; } } result->doc_list_ = std::move(context->result()); if (metric_->support_normalize()) { for (uint32_t i = 0; i < result->doc_list_.size(); ++i) { metric_->normalize(result->doc_list_[i].mutable_score()); } } if (reformer_) { // TODO: no need to call reformer_->normalize() when sparse? if (context->fetch_vector() && reformer_->need_revert()) { // TODO: use std::pmr to optimize memory allocation auto &result_doc_list = context->result(); result->reverted_sparse_values_list_.resize(result_doc_list.size()); for (uint32_t i = 0; i < result_doc_list.size(); ++i) { auto &result_doc = result_doc_list[i].sparse_doc(); std::string &reverted_sparse_values = result->reverted_sparse_values_list_[i]; reverted_sparse_values.resize(result_doc.sparse_count() * input_vector_meta_.unit_size()); if (reformer_->revert(result_doc.sparse_count(), reinterpret_cast( result_doc.sparse_indices().data()), reinterpret_cast( result_doc.sparse_values().data()), new_meta, &reverted_sparse_values) != 0) { LOG_ERROR("Failed to revert sparse vector"); return core::IndexError_Runtime; } } } } return 0; } int Index::Merge(const std::vector &indexes, const IndexFilter &filter, const MergeOptions &options) { if (indexes.empty()) { return core::IndexError_Success; } // ivf need builder auto reducer = core::IndexFactory::CreateStreamerReducer("MixedStreamerReducer"); if (reducer == nullptr) { LOG_ERROR("Failed to create reducer"); return core::IndexError_Runtime; } if (options.write_concurrency == 0) { LOG_ERROR("Write concurrency must be greater than 0"); return core::IndexError_InvalidArgument; } // must declare here to ensure its lifespan can cover reducer->reduce() std::unique_ptr local_thread_pool = nullptr; if (options.pool != nullptr) { reducer->set_thread_pool(options.pool); } else { local_thread_pool = std::make_unique(options.write_concurrency); reducer->set_thread_pool(local_thread_pool.get()); } ailego::Params reducer_params; reducer_params.set(core::PARAM_MIXED_STREAMER_REDUCER_ENABLE_PK_REWRITE, true); reducer_params.set(core::PARAM_MIXED_STREAMER_REDUCER_NUM_OF_ADD_THREADS, options.write_concurrency); if (reducer->init(reducer_params) != 0) { LOG_ERROR("Failed to init reducer"); return core::IndexError_Runtime; } if (reducer->set_target_streamer_wiht_info(builder_, streamer_, converter_, reformer_, input_vector_meta_) != 0) { LOG_ERROR("Failed to set target streamer"); return core::IndexError_Runtime; } for (const auto &index : indexes) { if (reducer->feed_streamer_with_reformer(index->streamer_, index->reformer_) != 0) { LOG_ERROR("Failed to feed streamer"); return core::IndexError_Runtime; } } if (reducer->reduce(filter) != 0) { LOG_ERROR("Failed to reduce"); return core::IndexError_Runtime; } is_trained_ = true; return 0; } int Index::_get_coarse_search_topk( const BaseIndexQueryParam::Pointer &search_param) { float scale_factor = search_param->refiner_param->scale_factor_; if (scale_factor == 0) { scale_factor = 1; } return floor(search_param->topk * scale_factor); } std::string Index::get_metric_name(MetricType metric_type, bool is_sparse) { if (is_sparse) { switch (metric_type) { case MetricType::kInnerProduct: return "InnerProductSparse"; case MetricType::kMIPSL2sq: return "MipsSquaredEuclideanSparse"; default: return ""; } } else { switch (metric_type) { case MetricType::kL2sq: return "SquaredEuclidean"; case MetricType::kInnerProduct: return "InnerProduct"; case MetricType::kCosine: return "Cosine"; case MetricType::kMIPSL2sq: return "MipsSquaredEuclidean"; default: return ""; } } } } // namespace zvec::core_interface ================================================ FILE: src/core/interface/index_factory.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "core/interface/utils/utils.h" namespace zvec::core_interface { Index::Pointer IndexFactory::CreateAndInitIndex(const BaseIndexParam ¶m) { Index::Pointer ptr = nullptr; // if (param.index_type == IndexType::kIVF) { // const IVFIndexParam *_param = dynamic_cast(¶m); ptr = std::make_shared(param); // if (_param->l1Index) { // // TODO: create l1 index // } // if (_param->l2Index) { // // TODO: create l2 index // } // } // if (param.index_type == IndexType::kHNSW) { // ptr = std::make_shared(param); // } if (param.index_type == IndexType::kFlat) { // ptr = std::make_shared(param); ptr = std::make_shared(); } else if (param.index_type == IndexType::kHNSW) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kIVF) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kHNSWRabitq) { ptr = std::make_shared(); } else { LOG_ERROR("Unsupported index type: "); return nullptr; } if (!ptr) { LOG_ERROR("Failed to create index"); return nullptr; } if (0 != ptr->Init(param)) { LOG_ERROR("Failed to init index"); return nullptr; } return ptr; } BaseIndexParam::Pointer IndexFactory::DeserializeIndexParamFromJson( const std::string &json_str) { ailego::JsonValue json_value; if (!json_value.parse(json_str)) { LOG_ERROR("Failed to parse json string: %s", json_str.c_str()); return nullptr; } ailego::JsonObject json_obj = json_value.as_object(); ailego::JsonValue tmp_json_value; IndexType index_type; if (!extract_enum_from_json(json_obj, "index_type", index_type, tmp_json_value)) { LOG_ERROR("Failed to deserialize index type"); return nullptr; } switch (index_type) { case IndexType::kFlat: { FlatIndexParam::Pointer param = std::make_shared(); if (!param->DeserializeFromJson(json_str)) { LOG_ERROR("Failed to deserialize flat index param"); return nullptr; } return param; } case IndexType::kHNSW: { HNSWIndexParam::Pointer param = std::make_shared(); if (!param->DeserializeFromJson(json_str)) { LOG_ERROR("Failed to deserialize hnsw index param"); return nullptr; } return param; } case IndexType::kIVF: { IVFIndexParam::Pointer param = std::make_shared(); if (!param->DeserializeFromJson(json_str)) { LOG_ERROR("Failed to deserialize hnsw index param"); return nullptr; } return param; } case IndexType::kHNSWRabitq: { HNSWRabitqIndexParam::Pointer param = std::make_shared(); if (!param->DeserializeFromJson(json_str)) { LOG_ERROR("Failed to deserialize hnsqrabitq index param"); return nullptr; } return param; } default: LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); return nullptr; } } template , bool> > std::string IndexFactory::QueryParamSerializeToJson(const QueryParamType ¶m, bool omit_empty_value) { ailego::JsonObject json_obj; // BaseIndexQueryParam // omit filter & bf_pks if (!omit_empty_value || param.topk != 0) { json_obj.set("topk", ailego::JsonValue(param.topk)); } if (!omit_empty_value || param.fetch_vector) { json_obj.set("fetch_vector", ailego::JsonValue(param.fetch_vector)); } if (!omit_empty_value || param.radius != 0.0f) { json_obj.set("radius", ailego::JsonValue(param.radius)); } if (!omit_empty_value || param.is_linear) { json_obj.set("is_linear", ailego::JsonValue(param.is_linear)); } IndexType index_type{IndexType::kNone}; if constexpr (std::is_same_v) { // index_type index_type = IndexType::kFlat; } else if constexpr (std::is_same_v) { if (!omit_empty_value || param.ef_search != 0) { json_obj.set("ef_search", ailego::JsonValue(param.ef_search)); } index_type = IndexType::kHNSW; } else if constexpr (std::is_same_v) { if (!omit_empty_value || param.nprobe != 0) { json_obj.set("nprobe", ailego::JsonValue(param.nprobe)); } index_type = IndexType::kIVF; // json_obj.set("l1QueryParam", // ailego::JsonValue(QueryParamSerializeToJson(param.l1QueryParam))); // json_obj.set("l2QueryParam", // ailego::JsonValue(QueryParamSerializeToJson(param.l2QueryParam))); } else if constexpr (std::is_same_v) { if (!omit_empty_value || param.ef_search != 0) { json_obj.set("ef_search", ailego::JsonValue(param.ef_search)); } index_type = IndexType::kHNSWRabitq; } json_obj.set("index_type", ailego::JsonValue(magic_enum::enum_name(index_type).data())); return ailego::JsonValue(json_obj).as_json_string().as_stl_string(); } template std::string IndexFactory::QueryParamSerializeToJson( const BaseIndexQueryParam ¶m, bool omit_empty_value); template std::string IndexFactory::QueryParamSerializeToJson( const FlatQueryParam ¶m, bool omit_empty_value); template std::string IndexFactory::QueryParamSerializeToJson( const HNSWQueryParam ¶m, bool omit_empty_value); template std::string IndexFactory::QueryParamSerializeToJson( const IVFQueryParam ¶m, bool omit_empty_value); template , bool> > typename QueryParamType::Pointer IndexFactory::QueryParamDeserializeFromJson( const std::string &json_str) { ailego::JsonValue tmp_json_value; if (!tmp_json_value.parse(json_str)) { LOG_ERROR("Failed to parse json string: %s", json_str.c_str()); return nullptr; } ailego::JsonObject json_obj = tmp_json_value.as_object(); auto parse_common_fields = [&](auto ¶m) -> bool { if (!extract_value_from_json(json_obj, "topk", param->topk, tmp_json_value)) { LOG_ERROR("Failed to deserialize topk"); return false; } if (!extract_value_from_json(json_obj, "fetch_vector", param->fetch_vector, tmp_json_value)) { LOG_ERROR("Failed to deserialize fetch_vector"); return false; } if (!extract_value_from_json(json_obj, "radius", param->radius, tmp_json_value)) { LOG_ERROR("Failed to deserialize radius"); return false; } if (!extract_value_from_json(json_obj, "is_linear", param->is_linear, tmp_json_value)) { LOG_ERROR("Failed to deserialize is_linear"); return false; } return true; }; IndexType index_type; if (!extract_enum_from_json(json_obj, "index_type", index_type, tmp_json_value)) { LOG_ERROR("Failed to deserialize index type"); return nullptr; } if constexpr (std::is_same_v) { if (index_type == IndexType::kFlat) { auto param = std::make_shared(); if (!parse_common_fields(param)) { return nullptr; } return param; } else if (index_type == IndexType::kHNSW) { auto param = std::make_shared(); if (!parse_common_fields(param)) { return nullptr; } if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, tmp_json_value)) { LOG_ERROR("Failed to deserialize ef_search"); return nullptr; } return param; } else if (index_type == IndexType::kIVF) { auto param = std::make_shared(); if (!parse_common_fields(param)) { return nullptr; } if (!extract_value_from_json(json_obj, "nprobe", param->nprobe, tmp_json_value)) { LOG_ERROR("Failed to deserialize nprobe"); return nullptr; } return param; } else if (index_type == IndexType::kHNSWRabitq) { auto param = std::make_shared(); if (!parse_common_fields(param)) { return nullptr; } if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, tmp_json_value)) { LOG_ERROR("Failed to deserialize ef_search"); return nullptr; } return param; } else { LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); return nullptr; } } else { auto param = std::make_shared(); if (!parse_common_fields(param)) { return nullptr; } if constexpr (std::is_same_v) { } else if constexpr (std::is_same_v) { if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, tmp_json_value)) { LOG_ERROR("Failed to deserialize ef_search"); return nullptr; } } else if constexpr (std::is_same_v) { if (!extract_value_from_json(json_obj, "nprobe", param->nprobe, tmp_json_value)) { LOG_ERROR("Failed to deserialize nprobe"); return nullptr; } } else if constexpr (std::is_same_v) { if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, tmp_json_value)) { LOG_ERROR("Failed to deserialize ef_search"); return nullptr; } } else { LOG_ERROR("Unsupported index type: %s", magic_enum::enum_name(index_type).data()); return nullptr; } return param; } } template BaseIndexQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson( const std::string &json_str); template FlatQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< FlatQueryParam>(const std::string &json_str); template HNSWQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< HNSWQueryParam>(const std::string &json_str); template IVFQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< IVFQueryParam>(const std::string &json_str); } // namespace zvec::core_interface ================================================ FILE: src/core/interface/index_param.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "core/interface/utils/utils.h" namespace zvec { namespace core_interface { ailego::JsonObject BaseIndexParam::SerializeToJsonObject( bool omit_empty_value) const { ailego::JsonObject json_obj; if (!omit_empty_value || index_type != IndexType::kNone) { json_obj.set("index_type", ailego::JsonValue(magic_enum::enum_name(index_type).data())); } if (!omit_empty_value || metric_type != MetricType::kNone) { json_obj.set("metric_type", ailego::JsonValue(magic_enum::enum_name(metric_type).data())); } if (!omit_empty_value || dimension != 0) { json_obj.set("dimension", ailego::JsonValue(dimension)); } if (!omit_empty_value || version != 0) { json_obj.set("version", ailego::JsonValue(version)); } if (!omit_empty_value || is_sparse) { json_obj.set("is_sparse", ailego::JsonValue(is_sparse)); } if (!omit_empty_value || data_type != DataType::DT_UNDEFINED) { json_obj.set("data_type", ailego::JsonValue(magic_enum::enum_name(data_type).data())); } if (!omit_empty_value || use_id_map) { json_obj.set("use_id_map", ailego::JsonValue(use_id_map)); } if (!omit_empty_value || is_huge_page) { json_obj.set("is_huge_page", ailego::JsonValue(is_huge_page)); } // if (preprocess_param) { // json.set("preprocess_param", preprocess_param->SerializeToJson()); // } if (!omit_empty_value || quantizer_param.type != QuantizerType::kNone) { json_obj.set("quantizer_param", quantizer_param.SerializeToJsonObject(omit_empty_value)); } // if (refiner_param) { // json.set("refiner_param", refiner_param->SerializeToJson()); // } // if (default_query_param) { // json.set("default_query_param", // default_query_param->SerializeToJson()); // } return json_obj; } ailego::JsonObject FlatIndexParam::SerializeToJsonObject( bool omit_empty_value) const { auto json_obj = BaseIndexParam::SerializeToJsonObject(omit_empty_value); if (!omit_empty_value || major_order != IndexMeta::MajorOrder::MO_UNDEFINED) { json_obj.set("major_order", ailego::JsonValue(magic_enum::enum_name(major_order).data())); } return json_obj; } ailego::JsonObject HNSWIndexParam::SerializeToJsonObject( bool omit_empty_value) const { auto json_obj = BaseIndexParam::SerializeToJsonObject(omit_empty_value); json_obj.set("m", ailego::JsonValue(m)); json_obj.set("ef_construction", ailego::JsonValue(ef_construction)); return json_obj; } bool BaseIndexParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { DESERIALIZE_ENUM_FIELD(json_obj, index_type, IndexType); DESERIALIZE_ENUM_FIELD(json_obj, metric_type, MetricType); DESERIALIZE_ENUM_FIELD(json_obj, data_type, DataType); DESERIALIZE_VALUE_FIELD(json_obj, dimension); DESERIALIZE_VALUE_FIELD(json_obj, version); DESERIALIZE_VALUE_FIELD(json_obj, is_sparse); DESERIALIZE_VALUE_FIELD(json_obj, use_id_map); DESERIALIZE_VALUE_FIELD(json_obj, is_huge_page); ailego::JsonValue tmp_json_value; if (json_obj.has("quantizer_param")) { if (json_obj.get("quantizer_param", &tmp_json_value); tmp_json_value.is_object()) { quantizer_param.DeserializeFromJsonObject(tmp_json_value.as_object()); } } return true; } bool FlatIndexParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { if (!BaseIndexParam::DeserializeFromJsonObject(json_obj)) { return false; } if (index_type != IndexType::kFlat) { LOG_ERROR("index_type is not kFlat"); return false; } DESERIALIZE_ENUM_FIELD(json_obj, major_order, IndexMeta::MajorOrder); return true; } bool HNSWIndexParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { if (!BaseIndexParam::DeserializeFromJsonObject(json_obj)) { return false; } if (index_type != IndexType::kHNSW) { LOG_ERROR("index_type is not kHNSW"); return false; } DESERIALIZE_VALUE_FIELD(json_obj, m); DESERIALIZE_VALUE_FIELD(json_obj, ef_construction); return true; } bool HNSWRabitqIndexParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { if (!BaseIndexParam::DeserializeFromJsonObject(json_obj)) { return false; } if (index_type != IndexType::kHNSWRabitq) { LOG_ERROR("index_type is not kHNSWRabitq"); return false; } DESERIALIZE_VALUE_FIELD(json_obj, m); DESERIALIZE_VALUE_FIELD(json_obj, ef_construction); DESERIALIZE_VALUE_FIELD(json_obj, total_bits); DESERIALIZE_VALUE_FIELD(json_obj, num_clusters); DESERIALIZE_VALUE_FIELD(json_obj, sample_count); return true; } ailego::JsonObject HNSWRabitqIndexParam::SerializeToJsonObject( bool omit_empty_value) const { auto json_obj = BaseIndexParam::SerializeToJsonObject(omit_empty_value); json_obj.set("m", ailego::JsonValue(m)); json_obj.set("ef_construction", ailego::JsonValue(ef_construction)); json_obj.set("total_bits", ailego::JsonValue(total_bits)); json_obj.set("num_clusters", ailego::JsonValue(num_clusters)); if (!omit_empty_value || sample_count != 0) { json_obj.set("sample_count", ailego::JsonValue(sample_count)); } return json_obj; } ailego::JsonObject QuantizerParam::SerializeToJsonObject( bool omit_empty_value) const { ailego::JsonObject json_obj; if (!omit_empty_value || type != QuantizerType::kNone) { json_obj.set("type", zvec::ailego::JsonValue(magic_enum::enum_name(type).data())); } return json_obj; } bool QuantizerParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { DESERIALIZE_ENUM_FIELD(json_obj, type, QuantizerType); return true; } // bool BaseIndexQueryParam::DeserializeFromJsonObject( // const ailego::JsonObject &json_obj) { // DESERIALIZE_ENUM_FIELD(json_obj, index_type, IndexType); // DESERIALIZE_VALUE_FIELD(json_obj, topk); // DESERIALIZE_VALUE_FIELD(json_obj, fetch_vector); // DESERIALIZE_VALUE_FIELD(json_obj, radius); // DESERIALIZE_VALUE_FIELD(json_obj, is_linear); // return true; // } // ailego::JsonObject BaseIndexQueryParam::SerializeToJsonObject( // bool omit_empty_value) const { // ailego::JsonObject json_obj; // if (!omit_empty_value || index_type != IndexType::kNone) { // json_obj.set("index_type", // ailego::JsonValue(magic_enum::enum_name(index_type).data())); // } // if (!omit_empty_value || topk != 0) { // json_obj.set("topk", ailego::JsonValue(topk)); // } // if (!omit_empty_value || fetch_vector) { // json_obj.set("fetch_vector", ailego::JsonValue(fetch_vector)); // } // if (!omit_empty_value || radius != 0.0f) { // json_obj.set("radius", ailego::JsonValue(radius)); // } // if (!omit_empty_value || is_linear) { // json_obj.set("is_linear", ailego::JsonValue(is_linear)); // } // return json_obj; // } // bool FlatQueryParam::DeserializeFromJsonObject( // const ailego::JsonObject &json_obj) { // if (!BaseIndexQueryParam::DeserializeFromJsonObject(json_obj)) { // return false; // } // if (index_type != IndexType::kFlat) { // LOG_ERROR("index_type is not kFlat"); // return false; // } // return true; // } // ailego::JsonObject FlatQueryParam::SerializeToJsonObject( // bool omit_empty_value) const { // auto json_obj = // BaseIndexQueryParam::SerializeToJsonObject(omit_empty_value); // return json_obj; // } // bool HNSWQueryParam::DeserializeFromJsonObject( // const ailego::JsonObject &json_obj) { // if (!BaseIndexQueryParam::DeserializeFromJsonObject(json_obj)) { // return false; // } // if (index_type != IndexType::kHNSW) { // LOG_ERROR("index_type is not kHNSW"); // return false; // } // DESERIALIZE_VALUE_FIELD(json_obj, ef_search); // return true; // } // ailego::JsonObject HNSWQueryParam::SerializeToJsonObject( // bool omit_empty_value) const { // auto json_obj = // BaseIndexQueryParam::SerializeToJsonObject(omit_empty_value); // if (!omit_empty_value || ef_search != 0) { // json_obj.set("ef_search", ailego::JsonValue(ef_search)); // } // return json_obj; // } } // namespace core_interface } // namespace zvec ================================================ FILE: src/core/interface/indexes/flat_index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "algorithm/flat/flat_utility.h" namespace zvec::core_interface { int FlatIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { param_ = dynamic_cast(param); proxima_index_params_.set(core::PARAM_FLAT_COLUMN_MAJOR_ORDER, param_.major_order == IndexMeta::MO_COLUMN); proxima_index_params_.set(core::PARAM_FLAT_USE_ID_MAP, param_.use_id_map); if (is_sparse_) { streamer_ = core::IndexFactory::CreateStreamer("FlatSparseStreamer"); } else { streamer_ = core::IndexFactory::CreateStreamer("FlatStreamer"); } if (ailego_unlikely(!streamer_)) { LOG_ERROR("Failed to create streamer"); return core::IndexError_Runtime; } if (ailego_unlikely( streamer_->init(proxima_index_meta_, proxima_index_params_) != 0)) { LOG_ERROR("Failed to init streamer"); return core::IndexError_Runtime; } return 0; } int FlatIndex::_prepare_for_search( const VectorData & /*vector_data*/, const BaseIndexQueryParam::Pointer &search_param, core::IndexContext::Pointer &context) { auto flat_search_param = std::dynamic_pointer_cast(search_param); if (ailego_unlikely(!flat_search_param)) { LOG_ERROR("Invalid search param type, expected FlatQueryParam"); return core::IndexError_Runtime; } context->set_topk(flat_search_param->topk); context->set_fetch_vector(flat_search_param->fetch_vector); if (flat_search_param->filter) { context->set_filter(std::move(*flat_search_param->filter)); } if (flat_search_param->radius > 0.0f) { context->set_threshold(flat_search_param->radius); } return 0; } } // namespace zvec::core_interface ================================================ FILE: src/core/interface/indexes/hnsw_index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "algorithm/hnsw/hnsw_params.h" #include "algorithm/hnsw_sparse/hnsw_sparse_params.h" namespace zvec::core_interface { int HNSWIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { param_ = dynamic_cast(param); // valid param_.ef_construction = std::max(1, std::min(2048, param_.ef_construction)); param_.m = std::max(5, std::min(1024, param_.m)); if (is_sparse_) { proxima_index_params_.set(core::PARAM_HNSW_SPARSE_STREAMER_EFCONSTRUCTION, param_.ef_construction); proxima_index_params_.set( core::PARAM_HNSW_SPARSE_STREAMER_MAX_NEIGHBOR_COUNT, param_.m); // TODO: add_vector_with_id & fetch_by_id don't rely on this param proxima_index_params_.set( core::PARAM_HNSW_SPARSE_STREAMER_GET_VECTOR_ENABLE, true); // TODO: use index params' default query param here proxima_index_params_.set(core::PARAM_HNSW_SPARSE_STREAMER_EF, kDefaultHnswEfSearch); streamer_ = core::IndexFactory::CreateStreamer("HnswSparseStreamer"); } else { proxima_index_params_.set(core::PARAM_HNSW_STREAMER_EFCONSTRUCTION, param_.ef_construction); proxima_index_params_.set(core::PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, param_.m); // TODO: add_vector_with_id & fetch_by_id don't rely on this param proxima_index_params_.set(core::PARAM_HNSW_STREAMER_GET_VECTOR_ENABLE, true); // TODO: use index params' default query param here proxima_index_params_.set(core::PARAM_HNSW_STREAMER_EF, kDefaultHnswEfSearch); proxima_index_params_.set(core::PARAM_HNSW_STREAMER_USE_ID_MAP, param_.use_id_map); streamer_ = core::IndexFactory::CreateStreamer("HnswStreamer"); } if (ailego_unlikely(!streamer_)) { LOG_ERROR("Failed to create streamer"); return core::IndexError_Runtime; } if (ailego_unlikely( streamer_->init(proxima_index_meta_, proxima_index_params_) != 0)) { LOG_ERROR("Failed to init streamer"); return core::IndexError_Runtime; } return 0; } int HNSWIndex::_prepare_for_search( const VectorData & /*vector_data*/, const BaseIndexQueryParam::Pointer &search_param, core::IndexContext::Pointer &context) { const auto &hnsw_search_param = std::dynamic_pointer_cast(search_param); if (ailego_unlikely(!hnsw_search_param)) { LOG_ERROR("Invalid search param type, expected HNSWQueryParam"); return core::IndexError_Runtime; } if (0 >= hnsw_search_param->ef_search || hnsw_search_param->ef_search > 2048) { LOG_ERROR( "ef_search must be greater than 0 and less than or equal to 2048."); return core::IndexError_Runtime; } context->set_topk(hnsw_search_param->topk); context->set_fetch_vector(hnsw_search_param->fetch_vector); if (hnsw_search_param->filter) { context->set_filter(std::move(*hnsw_search_param->filter)); } if (hnsw_search_param->radius > 0.0f) { context->set_threshold(hnsw_search_param->radius); } ailego::Params params; const int real_search_ef = std::max(1u, std::min(2048u, hnsw_search_param->ef_search)); params.set(core::PARAM_HNSW_STREAMER_EF, real_search_ef); context->update(params); return 0; } int HNSWIndex::_get_coarse_search_topk( const BaseIndexQueryParam::Pointer &search_param) { const auto &hnsw_search_param = std::dynamic_pointer_cast(search_param); // scale_factor doesn't take effect for hnsw. auto ret = std::max(search_param->topk, hnsw_search_param->ef_search); return ret; } } // namespace zvec::core_interface ================================================ FILE: src/core/interface/indexes/hnsw_rabitq_index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "zvec/core/framework/index_error.h" #if RABITQ_SUPPORTED #include "algorithm/hnsw_rabitq/hnsw_rabitq_params.h" #include "algorithm/hnsw_rabitq/hnsw_rabitq_streamer.h" #include "algorithm/hnsw_rabitq/rabitq_params.h" #endif namespace zvec::core_interface { int HNSWRabitqIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { #if !RABITQ_SUPPORTED LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); return core::IndexError_Unsupported; #else param_ = dynamic_cast(param); if (is_sparse_) { LOG_ERROR("Sparse index is not supported"); return core::IndexError_Runtime; } if (param.dimension < core::kMinRabitqDimSize || param.dimension > core::kMaxRabitqDimSize) { LOG_ERROR("Unsupported dimension: %d", param.dimension); return core::IndexError_Unsupported; } // validate parameters param_.ef_construction = std::max(1, std::min(2048, param_.ef_construction)); param_.m = std::max(5, std::min(1024, param_.m)); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_EFCONSTRUCTION, param_.ef_construction); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_MAX_NEIGHBOR_COUNT, param_.m); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_GET_VECTOR_ENABLE, true); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_EF, kDefaultHnswEfSearch); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_STREAMER_USE_ID_MAP, param_.use_id_map); proxima_index_params_.set(core::PARAM_HNSW_RABITQ_GENERAL_DIMENSION, input_vector_meta_.dimension()); proxima_index_params_.set(core::PARAM_RABITQ_TOTAL_BITS, param_.total_bits); // num_clusters, sample_count are parameters for rabitq converter // proxima_index_params_.set(core::PARAM_RABITQ_NUM_CLUSTERS, // param_.num_clusters); auto streamer = std::make_shared(); streamer->set_provider(param_.provider); streamer->set_reformer(param_.reformer); streamer_ = streamer; if (ailego_unlikely(!streamer_)) { LOG_ERROR("Failed to create HnswRabitqStreamer"); return core::IndexError_Runtime; } if (ailego_unlikely( streamer_->init(proxima_index_meta_, proxima_index_params_) != 0)) { LOG_ERROR("Failed to init HnswRabitqStreamer"); return core::IndexError_Runtime; } return 0; #endif // RABITQ_SUPPORTED } int HNSWRabitqIndex::_prepare_for_search( const VectorData & /*vector_data*/, const BaseIndexQueryParam::Pointer &search_param, core::IndexContext::Pointer &context) { #if !RABITQ_SUPPORTED LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); return core::IndexError_Unsupported; #else const auto &hnsw_search_param = std::dynamic_pointer_cast(search_param); if (ailego_unlikely(!hnsw_search_param)) { LOG_ERROR("Invalid search param type, expected HNSWRabitqQueryParam"); return core::IndexError_Runtime; } if (0 >= hnsw_search_param->ef_search || hnsw_search_param->ef_search > 2048) { LOG_ERROR( "ef_search must be greater than 0 and less than or equal to 2048."); return core::IndexError_Runtime; } context->set_topk(hnsw_search_param->topk); context->set_fetch_vector(hnsw_search_param->fetch_vector); if (hnsw_search_param->filter) { context->set_filter(std::move(*hnsw_search_param->filter)); } if (hnsw_search_param->radius > 0.0f) { context->set_threshold(hnsw_search_param->radius); } ailego::Params params; const int real_search_ef = std::max(1u, std::min(2048u, hnsw_search_param->ef_search)); params.set(core::PARAM_HNSW_RABITQ_STREAMER_EF, real_search_ef); context->update(params); return 0; #endif // RABITQ_SUPPORTED } int HNSWRabitqIndex::_get_coarse_search_topk( const BaseIndexQueryParam::Pointer &search_param) { #if !RABITQ_SUPPORTED LOG_ERROR("RaBitQ is not supported on this platform (Linux x86_64 only)"); return -1; #else const auto &hnsw_search_param = std::dynamic_pointer_cast(search_param); auto ret = std::max(search_param->topk, hnsw_search_param->ef_search); return ret; #endif // RABITQ_SUPPORTED } } // namespace zvec::core_interface ================================================ FILE: src/core/interface/indexes/ivf_index.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "algorithm/ivf/ivf_params.h" namespace zvec::core_interface { static constexpr uint64_t kInvalidKey = std::numeric_limits::max(); int IVFIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { if (is_sparse_) { LOG_ERROR("IVF Index not support sparse vector"); return core::IndexError_InvalidArgument; } param_ = dynamic_cast(param); param_.nlist = std::max(1, std::min(1024, param_.nlist)); param_.niters = std::max(1, std::min(1024, param_.niters)); proxima_index_params_.set(core::PARAM_IVF_BUILDER_CENTROID_COUNT, param_.nlist); // TODO: add_vector_with_id & fetch_by_id don't rely on this param builder_ = core::IndexFactory::CreateBuilder("IVFBuilder"); streamer_ = core::IndexFactory::CreateStreamer("IVFStreamer"); if (ailego_unlikely(!builder_)) { LOG_ERROR("Failed to create builder"); return core::IndexError_Runtime; } if (ailego_unlikely(!streamer_)) { LOG_ERROR("Failed to create streamer"); return core::IndexError_Runtime; } IndexMeta real_meta; if (converter_) { real_meta = converter_->meta(); } else { real_meta = proxima_index_meta_; } if (ailego_unlikely(builder_->init(real_meta, proxima_index_params_) != 0)) { LOG_ERROR("Failed to init builder"); return core::IndexError_Runtime; } if (ailego_unlikely(streamer_->init(real_meta, proxima_index_params_) != 0)) { LOG_ERROR("Failed to init streamer"); return core::IndexError_Runtime; } return 0; } int IVFIndex::Open(const std::string &file_path, StorageOptions storage_options) { ailego::Params storage_params; file_path_ = file_path; is_read_only_ = storage_options.read_only; switch (storage_options.type) { case StorageOptions::StorageType::kMMAP: { storage_ = core::IndexFactory::CreateStorage("MMapFileReadStorage"); if (storage_ == nullptr) { LOG_ERROR("Failed to create MMapFileStorage"); return core::IndexError_Runtime; } int ret = storage_->init(storage_params); if (ret != 0) { LOG_ERROR("Failed to init MMapFileStorage, path: %s, err: %s", file_path_.c_str(), core::IndexError::What(ret)); return ret; } break; } case StorageOptions::StorageType::kBufferPool: { storage_ = core::IndexFactory::CreateStorage("BufferStorage"); if (storage_ == nullptr) { LOG_ERROR("Failed to create BufferStorage"); return core::IndexError_Runtime; } int ret = storage_->init(storage_params); if (ret != 0) { LOG_ERROR("Failed to init BufferStorage, path: %s, err: %s", file_path_.c_str(), core::IndexError::What(ret)); return ret; } break; } default: { LOG_ERROR("Unsupported storage type"); return core::IndexError_Unsupported; } } if (is_read_only_ || !storage_options.create_new) { // read_options.create_new int ret = storage_->open(file_path_, false); if (ret != 0) { LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), core::IndexError::What(ret)); return core::IndexError_Runtime; } if (streamer_ == nullptr || streamer_->open(storage_) != 0) { LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } is_trained_ = true; } is_open_ = true; return 0; } int IVFIndex::GenerateHolder() { if (param_.data_type == DataType::DT_FP16) { auto holder = std::make_shared>( param_.dimension); for (auto doc : doc_cache_) { ailego::NumericalVector vec(doc.second); if (doc.first == kInvalidKey) { continue; } if (!holder->emplace(doc.first, vec)) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } holder_ = holder; } else if (param_.data_type == DataType::DT_FP32) { auto holder = std::make_shared>( param_.dimension); for (auto doc : doc_cache_) { ailego::NumericalVector vec(doc.second); if (doc.first == kInvalidKey) { continue; } if (!holder->emplace(doc.first, vec)) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } holder_ = holder; } else if (param_.data_type == DataType::DT_INT8) { auto holder = std::make_shared>( param_.dimension); for (auto doc : doc_cache_) { ailego::NumericalVector vec(doc.second); if (doc.first == kInvalidKey) { continue; } if (!holder->emplace(doc.first, vec)) { LOG_ERROR("Failed to add vector"); return core::IndexError_Runtime; } } holder_ = holder; } else { LOG_ERROR("data_type is not support"); return core::IndexError_Runtime; } if (converter_) { core::IndexConverter::TrainAndTransform(converter_, holder_); holder_ = converter_->result(); } return 0; } int IVFIndex::Add(const VectorData &vector, uint32_t doc_id) { if (is_trained_) { LOG_ERROR("this IVF index is trained"); return core::IndexError_Runtime; } if (!std::holds_alternative(vector.vector)) { LOG_ERROR("Invalid vector data"); return core::IndexError_Runtime; } const DenseVector &dense_vector = std::get(vector.vector); std::string out_vector_buffer = std::string( static_cast(dense_vector.data), input_vector_meta_.dimension() * input_vector_meta_.unit_size()); std::lock_guard lock(mutex_); while (doc_cache_.size() <= doc_id) { std::string fake_data( input_vector_meta_.dimension() * input_vector_meta_.unit_size(), 0); doc_cache_.push_back(std::make_pair(kInvalidKey, fake_data)); } doc_cache_[doc_id] = std::make_pair(doc_id, out_vector_buffer); return 0; } int IVFIndex::Train() { GenerateHolder(); builder_->train(holder_); builder_->build(holder_); auto dumper = core::IndexFactory::CreateDumper("FileDumper"); dumper->create(file_path_); builder_->dump(dumper); dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), core::IndexError::What(ret)); return core::IndexError_Runtime; } if (streamer_ == nullptr || streamer_->open(storage_) != 0) { LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } is_trained_ = true; return 0; } int IVFIndex::_dense_fetch(const uint32_t doc_id, VectorDataBuffer *vector_data_buffer) { if (is_trained_) { return Index::_dense_fetch(doc_id, vector_data_buffer); } else { DenseVectorBuffer dense_vector_buffer; std::string &out_vector_buffer = dense_vector_buffer.data; out_vector_buffer = doc_cache_[doc_id].second; vector_data_buffer->vector_buffer = std::move(dense_vector_buffer); return 0; } } int IVFIndex::_prepare_for_search( const VectorData & /*query*/, const BaseIndexQueryParam::Pointer &search_param, core::IndexContext::Pointer &context) { const auto &ivf_search_param = std::dynamic_pointer_cast(search_param); context->set_topk(ivf_search_param->topk); context->set_fetch_vector(ivf_search_param->fetch_vector); if (ivf_search_param->filter) { context->set_filter(std::move(*ivf_search_param->filter)); } if (ivf_search_param->radius > 0.0f) { context->set_threshold(ivf_search_param->radius); } if (ivf_search_param->nprobe > 0) { // TODO: 1. sparse; 2. default ef ailego::Params params; // need fix params.set(core::PARAM_IVF_BUILDER_CENTROID_COUNT, ivf_search_param->nprobe); context->update(params); } return 0; } int IVFIndex::Merge(const std::vector &indexes, const IndexFilter &filter, const MergeOptions &options) { int pre_ret = Index::Merge(indexes, filter, options); if (pre_ret != 0) { return pre_ret; } auto dumper = core::IndexFactory::CreateDumper("FileDumper"); dumper->create(file_path_); builder_->dump(dumper); dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { LOG_ERROR("Failed to open storage, path: %s, err: %s", file_path_.c_str(), core::IndexError::What(ret)); return core::IndexError_Runtime; } if (streamer_ == nullptr || streamer_->open(storage_) != 0) { LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } is_trained_ = true; return 0; } } // namespace zvec::core_interface ================================================ FILE: src/core/interface/utils/utils.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include namespace zvec { namespace core_interface { template constexpr bool extract_enum_from_json(const ailego::JsonObject &json_obj, const char *key, EnumType &enum_value, ailego::JsonValue &tmp_json_value) { if (json_obj.has(key)) { if (json_obj.get(key, &tmp_json_value); tmp_json_value.is_string()) { auto optional_enum_value = magic_enum::enum_cast(tmp_json_value.as_stl_string()); if (optional_enum_value.has_value()) { enum_value = optional_enum_value.value(); } else { LOG_ERROR("Invalid enum value for key: %s, value: %s", key, tmp_json_value.as_c_string()); return false; } } else { LOG_ERROR("Invalid json field type for key: %s", key); return false; } } return true; } template constexpr bool extract_value_from_json(const ailego::JsonObject &json_obj, const char *key, T &value, ailego::JsonValue &tmp_json_value) { if (json_obj.has(key)) { json_obj.get(key, &tmp_json_value); if constexpr (std::is_same_v) { if (tmp_json_value.is_boolean()) { value = tmp_json_value.as_bool(); } else { LOG_ERROR("Invalid json field type for key: %s; expected: boolean", key); return false; } } else if constexpr (std::is_floating_point_v) { if (tmp_json_value.is_float() || tmp_json_value.is_integer()) { value = static_cast(tmp_json_value.as_float()); } else { LOG_ERROR("Invalid json field type for key: %s; expected: float", key); return false; } } else if constexpr (std::is_integral_v) { if (tmp_json_value.is_integer()) { value = static_cast(tmp_json_value.as_integer()); } else { LOG_ERROR("Invalid json field type for key: %s; expected: integer", key); return false; } } else { abort(); } } return true; } #define DESERIALIZE_ENUM_FIELD(json_obj, field_name, EnumType) \ { \ ailego::JsonValue tmp_json_value; \ if (!extract_enum_from_json(json_obj, #field_name, field_name, \ tmp_json_value)) { \ LOG_ERROR("Error when deserialize json - field:%s", #field_name); \ return false; \ } \ } #define DESERIALIZE_VALUE_FIELD(json_obj, field_name) \ { \ ailego::JsonValue tmp_json_value; \ if (!extract_value_from_json(json_obj, #field_name, field_name, \ tmp_json_value)) { \ LOG_ERROR("Error when deserialize json - field:%s", #field_name); \ return false; \ } \ } } // namespace core_interface } // namespace zvec ================================================ FILE: src/core/metric/CMakeLists.txt ================================================ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) cc_library( NAME core_metric STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc LIBS zvec_ailego zvec_turbo core_framework INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${PROXIMA_ZVEC_VERSION}" ) ================================================ FILE: src/core/metric/cosine_metric.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec { namespace core { //! Retrieve distance function for index features inline IndexMetric::MatrixDistanceHandle CosineDistanceMatrixFp32(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features inline IndexMetric::MatrixDistanceHandle CosineDistanceMatrixFp16(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute), reinterpret_cast( ailego::CosineDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } /*! Cosine Metric */ class CosineMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType ft = meta.data_type(); if (ft != IndexMeta::DataType::DT_FP16 && ft != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(ft) != meta.unit_size()) { return IndexError_Unsupported; } data_type_ = ft; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::CosineDistanceMatrix::Compute); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::CosineDistanceMatrix::Compute); default: return nullptr; } } //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { if (m != 1 || n != 1) { return nullptr; } return distance(); } //! Retrieve distance function for query MatrixBatchDistance batch_distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); default: return nullptr; } } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; INDEX_FACTORY_REGISTER_METRIC_ALIAS(Cosine, CosineMetric); } // namespace core } // namespace zvec ================================================ FILE: src/core/metric/euclidean_metric.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include namespace zvec { namespace core { //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle SquaredEuclideanDistanceMatrixFp32(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle SquaredEuclideanDistanceMatrixFp16(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } static inline IndexMetric::MatrixDistanceHandle SquaredEuclideanDistanceMatrixInt8(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features in Int4 static inline IndexMetric::MatrixDistanceHandle SquaredEuclideanDistanceMatrixInt4(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle EuclideanDistanceMatrixFp32( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle EuclideanDistanceMatrixFp16( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } static inline IndexMetric::MatrixDistanceHandle EuclideanDistanceMatrixInt8( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features in Int4 static inline IndexMetric::MatrixDistanceHandle EuclideanDistanceMatrixInt4( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute), reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle HammingDistanceMatrix32( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #if defined(AILEGO_M64) static inline IndexMetric::MatrixDistanceHandle HammingDistanceMatrix64( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #endif // AILEGO_M64 //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle HammingSquareRootDistanceMatrix32(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #if defined(AILEGO_M64) static inline IndexMetric::MatrixDistanceHandle HammingSquareRootDistanceMatrix64(size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute), reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #endif // AILEGO_M64 /*! Squared Euclidean Distance Metric */ class SquaredEuclideanMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType dt = meta.data_type(); if (dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_FP32 && dt != IndexMeta::DataType::DT_INT8 && dt != IndexMeta::DataType::DT_INT4 && dt != IndexMeta::DataType::DT_BINARY32 && dt != IndexMeta::DataType::DT_BINARY64) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(dt) != meta.unit_size()) { return IndexError_Unsupported; } data_type_ = dt; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_BINARY32: return reinterpret_cast( ailego::HammingDistanceMatrix::Compute); #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: return reinterpret_cast( ailego::HammingDistanceMatrix::Compute); #endif // AILEGO_M64 case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); default: return nullptr; } } //! Retrieve sparse distance function for query MatrixSparseDistance sparse_distance(void) const override { return reinterpret_cast( ailego::SquaredEuclideanSparseDistanceMatrix::Compute); } //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { switch (data_type_) { case IndexMeta::DataType::DT_BINARY32: return HammingDistanceMatrix32(m, n); #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: return HammingDistanceMatrix64(m, n); #endif // AILEGO_M64 case IndexMeta::DataType::DT_FP16: return SquaredEuclideanDistanceMatrixFp16(m, n); case IndexMeta::DataType::DT_FP32: return SquaredEuclideanDistanceMatrixFp32(m, n); case IndexMeta::DataType::DT_INT8: return SquaredEuclideanDistanceMatrixInt8(m, n); case IndexMeta::DataType::DT_INT4: return SquaredEuclideanDistanceMatrixInt4(m, n); default: return nullptr; } } //! Retrieve distance function for query MatrixBatchDistance batch_distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_BINARY32: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); #endif // AILEGO_M64 case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); default: return nullptr; } } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; /*! Euclidean Distance Metric */ class EuclideanMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType dt = meta.data_type(); if (dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_FP32 && dt != IndexMeta::DataType::DT_INT8 && dt != IndexMeta::DataType::DT_INT4 && dt != IndexMeta::DataType::DT_BINARY32 && dt != IndexMeta::DataType::DT_BINARY64) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(dt) != meta.unit_size()) { return IndexError_Unsupported; } data_type_ = dt; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_BINARY32: return reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute); #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: return reinterpret_cast( ailego::HammingSquareRootDistanceMatrix::Compute); #endif // AILEGO_M64 case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::EuclideanDistanceMatrix::Compute); default: return nullptr; } } //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { switch (data_type_) { case IndexMeta::DataType::DT_BINARY32: return HammingSquareRootDistanceMatrix32(m, n); #if defined(AILEGO_M64) case IndexMeta::DataType::DT_BINARY64: return HammingSquareRootDistanceMatrix64(m, n); #endif // AILEGO_M64 case IndexMeta::DataType::DT_FP16: return EuclideanDistanceMatrixFp16(m, n); case IndexMeta::DataType::DT_FP32: return EuclideanDistanceMatrixFp32(m, n); case IndexMeta::DataType::DT_INT8: return EuclideanDistanceMatrixInt8(m, n); case IndexMeta::DataType::DT_INT4: return EuclideanDistanceMatrixInt4(m, n); default: return nullptr; } } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; /*! Squared Euclidean Sparse Metric */ class SquaredEuclideanSparseMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType data_type = meta.data_type(); if (data_type != IndexMeta::DataType::DT_FP16 && data_type != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(data_type) != meta.unit_size()) { return IndexError_Unsupported; } data_type_ = data_type; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.data_type() == meta.data_type() && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.unit_size() == meta.unit_size()); } //! Retrieve sparse distance function for query MatrixSparseDistance sparse_distance(void) const override { return reinterpret_cast( ailego::SquaredEuclideanSparseDistanceMatrix::Compute); } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; INDEX_FACTORY_REGISTER_METRIC_ALIAS(SquaredEuclidean, SquaredEuclideanMetric); INDEX_FACTORY_REGISTER_METRIC_ALIAS(Euclidean, EuclideanMetric); INDEX_FACTORY_REGISTER_METRIC_ALIAS(SquaredEuclideanSparse, SquaredEuclideanSparseMetric); } // namespace core } // namespace zvec ================================================ FILE: src/core/metric/hamming_metric.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "ailego/math_batch/distance_batch.h" namespace zvec { namespace core { //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle HammingDistanceMatrix32( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #if defined(AILEGO_M64) static inline IndexMetric::MatrixDistanceHandle HammingDistanceMatrix64( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), nullptr}, {reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute), reinterpret_cast( ailego::HammingDistanceMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } #endif // AILEOG_M64 /*! Hamming Metric */ class HammingMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { if (meta.data_type() != IndexMeta::DataType::DT_BINARY32 && meta.data_type() != IndexMeta::DataType::DT_BINARY64) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(meta.data_type()) != meta.unit_size()) { return IndexError_Unsupported; } feature_type_ = meta.data_type(); params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == feature_type_ && meta.unit_size() == IndexMeta::UnitSizeof(feature_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == feature_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(feature_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { #if defined(AILEGO_M64) if (feature_type_ == IndexMeta::DataType::DT_BINARY64) { return reinterpret_cast( ailego::HammingDistanceMatrix::Compute); } #endif if (feature_type_ == IndexMeta::DataType::DT_BINARY32) { return reinterpret_cast( ailego::HammingDistanceMatrix::Compute); } return nullptr; } MatrixBatchDistance batch_distance(void) const override { #if defined(AILEGO_M64) if (feature_type_ == IndexMeta::DataType::DT_BINARY64) { return reinterpret_cast( ailego::BaseDistance::ComputeBatch); } #endif if (feature_type_ == IndexMeta::DataType::DT_BINARY32) { return reinterpret_cast( ailego::BaseDistance::ComputeBatch); } return nullptr; } //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { #if defined(AILEGO_M64) if (feature_type_ == IndexMeta::DataType::DT_BINARY64) { return HammingDistanceMatrix64(m, n); } #endif if (feature_type_ == IndexMeta::DataType::DT_BINARY32) { return HammingDistanceMatrix32(m, n); } return nullptr; } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType feature_type_{IndexMeta::DataType::DT_BINARY32}; ailego::Params params_{}; }; INDEX_FACTORY_REGISTER_METRIC_ALIAS(Hamming, HammingMetric); } // namespace core } // namespace zvec ================================================ FILE: src/core/metric/inner_product_metric.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include namespace zvec { namespace core { //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle MinusInnerProductMatrixFp32( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features static inline IndexMetric::MatrixDistanceHandle MinusInnerProductMatrixFp16( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } static inline IndexMetric::MatrixDistanceHandle MinusInnerProductMatrixInt8( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } //! Retrieve distance function for index features in Int4 static inline IndexMetric::MatrixDistanceHandle MinusInnerProductMatrixInt4( size_t m, size_t n) { static const IndexMetric::MatrixDistanceHandle distance_table[6][6] = { {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr, nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), nullptr}, {reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute), reinterpret_cast( ailego::MinusInnerProductMatrix::Compute)}, }; if (m > 32 || n > 32 || ailego_popcount(m) != 1 || ailego_popcount(n) != 1) { return nullptr; } return distance_table[ailego_ctz(m)][ailego_ctz(n)]; } /*! Inner Product Metric */ class InnerProductMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::MetaType mt = meta.meta_type(); if (mt != IndexMeta::MetaType::MT_DENSE) { return IndexError_Unsupported; } IndexMeta::DataType dt = meta.data_type(); if (dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_FP32 && dt != IndexMeta::DataType::DT_INT8 && dt != IndexMeta::DataType::DT_INT4) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(dt) != meta.unit_size()) { return IndexError_Unsupported; } meta_type_ = mt; data_type_ = dt; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::MinusInnerProductMatrix::Compute); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::MinusInnerProductMatrix::Compute); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::MinusInnerProductMatrix::Compute); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::MinusInnerProductMatrix::Compute); default: return nullptr; } } //! Retrieve sparse distance function for query MatrixSparseDistance sparse_distance(void) const override { return reinterpret_cast( ailego::MinusInnerProductSparseMatrix::Compute); } //! Retrieve distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP16: return MinusInnerProductMatrixFp16(m, n); case IndexMeta::DataType::DT_FP32: return MinusInnerProductMatrixFp32(m, n); case IndexMeta::DataType::DT_INT8: return MinusInnerProductMatrixInt8(m, n); case IndexMeta::DataType::DT_INT4: return MinusInnerProductMatrixInt4(m, n); default: return nullptr; } } //! Retrieve distance function for query MatrixBatchDistance batch_distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::BaseDistance::ComputeBatch); default: return nullptr; } } //! Normalize result void normalize(float *score) const override { *score = -(*score); } //! Denormalize threshold void denormalize(float *score) const override { *score = -(*score); } //! Retrieve if it supports normalization bool support_normalize(void) const override { return true; } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query measure object of this index measure Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::MetaType meta_type_{IndexMeta::MetaType::MT_DENSE}; IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; /*! Normalized Cosine Metric */ class NormalizedCosineMetric : public InnerProductMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType dt = meta.data_type(); if (dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } InnerProductMetric::init(meta, index_params); return 0; } //! Normalize result void normalize(float *score) const override { *score = 1 + (*score); } //! Denormalize threshold void denormalize(float *score) const override { *score -= 1; } }; /*! Inner Product Sparse Metric */ class InnerProductSparseMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { IndexMeta::DataType dt = meta.data_type(); if (dt != IndexMeta::DataType::DT_FP16 && dt != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; } if (IndexMeta::UnitSizeof(dt) != meta.unit_size()) { return IndexError_Unsupported; } data_type_ = dt; params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.data_type() == meta.data_type() && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.unit_size() == meta.unit_size()); } //! Retrieve distance function for query MatrixDistance distance(void) const override { return nullptr; } //! Retrieve sparse distance function for query MatrixSparseDistance sparse_distance(void) const override { switch (data_type_) { case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::MinusInnerProductSparseMatrix::Compute); case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::MinusInnerProductSparseMatrix::Compute); default: return nullptr; } } //! Normalize result void normalize(float *score) const override { *score = -(*score); } //! Denormalize threshold void denormalize(float *score) const override { *score = -(*score); } //! Retrieve if it supports normalization bool support_normalize(void) const override { return true; } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Retrieve query measure object of this index measure Pointer query_metric(void) const override { return nullptr; } private: IndexMeta::DataType data_type_{IndexMeta::DataType::DT_FP32}; ailego::Params params_{}; }; INDEX_FACTORY_REGISTER_METRIC_ALIAS(InnerProduct, InnerProductMetric); INDEX_FACTORY_REGISTER_METRIC_ALIAS(NormalizedCosine, NormalizedCosineMetric); INDEX_FACTORY_REGISTER_METRIC_ALIAS(InnerProductSparse, InnerProductSparseMetric); } // namespace core } // namespace zvec ================================================ FILE: src/core/metric/metric_params.h ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace zvec { namespace core { //! MipsEuclideanMetric static const std::string MIPS_EUCLIDEAN_METRIC_M_VALUE = "mips_euclidean.metric.m_value"; static const std::string MIPS_EUCLIDEAN_METRIC_U_VALUE = "mips_euclidean.metric.u_value"; static const std::string MIPS_EUCLIDEAN_METRIC_MAX_L2_NORM = "mips_euclidean.metric.max_l2_norm"; static const std::string MIPS_EUCLIDEAN_METRIC_INJECTION_TYPE = "mips_euclidean.metric.injection_type"; //! QuantizedInteger Metric static const std::string QUANTIZED_INTEGER_METRIC_ORIGIN_METRIC_NAME = "proxima.quantized_integer.metric.origin_metric_name"; static const std::string QUANTIZED_INTEGER_METRIC_ORIGIN_METRIC_PARAMS = "proxima.quantized_integer.metric.origin_metric_params"; } // namespace core } // namespace zvec ================================================ FILE: src/core/metric/mips_euclidean_metric.cc ================================================ // Copyright 2025-present the zvec project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "metric_params.h" namespace zvec { namespace core { /*! Mips Squared Euclidean Metric */ template class MipsSquaredEuclideanMetric : public IndexMetric { public: //! Initialize Metric int init(const IndexMeta &meta, const ailego::Params &index_params) override { data_type_ = meta.data_type(); dimension_ = meta.dimension(); int injection_type = static_cast(kDefaultInjectionType); index_params.get(MIPS_EUCLIDEAN_METRIC_INJECTION_TYPE, &injection_type); if (injection_type >= static_cast(Injection::kNumInjections)) { LOG_WARN("Unsupported injection_type %u, using '%s' instead", injection_type, InjectionName(0)); injection_type = static_cast(Injection::kLocalizedSpherical); } injection_ = static_cast(injection_type); LOG_DEBUG( "Initializing MipsSquaredEuclideanMetric with injection %s" " type %d dimension %d", InjectionName(injection_), data_type_, dimension_); float max_l2_norm = 0.0f; float u_value = 0.0f; index_params.get(MIPS_EUCLIDEAN_METRIC_M_VALUE, &m_value_); index_params.get(MIPS_EUCLIDEAN_METRIC_U_VALUE, &u_value); index_params.get(MIPS_EUCLIDEAN_METRIC_MAX_L2_NORM, &max_l2_norm); CheckAndFixM(injection_, &m_value_); CheckAndFixU(injection_, m_value_, &u_value); squared_u_value_ = u_value * u_value; max_squared_l2_norm_ = max_l2_norm * max_l2_norm; if (injection_ == Injection::kIdentity || injection_ == Injection::kLocalizedSpherical) { eta_ = 0.0f; } else if (max_squared_l2_norm_ < std::numeric_limits::epsilon()) { eta_ = kDefaultEta; } else { eta_ = squared_u_value_ / max_squared_l2_norm_; } switch (data_type_) { case IndexMeta::DataType::DT_FP32: squared_norm2_handle_ = reinterpret_cast( ailego::SquaredNorm2Matrix::Compute); break; case IndexMeta::DataType::DT_FP16: squared_norm2_handle_ = reinterpret_cast( ailego::SquaredNorm2Matrix::Compute); break; case IndexMeta::DataType::DT_INT8: squared_norm2_handle_ = reinterpret_cast( ailego::SquaredNorm2Matrix::Compute); break; case IndexMeta::DataType::DT_INT4: squared_norm2_handle_ = reinterpret_cast( ailego::SquaredNorm2Matrix::Compute); break; default: return IndexError_Unsupported; } query_metric_ = IndexFactory::CreateMetric(kQueryMetric); if (!query_metric_) { LOG_ERROR("Failed to create metric %s", kQueryMetric); return IndexError_NoExist; } int ret = query_metric_->init(meta, ailego::Params()); if (ret != 0) { LOG_ERROR("Failed to initialize metric %s", kQueryMetric); return ret; } params_ = index_params; return 0; } //! Cleanup Metric int cleanup(void) override { eta_ = 0.0f; m_value_ = 0; squared_u_value_ = 0.0f; max_squared_l2_norm_ = 0.0f; query_metric_.reset(); return 0; } //! Retrieve if it matched bool is_matched(const IndexMeta &meta) const override { return (meta.data_type() == data_type_ && meta.unit_size() == IndexMeta::UnitSizeof(data_type_)); } //! Retrieve if it matched bool is_matched(const IndexMeta &meta, const IndexQueryMeta &qmeta) const override { return (qmeta.data_type() == data_type_ && qmeta.unit_size() == IndexMeta::UnitSizeof(data_type_) && qmeta.dimension() == meta.dimension()); } //! Retrieve distance function for query MatrixBatchDistance batch_distance() const override { MatrixDistance dist_func = distance(); return [=](const void **m, const void *q, size_t num, size_t dim, float *out) { for (size_t i = 0; i < num; ++i) { dist_func(m[i], q, dim, out + i); } }; } //! Retrieve distance function for query MatrixDistance distance(void) const override { if (injection_ == Injection::kLocalizedSpherical) { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, 0.0f, out); }; case IndexMeta::DataType::DT_FP16: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix:: Compute(reinterpret_cast(m), reinterpret_cast(q), dim, 0.0f, out); }; case IndexMeta::DataType::DT_INT8: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, 0.0f, out); }; case IndexMeta::DataType::DT_INT4: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, 0.0f, out); }; default: return nullptr; } } if (injection_ == Injection::kRepeatedQuadratic) { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, m_value_, eta_, out); }; case IndexMeta::DataType::DT_FP16: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix:: Compute(reinterpret_cast(m), reinterpret_cast(q), dim, m_value_, eta_, out); }; case IndexMeta::DataType::DT_INT8: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, m_value_, eta_, out); }; case IndexMeta::DataType::DT_INT4: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, m_value_, eta_, out); }; default: return nullptr; } } if (injection_ == Injection::kSpherical) { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, eta_, out); }; case IndexMeta::DataType::DT_FP16: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix:: Compute(reinterpret_cast(m), reinterpret_cast(q), dim, eta_, out); }; case IndexMeta::DataType::DT_INT8: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, eta_, out); }; case IndexMeta::DataType::DT_INT4: return [&](const void *m, const void *q, size_t dim, float *out) { ailego::MipsSquaredEuclideanDistanceMatrix::Compute( reinterpret_cast(m), reinterpret_cast(q), dim, eta_, out); }; default: return nullptr; } } if (injection_ == Injection::kIdentity) { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_FP16: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT8: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); case IndexMeta::DataType::DT_INT4: return reinterpret_cast( ailego::SquaredEuclideanDistanceMatrix::Compute); default: return nullptr; } } return nullptr; } //! Retrieve distance function for query MatrixSparseDistance sparse_distance(void) const override { if (injection_ == Injection::kLocalizedSpherical) { return [&](const void *m_sparse, const void *q_sparse, float *out) { ailego::MipsSquaredEuclideanSparseDistanceMatrix::Compute( m_sparse, q_sparse, out); }; } if (injection_ == Injection::kRepeatedQuadratic) { LOG_ERROR( "Repeated Quadratic is not supported in MipsEuclideanMetric for " "Hybrid Vector!"); return nullptr; } if (injection_ == Injection::kSpherical) { LOG_ERROR( "Spherical is not supported in MipsEuclideanMetric for Hybrid " "Vector!"); return nullptr; } if (injection_ == Injection::kIdentity) { LOG_ERROR( "Identity is not supported in MipsEuclideanMetric for Hybrid " "Vector!"); return nullptr; } return nullptr; } //! Retrieve matrix distance function for index features MatrixDistance distance_matrix(size_t m, size_t n) const override { if (injection_ == Injection::kLocalizedSpherical) { SphericalHandle compute; switch (data_type_) { case IndexMeta::DataType::DT_FP32: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_FP16: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_INT8: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_INT4: compute = DistanceMatrixCompute(m, n); break; default: return nullptr; } return [=](const void *d, const void *q, size_t dim, float *out) { compute(d, q, dim, 0.0f, out); }; } if (injection_ == Injection::kRepeatedQuadratic) { RepeatedQuadraticHandle compute; switch (data_type_) { case IndexMeta::DataType::DT_FP32: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_FP16: compute = DistanceMatrixCompute( m, n); break; case IndexMeta::DataType::DT_INT8: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_INT4: compute = DistanceMatrixCompute(m, n); break; default: return nullptr; } return [=](const void *d, const void *q, size_t dim, float *out) { compute(d, q, dim, m_value_, eta_, out); }; } if (injection_ == Injection::kSpherical) { SphericalHandle compute; switch (data_type_) { case IndexMeta::DataType::DT_FP32: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_FP16: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_INT8: compute = DistanceMatrixCompute(m, n); break; case IndexMeta::DataType::DT_INT4: compute = DistanceMatrixCompute(m, n); break; default: return nullptr; } return [=](const void *d, const void *q, size_t dim, float *out) { compute(d, q, dim, eta_, out); }; } if (injection_ == Injection::kIdentity) { switch (data_type_) { case IndexMeta::DataType::DT_FP32: return DistanceMatrixCompute(m, n); case IndexMeta::DataType::DT_FP16: return DistanceMatrixCompute(m, n); case IndexMeta::DataType::DT_INT8: return DistanceMatrixCompute(m, n); case IndexMeta::DataType::DT_INT4: return DistanceMatrixCompute(m, n); default: return nullptr; } } return nullptr; } //! Normalize result void normalize(float *score) const override { query_metric_->normalize(score); } //! Denormalize threshold void denormalize(float *score) const override { query_metric_->denormalize(score); } //! Retrieve if it supports normalization bool support_normalize(void) const override { return query_metric_->support_normalize(); } //! Retrieve params of Metric const ailego::Params ¶ms(void) const override { return params_; } //! Train the metric int train(const void *vec, size_t dim) override { if (eta_ == 0.0f) { // No global norm scaling => no training. return 0; } if (!squared_norm2_handle_) { return IndexError_Unsupported; } float score; squared_norm2_handle_(vec, dim, &score); if (score > max_squared_l2_norm_) { max_squared_l2_norm_ = score; const float max_l2_norm = std::sqrt(score); params_.set(MIPS_EUCLIDEAN_METRIC_MAX_L2_NORM, max_l2_norm); if (max_squared_l2_norm_ < 1.0 && max_squared_l2_norm_ > squared_u_value_) { squared_u_value_ = max_squared_l2_norm_; params_.set(MIPS_EUCLIDEAN_METRIC_U_VALUE, max_l2_norm); } eta_ = squared_u_value_ / max_squared_l2_norm_; } return 0; } //! Retrieve if it supports training bool support_train(void) const override { // No global norm scaling => eta_ == 0 => no training. return eta_ != 0.0f; } //! Retrieve query metric object of this index metric Pointer query_metric(void) const override { return query_metric_; } private: //! Type of MipsSquaredEuclideanDistanceMatrix::Compute overloaded for // Spherical injection and LocalizedSpherical nonmetric. template using SphericalHandle = void (*)(const T *m, const T *q, size_t dim, float eta, float *out); //! Type of MipsSquaredEuclideanDistanceMatrix::Compute overloaded for // RepeatedQuadratic injection. template using RepeatedQuadraticHandle = void (*)(const T *m, const T *q, size_t dim, size_t m_value, float eta, float *out); //! Type of squared L2 norm function. using SquaredNorm2Handle = void (*)(const void *m, size_t dim, float *out); enum struct Injection { // Type of injective mapping into Euclidean space. kLocalizedSpherical = 0, // spherical with pair-only max-norm kSpherical = 1, // require global scaling/training kRepeatedQuadratic = 2, // require global scaling/training kIdentity = 3, // plain Euclidean distance kNumInjections = 4 }; static const char *InjectionName(int injection) { static const char *injection_names[] = {"LocalizedSpherical", "Spherical", "RepeatedQuadratic", "Identity"}; if (injection >= 0 && injection < static_cast(Injection::kNumInjections)) { return injection_names[injection]; } return "Invalid"; } static const char *InjectionName(Injection injection) { return InjectionName(static_cast(injection)); } // Checks (and fixes) `*m_value`, no. additional dimensions for injection. // `dim` is the original dimension, used ONLY by RepeatedQuadratic // injection, where dim = 1 induces the default *m_value = 3. It's // positioned last to allow other injections to skip it. // Returns true if `*m_value` is modified. static bool CheckAndFixM(Injection injection, uint32_t *m_value) { if (injection == Injection::kRepeatedQuadratic) { if (*m_value == 0) { *m_value = 3u; // Recommend value in paper (3.5 Practical // Recommendation of Parameters) return true; } } else if (injection == Injection::kSpherical) { if (*m_value != 1) { if (*m_value != 0) { LOG_WARN("M value (%u) set to 1 for Spherical injection", *m_value); } *m_value = 1; return true; } } else { // kLocalizedSpherical, kIdentity, or kInvalid if (*m_value != 0) { LOG_WARN("M value (%u) set to 0 for %s injections", *m_value, InjectionName(injection)); *m_value = 0; return true; } } return false; } // Checks and fixes `*u_value`, global L2 norm scalar. // `m_value` is no. additional dimensions, used ONLY by RepeatedQuadratic // injection. It's positioned last to allow other injections to skip it. // Returns true if `*u_value` is set to a new value. static bool CheckAndFixU(Injection injection, uint32_t m_value, float *u_value) { if (injection == Injection::kRepeatedQuadratic) { if (*u_value <= std::numeric_limits::epsilon() || *u_value >= 1.0) { // Try computing a default U value constexpr float kLogError = -5.0; // log_10(distance_error) float new_u_value = std::pow(10, kLogError / (1 << (m_value + 1))); if (*u_value != 0) { LOG_WARN("U value (%f) set to %f for RepeatedQuadratic injection", *u_value, new_u_value); } *u_value = new_u_value; return true; } else if (std::pow(*u_value, (1 << m_value)) < std::numeric_limits::epsilon()) { LOG_WARN( "U value %f is too small, may cause loss of distance precision", *u_value); } } else if (injection == Injection::kSpherical) { // Spherical injection requires ||x'|| <= 1.0 for computing // std::sqrt(1 - ||x'||^2), x' = u_value * x / max_norm. Set u_value // to slightly < 1.0 in case of precision loss in float computation. if (*u_value <= std::numeric_limits::epsilon() || *u_value >= 1.0) { static constexpr float kSphericalUValue = 1.0f - 1e-3; if (*u_value != 0.0f) { LOG_WARN("U value (%f) set to %f for Spherical injection", *u_value, kSphericalUValue); } *u_value = kSphericalUValue; return true; } } else { // kLocalizedSpherical, kIdentity, or kInvalid if (*u_value != 1.0) { if (*u_value != 0) { LOG_WARN("U value (%f) set to 1.0 for %s injection", *u_value, InjectionName(injection)); } *u_value = 1.0; return true; } } return false; } private: //! Type of basic DistanceMatrix::Compute function with typed parameter. template using TypedDistanceHandle = void (*)(const T *m, const T *q, size_t dim, float *out); //! Returns m x n distance matrix compute function. // Handle is used to resolve potential DistanceMatrix::Compute overload. template