Full Code of NVIDIA-Merlin/HierarchicalKV for AI

master ae24eecde0b4 cached
104 files
2.0 MB
526.9k tokens
60 symbols
1 requests
Download .txt
Showing preview only (2,107K chars total). Download the full file or copy to clipboard to get everything.
Repository: NVIDIA-Merlin/HierarchicalKV
Branch: master
Commit: ae24eecde0b4
Files: 104
Total size: 2.0 MB

Directory structure:
gitextract_3c35qd95/

├── .bazeliskrc
├── .bazelrc
├── .clang-format
├── .github/
│   └── workflows/
│       ├── blossom-ci.yml
│       ├── docs-build.yaml
│       ├── docs-preview-pr.yaml
│       ├── docs-remove-stale-reviews.yaml
│       └── docs-sched-rebuild.yaml
├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── STYLE_GUIDE.md
├── WORKSPACE
├── bazel_build.sh
├── benchmark/
│   ├── BUILD
│   ├── benchmark_util.cuh
│   ├── dual_bucket_benchmark.cc.cu
│   ├── find_with_missed_keys_benchmark.cc.cu
│   └── merlin_hashtable_benchmark.cc.cu
├── build_deps/
│   ├── gpus/
│   │   ├── BUILD
│   │   ├── check_cuda_libs.py
│   │   ├── configure.bzl
│   │   ├── crosstool/
│   │   │   ├── BUILD
│   │   │   ├── BUILD.tpl
│   │   │   ├── cc_toolchain_config.bzl.tpl
│   │   │   └── crosstool_compiler_wrapper.tpl
│   │   ├── cuda/
│   │   │   ├── BUILD
│   │   │   ├── BUILD.tpl
│   │   │   ├── build_defs.bzl.tpl
│   │   │   ├── cuda_config.h.tpl
│   │   │   └── cuda_config.py.tpl
│   │   └── find_cuda_config.py
│   └── remote_config/
│       ├── BUILD
│       ├── BUILD.tpl
│       ├── common.bzl
│       └── remote_platform_configure.bzl
├── cmake/
│   └── modules/
│       └── ClangFormat.cmake
├── docs/
│   ├── Makefile
│   ├── README.md
│   ├── make.bat
│   ├── requirements-doc.txt
│   └── source/
│       ├── _static/
│       │   ├── .gitkeep
│       │   └── css/
│       │       ├── banner.css
│       │       └── custom.css
│       ├── _templates/
│       │   ├── footer.html
│       │   └── versions.html
│       ├── conf.py
│       ├── index.rst
│       └── toc.yaml
├── include/
│   ├── BUILD
│   ├── merlin/
│   │   ├── BUILD
│   │   ├── allocator.cuh
│   │   ├── array_kernels.cuh
│   │   ├── core_kernels/
│   │   │   ├── BUILD
│   │   │   ├── accum_or_assign.cuh
│   │   │   ├── contains.cuh
│   │   │   ├── dual_bucket_lookup.cuh
│   │   │   ├── dual_bucket_upsert.cuh
│   │   │   ├── dual_bucket_utils.cuh
│   │   │   ├── find_or_insert.cuh
│   │   │   ├── find_ptr_or_insert.cuh
│   │   │   ├── group_lock_kernels.cuh
│   │   │   ├── kernel_utils.cuh
│   │   │   ├── lookup.cuh
│   │   │   ├── lookup_ptr.cuh
│   │   │   ├── update.cuh
│   │   │   ├── update_score.cuh
│   │   │   ├── update_values.cuh
│   │   │   ├── upsert.cuh
│   │   │   └── upsert_and_evict.cuh
│   │   ├── core_kernels.cuh
│   │   ├── debug.hpp
│   │   ├── flexible_buffer.cuh
│   │   ├── group_lock.cuh
│   │   ├── memory_pool.cuh
│   │   ├── multi_vector.hpp
│   │   ├── optimizers.cuh
│   │   ├── types.cuh
│   │   └── utils.cuh
│   ├── merlin_hashtable.cuh
│   └── merlin_localfile.hpp
├── run_all_tests.sh
└── tests/
    ├── accum_or_assign_test.cc.cu
    ├── assign_score_test.cc.cu
    ├── assign_values_test.cc.cu
    ├── dual_bucket_test.cc.cu
    ├── dynamic_max_capacity_test.cc.cu
    ├── export_batch_if_test.cc.cu
    ├── find_or_insert_ptr_lock_test.cc.cu
    ├── find_or_insert_ptr_test.cc.cu
    ├── find_or_insert_test.cc.cu
    ├── find_with_missed_keys_test.cc.cu
    ├── group_lock_test.cc.cu
    ├── insert_and_evict_test.cc.cu
    ├── lock_unlock_test.cc.cu
    ├── memory_pool_test.cc.cu
    ├── merlin_hashtable_test.cc.cu
    ├── reserved_keys_test.cc.cu
    ├── save_and_load_test.cc.cu
    ├── test_util.cuh
    └── uint32_score_test.cc.cu

================================================
FILE CONTENTS
================================================

================================================
FILE: .bazeliskrc
================================================
USE_BAZEL_VERSION=5.0.0


================================================
FILE: .bazelrc
================================================
build -c opt
build --copt -O3
build --copt -pthread
build --linkopt -pthread
build --linkopt -ldl
build --incompatible_linkopts_to_linklibs
build --copt -g --strip=never
build --experimental_repo_remote_exec

# By default, build HKV in C++ 17 mode.
build --cxxopt=-std=c++17
build --host_cxxopt=-std=c++17

# This config refers to building CUDA kernels with nvcc.
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain

# CUDA options
build:cuda --action_env GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
build:cuda --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda"
build:cuda --action_env CUDA_VERSION="11"
build:cuda --action_env CUDNN_VERSION="8"
build:cuda --action_env CUDNN_INSTALL_PATH="/usr/"
build:cuda --action_env CUDA_COMPUTE_CAPABILITIES="7.5"


================================================
FILE: .clang-format
================================================
BasedOnStyle: Google
DerivePointerAlignment: false
IncludeBlocks: Merge
SortIncludes: true


================================================
FILE: .github/workflows/blossom-ci.yml
================================================
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# A workflow to trigger ci on hybrid infra (github + self hosted runner)
name: Blossom-CI
on:
  issue_comment:
    types: [created]
  workflow_dispatch:
      inputs:
          platform:
            description: 'runs-on argument'     
            required: false
          args:
            description: 'argument'     
            required: false
jobs:
  Authorization:
    name: Authorization
    runs-on: blossom 
    outputs:
      args: ${{ env.args }}
      
    # This job only runs for pull request comments
    if: |
         (github.actor == 'EmmaQiaoCh' || github.actor == 'rhdong' || github.actor == 'Ranjeet-Nvidia' ||  github.actor == 'jiashuy') &&
         github.event.comment.body == '/blossom-ci'  
    steps:
      - name: Check if comment is issued by authorized person
        run: blossom-ci
        env:
          OPERATION: 'AUTH'
          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
        
  Vulnerability-scan:
    name: Vulnerability scan
    needs: [Authorization]
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v2
        with:
          repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}
          ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}
          lfs: 'true'
         
      - name: Run blossom action
        uses: NVIDIA/blossom-action@main
        env:
          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}
        with:
          args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}
          args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}
          args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}
          
  Job-trigger:
    name: Start ci job
    needs: [Vulnerability-scan]
    runs-on: blossom
    steps:
      - name: Start ci job
        run: blossom-ci
        env:
          OPERATION: 'START-CI-JOB'
          CI_SERVER: ${{ secrets.CI_SERVER }}
          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
              
  Upload-Log:
    name: Upload log
    runs-on: blossom
    if : github.event_name == 'workflow_dispatch'
    steps:
      - name: Jenkins log for pull request ${{ fromJson(github.event.inputs.args).pr }} (click here)
        run: blossom-ci
        env:
          OPERATION: 'POST-PROCESSING'
          CI_SERVER: ${{ secrets.CI_SERVER }}
          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}


================================================
FILE: .github/workflows/docs-build.yaml
================================================
name: docs-build

on:
  pull_request:
    branches: [master]

jobs:
  build:
    runs-on: "ubuntu-latest"

    steps:
      - uses: actions/checkout@v3
      - name: Set up Python 3.8
        uses: actions/setup-python@v4
        with:
          python-version: '3.8'
      - name: Install Ubuntu packages
        run: |
          sudo apt-get update -y
          sudo apt-get install -y --no-install-recommends doxygen
      - name: Install dependencies
        run: |
          python -m pip install -r docs/requirements-doc.txt
      - name: Building docs
        run: |
          make -C docs html
      - name: Upload HTML
        uses: actions/upload-artifact@v4
        with:
          name: html-build-artifact
          path: docs/build/html
          if-no-files-found: error
          retention-days: 1
      - name: Store PR information
        run: |
          mkdir ./pr
          echo ${{ github.event.number }}              > ./pr/pr.txt
          echo ${{ github.event.pull_request.merged }} > ./pr/merged.txt
          echo ${{ github.event.action }}              > ./pr/action.txt
      - name: Upload PR information
        uses: actions/upload-artifact@v4
        with:
          name: pr
          path: pr/


================================================
FILE: .github/workflows/docs-preview-pr.yaml
================================================
name: docs-preview-pr

on:
  workflow_run:
    workflows: [docs-build]
    types: [completed]

env:
  WF_ID: ${{ github.event.workflow_run.id }}

jobs:
  preview:
    uses: nvidia-merlin/.github/.github/workflows/docs-preview-pr-common.yaml@main

================================================
FILE: .github/workflows/docs-remove-stale-reviews.yaml
================================================
name: docs-remove-stale-reviews

on:
  schedule:
    # 42 minutes after 0:00 UTC on Sundays
    - cron: "42 0 * * 0"
  workflow_dispatch:

jobs:
  remove:
    uses: nvidia-merlin/.github/.github/workflows/docs-remove-stale-reviews-common.yaml@main


================================================
FILE: .github/workflows/docs-sched-rebuild.yaml
================================================
name: docs-sched-rebuild

on:
  push:
    branches: [master]
    tags:
      - v*
  workflow_dispatch:

jobs:
  build:
    runs-on: "ubuntu-latest"

    steps:
      - uses: actions/checkout@v3
        with:
          fetch-depth: 0
      - name: Set up Python 3.8
        uses: actions/setup-python@v4
        with:
          python-version: 3.8
      - name: Install Ubuntu packages
        run: |
          sudo apt-get update -y
          sudo apt-get install -y doxygen
      - name: Install dependencies
        run: |
          python -m pip install --upgrade pip
          python -m pip install -r docs/requirements-doc.txt
      - name: Report the versions to build
        run: |
          sphinx-multiversion -D 'exhale_args.containmentFolder=${sourcedir}/api' --dump-metadata docs/source docs/build/html | jq "keys"
      - name: Building docs (multiversion)
        run: |
          sphinx-multiversion -D 'exhale_args.containmentFolder=${sourcedir}/api' docs/source docs/build/html
      - name: Delete unnecessary files
        run: |
          find docs/build -name .doctrees -prune -exec rm -rf {} \;
          find docs/build -name .buildinfo -exec rm {} \;
      - name: Upload HTML
        uses: actions/upload-artifact@v4
        with:
          name: html-build-artifact
          path: docs/build/html
          if-no-files-found: error
          retention-days: 1

  # Identify the dir for the HTML.
  store-html:
    needs: [build]
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
        with:
          ref: "gh-pages"
      - name: Initialize Git configuration
        run: |
          git config user.name docs-sched-rebuild
          git config user.email do-not-send-@github.com
      - name: Download artifacts
        uses: actions/download-artifact@v4
        with:
          name: html-build-artifact
      - name: Copy HTML directories
        run: |
          ls -asl
          for i in `ls -d *`
          do
            echo "Git adding ${i}"
            git add "${i}"
          done
      - name: Check or create dot-no-jekyll file
        run: |
          if [ -f ".nojekyll" ]; then
            echo "The dot-no-jekyll file already exists."
            exit 0
          fi
          touch .nojekyll
          git add .nojekyll
      - name: Check or create redirect page
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        run: |
          resp=$(grep 'http-equiv="refresh"' index.html 2>/dev/null) || true
          if [ -n "${resp}" ]; then
            echo "The redirect file already exists."
            exit 0
          fi
          # If any of these commands fail, fail the build.
          def_branch=$(gh api "repos/${GITHUB_REPOSITORY}" --jq ".default_branch")
          html_url=$(gh api "repos/${GITHUB_REPOSITORY}/pages" --jq ".html_url")
          # Beware ugly quotation mark avoidance in the foll lines.
          echo '<!DOCTYPE html>'                                                                         > index.html
          echo '<html>'                                                                                 >> index.html
          echo '  <head>'                                                                               >> index.html
          echo '    <title>Redirect to documentation</title>'                                           >> index.html
          echo '    <meta charset="utf-8">'                                                             >> index.html
          echo '    <meta http=equiv="refresh" content="3; URL='${html_url}${def_branch}'/index.html">' >> index.html
          echo '    <link rel="canonical" href="'${html_url}${def_branch}'/index.html">'                >> index.html
          echo '    <script language="javascript">'                                                     >> index.html
          echo '      function redirect() {'                                                            >> index.html
          echo '        window.location.assign("'${html_url}${def_branch}'/index.html")'                >> index.html
          echo '      }'                                                                                >> index.html
          echo '    </script>'                                                                          >> index.html
          echo '  </head>'                                                                              >> index.html
          echo '  <body onload="redirect()">'                                                           >> index.html
          echo '    <p>Please follow the link to the <a href="'${html_url}${def_branch}'/index.html">'  >> index.html
          echo      ${def_branch}'</a> branch documentation.</p>'                                       >> index.html
          echo '  </body>'                                                                              >> index.html
          echo '</html>'                                                                                >> index.html
          git add index.html
      - name: Commit changes to the GitHub Pages branch
        run: |
          git status
          if git commit -m 'Pushing changes to GitHub Pages.'; then
            git push -f
          else
           echo "Nothing changed."
          fi


================================================
FILE: .gitignore
================================================
.DS_Store
.idea
.vscode
build
.clwb
cmake-build-debug/
docs/build
docs/source/README.md
docs/source/CONTRIBUTING.md
docs/source/api

================================================
FILE: .gitmodules
================================================
[submodule "tests/googletest"]
	path = tests/googletest
	url = https://github.com/google/googletest.git
	ignore = dirty


================================================
FILE: CMakeLists.txt
================================================
# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.10)
project(merlin-hkvs LANGUAGES CXX CUDA)
find_package(CUDAToolkit)

# TODO(Q3): target_compile_features below still declare cxx_std_14, which is
# inconsistent with the project-level C++17.  Update them to cxx_std_17 (or
# remove the per-target lines entirely) once downstream compatibility is
# confirmed.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules)

option(CLANGFORMAT "Clangformat code files before compiling" OFF)
if(CLANGFORMAT)
  include(ClangFormat)
  file(GLOB_RECURSE clangformat_includes
    ${PROJECT_SOURCE_DIR}/include/*.h
    ${PROJECT_SOURCE_DIR}/include/*.hpp
    ${PROJECT_SOURCE_DIR}/include/*.cuh
  )
  file(GLOB clangformat_tests
    ${PROJECT_SOURCE_DIR}/tests/*.c
    ${PROJECT_SOURCE_DIR}/tests/*.h
    ${PROJECT_SOURCE_DIR}/tests/*.cpp
    ${PROJECT_SOURCE_DIR}/tests/*.hpp
    ${PROJECT_SOURCE_DIR}/tests/*.cu
    ${PROJECT_SOURCE_DIR}/tests/*.cuh
  )
  set(clangformat_files ${clangformat_includes} ${clangformat_tests})
  clangformat_setup("${clangformat_files}")
endif()

# Default to release build.
if (NOT CMAKE_BUILD_TYPE)
    set(CMAKE_BUILD_TYPE "Release")
    message(STATUS "Setting default CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
endif()

# Some neat defaults.
set(CUDA_SEPARABLE_COMPILATION ON)

# Select target CUDA binary architecture.
foreach(cuda_arch ${sm})
  list(APPEND cuda_arch_list ${cuda_arch})
  message(STATUS "Assign GPU architecture (sm=${cuda_arch})")
endforeach()

list(LENGTH cuda_arch_list cuda_arch_list_length)
if(cuda_arch_list_length EQUAL 0)
  list(APPEND cuda_arch_list "80")
  message(STATUS "Assign default GPU architecture sm=80")
endif()

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
  add_compile_definitions(CUDA_ERROR_CHECK)
  set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo")
endif()

foreach(cuda_arch ${cuda_arch_list})
  set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode arch=compute_${cuda_arch},code=sm_${cuda_arch}")
endforeach()

message(CMAKE_CUDA_FLAGS="${CMAKE_CUDA_FLAGS}")

include_directories(
  ${PROJECT_SOURCE_DIR}/include
  ${PROJECT_SOURCE_DIR}/tests/googletest/googletest/include
)

ADD_SUBDIRECTORY(tests/googletest)

link_directories(
)

file(GLOB_RECURSE merlin_hkvs_src RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cpp *.cu)

# TODO:
# add_library(hierarchical_kv STATIC ${hierarchical_kv_src})
# target_compile_features(hierarchical_kv PUBLIC cxx_std_14)
# target_link_libraries(hierarchical_kv PUBLIC ...)


add_executable(merlin_hashtable_benchmark benchmark/merlin_hashtable_benchmark.cc.cu)
target_compile_features(merlin_hashtable_benchmark PUBLIC cxx_std_14)
set_target_properties(merlin_hashtable_benchmark PROPERTIES  CUDA_ARCHITECTURES OFF)

add_executable(find_with_missed_keys_benchmark benchmark/find_with_missed_keys_benchmark.cc.cu)
target_compile_features(find_with_missed_keys_benchmark PUBLIC cxx_std_14)
set_target_properties(find_with_missed_keys_benchmark PROPERTIES  CUDA_ARCHITECTURES OFF)

add_executable(merlin_hashtable_test tests/merlin_hashtable_test.cc.cu)
target_compile_features(merlin_hashtable_test PUBLIC cxx_std_14)
set_target_properties(merlin_hashtable_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(merlin_hashtable_test gtest_main)

add_executable(find_or_insert_test tests/find_or_insert_test.cc.cu)
target_compile_features(find_or_insert_test PUBLIC cxx_std_14)
set_target_properties(find_or_insert_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_or_insert_test gtest_main)

add_executable(merlin_memory_pool_test tests/memory_pool_test.cc.cu)
target_compile_features(merlin_memory_pool_test PUBLIC cxx_std_14)
set_target_properties(merlin_memory_pool_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(merlin_memory_pool_test gtest_main)

set(CMAKE_BUILD_TYPE "Debug")
add_executable(save_and_load_test tests/save_and_load_test.cc.cu)
target_compile_features(save_and_load_test PUBLIC cxx_std_14)
set_target_properties(save_and_load_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(save_and_load_test gtest_main)

add_executable(insert_and_evict_test tests/insert_and_evict_test.cc.cu)
target_compile_features(insert_and_evict_test PUBLIC cxx_std_14)
set_target_properties(insert_and_evict_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(insert_and_evict_test gtest_main)

add_executable(dynamic_max_capacity_test tests/dynamic_max_capacity_test.cc.cu)
target_compile_features(dynamic_max_capacity_test PUBLIC cxx_std_14)
set_target_properties(dynamic_max_capacity_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(dynamic_max_capacity_test gtest_main)

add_executable(group_lock_test tests/group_lock_test.cc.cu)
target_compile_features(group_lock_test PUBLIC cxx_std_14)
set_target_properties(group_lock_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(group_lock_test gtest_main)

add_executable(find_or_insert_ptr_test tests/find_or_insert_ptr_test.cc.cu)
target_compile_features(find_or_insert_ptr_test PUBLIC cxx_std_14)
set_target_properties(find_or_insert_ptr_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_or_insert_ptr_test gtest_main)

add_executable(assign_score_test tests/assign_score_test.cc.cu)
target_compile_features(assign_score_test PUBLIC cxx_std_14)
set_target_properties(assign_score_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(assign_score_test gtest_main)

add_executable(uint32_score_test tests/uint32_score_test.cc.cu)
target_compile_features(uint32_score_test PUBLIC cxx_std_14)
set_target_properties(uint32_score_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(uint32_score_test gtest_main)

add_executable(accum_or_assign_test tests/accum_or_assign_test.cc)
target_compile_features(accum_or_assign_test PUBLIC cxx_std_14)
set_target_properties(accum_or_assign_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(accum_or_assign_test gtest_main)

add_executable(assign_values_test tests/assign_values_test.cc.cu)
target_compile_features(assign_values_test PUBLIC cxx_std_14)
set_target_properties(assign_values_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(assign_values_test gtest_main)

add_executable(find_with_missed_keys_test tests/find_with_missed_keys_test.cc.cu)
target_compile_features(find_with_missed_keys_test PUBLIC cxx_std_14)
set_target_properties(find_with_missed_keys_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main)

add_executable(reserved_keys_test tests/reserved_keys_test.cc.cu)
target_compile_features(reserved_keys_test PUBLIC cxx_std_14)
set_target_properties(reserved_keys_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main)

add_executable(export_batch_if_test tests/export_batch_if_test.cc.cu)
target_compile_features(export_batch_if_test PUBLIC cxx_std_14)
set_target_properties(export_batch_if_test PROPERTIES  CUDA_ARCHITECTURES OFF)

add_executable(find_or_insert_ptr_lock_test tests/find_or_insert_ptr_lock_test.cc.cu)
target_compile_features(find_or_insert_ptr_lock_test PUBLIC cxx_std_14)
set_target_properties(find_or_insert_ptr_lock_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_or_insert_ptr_lock_test gtest_main)

add_executable(lock_unlock_test tests/lock_unlock_test.cc.cu)
target_compile_features(lock_unlock_test PUBLIC cxx_std_14)
set_target_properties(lock_unlock_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(lock_unlock_test gtest_main)

add_executable(dual_bucket_test tests/dual_bucket_test.cc.cu)
target_compile_features(dual_bucket_test PUBLIC cxx_std_14)
set_target_properties(dual_bucket_test PROPERTIES  CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(dual_bucket_test gtest_main)

add_executable(dual_bucket_benchmark benchmark/dual_bucket_benchmark.cc.cu)
target_compile_features(dual_bucket_benchmark PUBLIC cxx_std_14)
set_target_properties(dual_bucket_benchmark PROPERTIES  CUDA_ARCHITECTURES OFF)


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing

## About HierarchicalKV

HierarchicalKV is a part of NVIDIA Merlin and provides hierarchical key-value storage to meet RecSys requirements.

The key capability of HierarchicalKV is to store key-value (feature-embedding) on high-bandwidth memory (HBM) of GPUs and in host memory.

You can also use the library for generic key-value storage.

## Maintainership

HierarchicalKV is co-maintianed by [NVIDIA Merlin Team](https://github.com/NVIDIA-Merlin) and NVIDIA product end-users,
and also open for public contributions, bug fixes, and documentation. This project adheres to NVIDIA's Code of Conduct.

## Contributing

We’re grateful for your interest in HierarchicalKV and value your contributions. 
We welcome contributions via pull requests(PR). 

Before sending out a pull request for significant change on the end-user API, we recommend you open an issue and
discuss your proposed change. Some changes may require a design review.
All submissions require review by project reviewers.

### Coding Style

Refer to the [Style Guide](http://github.com/NVIDIA-Merlin/HierarchicalKV/STYLE_GUIDE.md)

### Additional Requirements

In addition to the above requirements, contribution also needs to meet the following criteria:
* The change needs to include unit tests and integration tests if any.
* Each PR needs to provide necessary documentation for when and how to use it.

## Community

* HierarchicalKV code (https://github.com/NVIDIA-Merlin/HierarchicalKV)

## Licence
Apache License 2.0



================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "{}"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2022 NVIDIA Corporation

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# [NVIDIA HierarchicalKV(Beta)](https://github.com/NVIDIA-Merlin/HierarchicalKV)

[![Version](https://img.shields.io/github/v/release/NVIDIA-Merlin/HierarchicalKV?color=orange&include_prereleases)](https://github.com/NVIDIA-Merlin/HierarchicalKV/releases)
[![GitHub License](https://img.shields.io/github/license/NVIDIA-Merlin/HierarchicalKV)](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/LICENSE)
[![Documentation](https://img.shields.io/badge/documentation-blue.svg)](https://nvidia-merlin.github.io/HierarchicalKV/master/README.html)

## About HierarchicalKV

HierarchicalKV is a part of NVIDIA Merlin and provides hierarchical key-value storage to meet RecSys requirements.

The key capability of HierarchicalKV is to store key-value (feature-embedding) on high-bandwidth memory (HBM) of GPUs and in host memory.

You can also use the library for generic key-value storage.

## Benefits

When building large recommender systems, machine learning (ML) engineers face the following challenges:

- GPUs are needed, but HBM on a single GPU is too small for the large DLRMs that scale to several terabytes.
- Improving communication performance is getting more difficult in larger and larger CPU clusters.
- It is difficult to efficiently control consumption growth of limited HBM with customized strategies.
- Most generic key-value libraries provide low HBM and host memory utilization.

HierarchicalKV alleviates these challenges and helps the machine learning engineers in RecSys with the following benefits:

- Supports training large RecSys models on **HBM and host memory** at the same time.
- Provides better performance by **full bypassing CPUs** and reducing the communication workload.
- Implements table-size restraint strategies that are based on **LRU or customized strategies**.
  The strategies are implemented by CUDA kernels.
- Operates at a high working-status load factor that is close to 1.0.


## Key ideas

- Buckets are locally ordered
- Store keys and values separately
- Store all the keys in HBM
- Build-in and customizable eviction strategy

HierarchicalKV makes NVIDIA GPUs more suitable for training large and super-large models of ***search, recommendations, and advertising***.
The library simplifies the common challenges to building, evaluating, and serving sophisticated recommenders models.

## API Documentation

The main classes and structs are below, but reading the comments in the source code is recommended:

- [`class HashTable`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L151)
- [`class EvictStrategy`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L52)
- [`struct HashTableOptions`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L60)

For regular API doc, please refer to [API Docs](https://nvidia-merlin.github.io/HierarchicalKV/master/api/index.html)

### API Maturity Matrix

`industry-validated` means the API has been well-tested and verified in at least one real-world scenario.

| Name                 | Description                                                                                                              | Function           |
|:---------------------|:-------------------------------------------------------------------------------------------------------------------------|:-------------------|
| __insert_or_assign__ | Insert or assign for the specified keys. <br>Overwrite one key with minimum score when bucket is full.                   | industry-validated |
| __insert_and_evict__ | Insert new keys, and evict keys with minimum score when bucket is full.                                                  | industry-validated |
| __find_or_insert__   | Search for the specified keys, and insert them when missed.                                                              | well-tested        |
| __assign__           | Update for each key and bypass when missed.                                                                              | well-tested        |
| __accum_or_assign__  | Search and update for each key. If found, add value as a delta to the original value. <br>If missed, update it directly. | well-tested        |
| __find_or_insert\*__ | Search for the specified keys and return the pointers of values. Insert them firstly when missing.                       | well-tested        |
| __find__             | Search for the specified keys.                                                                                           | industry-validated |
| __find\*__           | Search and return the pointers of values, thread-unsafe but with high performance.                                       | well-tested        |
| __export_batch__     | Exports a certain number of the key-value-score tuples.                                                                  | industry-validated |
| __export_batch_if__  | Exports a certain number of the key-value-score tuples which match specific conditions.                                  | industry-validated |
| __warmup__           | Move the hot key-values from HMEM to HBM                                                                                 | June 15, 2023      |


### Evict Strategy

The `score` is introduced to define the importance of each key, the larger, the more important, the less likely they will be evicted. Eviction only happens when a bucket is full.
The `score_type` must be `uint64_t`. For more detail, please refer to [`class EvictStrategy`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L52).

| Name           | Definition of `Score`                                                                                                                                                                                           |
|:---------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| __Lru__        | Device clock in a nanosecond, which could differ slightly from host clock.                                                                                                                                      |
| __Lfu__        | Frequency increment provided by caller via the input parameter of `scores` of `insert-like` APIs as the increment of frequency.                                                                                 |
| __EpochLru__   | The high 32bits is the global epoch provided via the input parameter of `global_epoch`, <br>the low 32bits is equal to `(device_clock >> 20) & 0xffffffff` with granularity close to 1 ms.                      |
| __EpochLfu__   | The high 32bits is the global epoch provided via the input parameter of `global_epoch`, <br>the low 32bits is the frequency, <br>the frequency will keep constant after reaching the max value of `0xffffffff`. |
| __Customized__ | Fully provided by the caller via the input parameter of `scores` of `insert-like` APIs.                                                                                                                         |


* __Note__:
  - The `insert-like` APIs mean the APIs of `insert_or_assign`, `insert_and_evict`, `find_or_insert`, `accum_or_assign`, and `find_or_insert`. 
  - The `global_epoch` should be maintained by the caller and input as the input parameter of `insert-like` APIs.

### Configuration Options

It's recommended to keep the default configuration for the options ending with `*`.

| Name                       | Type   | Default | Description                                           |
|:---------------------------|:-------|:--------|:------------------------------------------------------|
| __init_capacity__          | size_t | 0       | The initial capacity of the hash table.               |
| __max_capacity__           | size_t | 0       | The maximum capacity of the hash table.               |
| __max_hbm_for_vectors__    | size_t | 0       | The maximum HBM for vectors, in bytes.                |
| __dim__                    | size_t | 64      | The dimension of the value vectors.                   |
| __max_bucket_size*__       | size_t | 128     | The length of each bucket.                            |
| __max_load_factor*__       | float  | 0.5f    | The max load factor before rehashing.                 |
| __block_size*__            | int    | 128     | The default block size for CUDA kernels.              |
| __io_block_size*__         | int    | 1024    | The block size for IO CUDA kernels.                   |
| __device_id*__             | int    | -1      | The ID of device. Managed internally when set to `-1` |
| __io_by_cpu*__             | bool   | false   | The flag indicating if the CPU handles IO.            |
| __reserved_key_start_bit__ | int    | 0       | The start bit offset of reserved key in the 64 bit    |

- Fore more details refer to [`struct HashTableOptions`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L60).

#### Reserved Keys
- By default, the keys of `0xFFFFFFFFFFFFFFFD`, `0xFFFFFFFFFFFFFFFE`, and `0xFFFFFFFFFFFFFFFF` are reserved for internal using.
  change  `options.reserved_key_start_bit` if you want to use the above keys.
  `reserved_key_start_bit` has a valid range from 0 to 62. The default value is 0, which is the above default reserved keys. When `reserved_key_start_bit` is set to any value other than 0, the least significant bit (bit 0) is always `0` for any reserved key.

- Setting `reserved_key_start_bit = 1`:
  - This setting reserves the two least significant bits 1 and 2 for the reserved keys.
  - In binary, the last four bits range from `1000` to `1110`. Here, the least significant bit (bit 0) is always `0`, and bits from 3 to 63 are set to `1`.
  - The new reserved keys in hexadecimal representation are as follows:
    - `0xFFFFFFFFFFFFFFFE`
    - `0xFFFFFFFFFFFFFFFC`
    - `0xFFFFFFFFFFFFFFF8`
    - `0xFFFFFFFFFFFFFFFA`

- Setting `reserved_key_start_bit = 2`:
  - This configuration reserves bits 2 and 3 as reserved keys.
  - The binary representation for the last five bits ranges from `10010` to `11110`, with the least significant bit (bit 0) always set to `0`, and bits from 4 to 63 are set to `1`.

- if you change the reserved_key_start_bit, you should use same value for save/load
  For more detail, please refer to [`init_reserved_keys`](https://github.com/search?q=repo%3ANVIDIA-Merlin%2FHierarchicalKV%20init_reserved_keys&type=code)

### How to use:
```cpp
#include "merlin_hashtable.cuh"


using TableOptions = nv::merlin::HashTableOptions;
using EvictStrategy = nv::merlin::EvictStrategy;

int main(int argc, char *argv[])
{
  using K = uint64_t;
  using V = float;
  using S = uint64_t;
  
  // 1. Define the table and use LRU eviction strategy.
  using HKVTable = nv::merlin::HashTable<K, V, S, EvictStrategy::kLru>;
  std::unique_ptr<HKVTable> table = std::make_unique<HKVTable>();
  
  // 2. Define the configuration options.
  TableOptions options;
  options.init_capacity = 16 * 1024 * 1024;
  options.max_capacity = options.init_capacity;
  options.dim = 16;
  options.max_hbm_for_vectors = nv::merlin::GB(16);
  
  
  // 3. Initialize the table memory resource.
  table->init(options);
  
  // 4. Use table to do something.
  
  return 0;
}

```

### Usage restrictions

- The `key_type` must be `int64_t` or `uint64_t`.
- The `score_type` must be `uint64_t`.
## Contributors

HierarchicalKV is co-maintianed by [NVIDIA Merlin Team](https://github.com/NVIDIA-Merlin) and NVIDIA product end-users,
and also open for public contributions, bug fixes, and documentation. [[Contribute](CONTRIBUTING.md)]

## How to build

Basically, HierarchicalKV is a headers only library, the commands below only create binaries for benchmark and unit testing.

Your environment must meet the following requirements:

- CUDA version >= 11.2
- NVIDIA GPU with compute capability 8.0, 8.6, 8.7 or 9.0
- GCC supports `C++17' standard or later.
- Bazel version >= 3.7.2 (Bazel compile only)

### with cmake
```shell
git clone --recursive https://github.com/NVIDIA-Merlin/HierarchicalKV.git
cd HierarchicalKV && mkdir -p build && cd build
cmake -DCMAKE_BUILD_TYPE=Release -Dsm=80 .. && make -j
```

For Debug:
```shell
cmake -DCMAKE_BUILD_TYPE=Debug -Dsm=80 .. && make -j
```

For Benchmark:
```shell
./merlin_hashtable_benchmark
```

For Unit Test:
```shell
./merlin_hashtable_test
```

### with bazel

- DON'T use the option of `--recursive` for `git clone`.
- Please modify the environment variables in the `.bazelrc` file in advance if using the customized docker images.
- The docker images maintained on `nvcr.io/nvidia/tensorflow` are highly recommended.

Pull the docker image:
```shell
docker pull nvcr.io/nvidia/tensorflow:22.09-tf2-py3
docker run --gpus all -it --rm nvcr.io/nvidia/tensorflow:22.09-tf2-py3
```

Compile in docker container:
```shell
git clone https://github.com/NVIDIA-Merlin/HierarchicalKV.git
cd HierarchicalKV && bash bazel_build.sh
```

For Benchmark:
```shell
./benchmark_util
```


## Benchmark & Performance(W.I.P)

* GPU: 1 x NVIDIA A100 80GB PCIe: 8.0
* Key Type = uint64_t
* Value Type = float32 * {dim}
* Key-Values per OP = 1048576
* Evict strategy: LRU
* `λ`: load factor
* `find*` means the `find` API that directly returns the addresses of values.
* `find_or_insert*` means the `find_or_insert` API that directly returns the addresses of values.
* ***Throughput Unit: Billion-KV/second***

### On pure HBM mode: 

* dim = 8, capacity = 128 Million-KV, HBM = 4 GB, HMEM = 0 GB

|    λ | insert_or_assign |   find | find_or_insert | assign |  find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
| 0.50 |            1.093 |  2.470 |          1.478 |  1.770 |  3.726 |           1.447 |            1.075 |
| 0.75 |            1.045 |  2.452 |          1.335 |  1.807 |  3.374 |           1.309 |            1.013 |
| 1.00 |            0.655 |  2.481 |          0.612 |  1.815 |  1.865 |           0.619 |            0.511 |

|    λ | export_batch | export_batch_if | contains |
|-----:|-------------:|----------------:|---------:|
| 0.50 |        2.087 |          12.258 |    3.121 |
| 0.75 |        2.045 |          12.447 |    3.094 |
| 1.00 |        1.950 |           2.657 |    3.096 |

* dim = 32, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 0 GB

|    λ | insert_or_assign |   find | find_or_insert | assign |  find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
| 0.50 |            0.961 |  2.272 |          1.278 |  1.706 |  3.718 |           1.435 |            0.931 |
| 0.75 |            0.930 |  2.238 |          1.177 |  1.693 |  3.369 |           1.316 |            0.866 |
| 1.00 |            0.646 |  2.321 |          0.572 |  1.783 |  1.873 |           0.618 |            0.469 |

|    λ | export_batch | export_batch_if | contains |
|-----:|-------------:|----------------:|---------:|
| 0.50 |        0.692 |          10.784 |    3.100 |
| 0.75 |        0.569 |          10.240 |    3.075 |
| 1.00 |        0.551 |           0.765 |    3.096 |

* dim = 64, capacity = 64 Million-KV, HBM = 16 GB, HMEM = 0 GB

|    λ | insert_or_assign |   find | find_or_insert | assign |  find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
| 0.50 |            0.834 |  1.982 |          1.113 |  1.499 |  3.950 |           1.502 |            0.805 |
| 0.75 |            0.801 |  1.951 |          1.033 |  1.493 |  3.545 |           1.359 |            0.773 |
| 1.00 |            0.621 |  2.021 |          0.608 |  1.541 |  1.965 |           0.613 |            0.481 |

|    λ | export_batch | export_batch_if | contains |
|-----:|-------------:|----------------:|---------:|
| 0.50 |        0.316 |           8.199 |    3.239 |
| 0.75 |        0.296 |           8.549 |    3.198 |
| 1.00 |        0.288 |           0.395 |    3.225 |

### On HBM+HMEM hybrid mode: 

* dim = 64, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 16 GB

|    λ | insert_or_assign |   find | find_or_insert | assign |  find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
| 0.50 |            0.083 |  0.124 |          0.109 |  0.131 |  3.705 |           1.435 |
| 0.75 |            0.083 |  0.122 |          0.111 |  0.129 |  3.221 |           1.274 |
| 1.00 |            0.073 |  0.123 |          0.095 |  0.126 |  1.854 |           0.617 |

|    λ | export_batch | export_batch_if | contains |
|-----:|-------------:|----------------:|---------:|
| 0.50 |        0.318 |           8.086 |    3.122 |
| 0.75 |        0.294 |           5.549 |    3.111 |
| 1.00 |        0.287 |           0.393 |    3.075 |

* dim = 64, capacity = 512 Million-KV, HBM = 32 GB, HMEM = 96 GB

|    λ | insert_or_assign |   find | find_or_insert | assign |  find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
| 0.50 |            0.049 |  0.069 |          0.049 |  0.069 |  3.484 |           1.370 |
| 0.75 |            0.049 |  0.069 |          0.049 |  0.069 |  3.116 |           1.242 |
| 1.00 |            0.047 |  0.072 |          0.047 |  0.070 |  1.771 |           0.607 |

|    λ | export_batch | export_batch_if | contains |
|-----:|-------------:|----------------:|---------:|
| 0.50 |        0.316 |           8.181 |    3.073 |
| 0.75 |        0.293 |           8.950 |    3.052 |
| 1.00 |        0.292 |           0.394 |    3.026 |

### Support and Feedback:

If you encounter any issues or have questions, go to [https://github.com/NVIDIA-Merlin/HierarchicalKV/issues](https://github.com/NVIDIA-Merlin/HierarchicalKV/issues) and submit an issue so that we can provide you with the necessary resolutions and answers.

### Acknowledgment
We are very grateful to external initial contributors [@Zhangyafei](https://github.com/zhangyafeikimi) and [@Lifan](https://github.com/Lifann) for their design, coding, and review work.

### License
Apache License 2.0


================================================
FILE: STYLE_GUIDE.md
================================================
#### C++
C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).

HierarchicalKV uses [clang-format](https://clang.llvm.org/docs/ClangFormat.html)
to check your C/C++ changes. Sometimes you have some manually formatted
code that you don’t want clang-format to touch.
You can disable formatting like this:

```cpp
int formatted_code;
// clang-format off
    void    unformatted_code  ;
// clang-format on
void formatted_code_again;
```

Install Clang-format (the version 18.1.3 is required) for Ubuntu:

```bash
sudo apt install clang-format-18
```

format all with:
```bash
find ./ \( -path ./tests/googletest -prune \) -o \( -iname *.h -o -iname *.cpp -o -iname *.cc -o -iname *.cu -o -iname *.cuh -o -iname *.hpp \) -print | xargs clang-format-18 -i --style=file

```


================================================
FILE: WORKSPACE
================================================
workspace(name = "HierarchicalKV")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//build_deps/gpus:configure.bzl", "cuda_configure")

http_archive(
    name = "bazel_skylib",
    sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
    urls = [
        "https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz",
    ],
)

cuda_configure(name = "local_config_cuda")


================================================
FILE: bazel_build.sh
================================================
#!/bin/bash

# Usage : `./bazel_build.sh` or `bash bazel_build.sh`
set -e
export $(cat .bazeliskrc | xargs)

bazel build --config=cuda //...


================================================
FILE: benchmark/BUILD
================================================
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_cc_library")

cc_binary(
    name = "benchmark_util",
    deps = [
        ":benchmark_lib",
    ],
)

cuda_cc_library(
    name = "benchmark_lib",
    srcs = [
        "merlin_hashtable_benchmark.cc.cu",
    ],
    hdrs = [
        "benchmark_util.cuh",
    ],
    copts = ["-Iinclude/"],
    linkopts = ["-pthread"],
    deps = [
        "//include:merlin_hashtable",
        "@local_config_cuda//cuda",
    ],
)


================================================
FILE: benchmark/benchmark_util.cuh
================================================
/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <chrono>
#include <cmath>
#include <cstdint>
#include "merlin/utils.cuh"

namespace benchmark {

enum class TimeUnit {
  Second = 0,
  MilliSecond = 3,
  MicroSecond = 6,
  NanoSecond = 9,
};

enum class API_Select {
  find = 0,
  insert_or_assign = 1,
  find_or_insert = 2,
  assign = 3,
  insert_and_evict = 4,
  find_ptr = 5,
  find_or_insert_ptr = 6,
  export_batch = 7,
  export_batch_if = 8,
  contains = 9,
};

enum class Hit_Mode {
  random = 0,
  last_insert = 1,
};

template <typename Rep>
struct Timer {
  explicit Timer(TimeUnit tu = TimeUnit::Second) : tu_(tu) {}
  void start() { startRecord = std::chrono::steady_clock::now(); }
  void end() { endRecord = std::chrono::steady_clock::now(); }
  Rep getResult() {
    auto duration_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
        endRecord - startRecord);
    auto pow_ =
        static_cast<int32_t>(tu_) - static_cast<int32_t>(TimeUnit::NanoSecond);
    auto factor = static_cast<Rep>(std::pow(10, pow_));
    return static_cast<Rep>(duration_.count()) * factor;
  }

 private:
  TimeUnit tu_;
  std::chrono::time_point<std::chrono::steady_clock> startRecord{};
  std::chrono::time_point<std::chrono::steady_clock> endRecord{};
};

// RAII Timer using CUDA Event
template <typename Rep>
struct KernelTimer {
  explicit KernelTimer(TimeUnit tu = TimeUnit::Second) : tu_(tu) {
    CUDA_CHECK(cudaEventCreate(&start_));
    CUDA_CHECK(cudaEventCreate(&end_));
  }
  ~KernelTimer() {
    CUDA_CHECK(cudaEventDestroy(start_));
    CUDA_CHECK(cudaEventDestroy(end_));
  }
  void start() { CUDA_CHECK(cudaEventRecord(start_)); }
  void end() {
    CUDA_CHECK(cudaEventRecord(end_));
    CUDA_CHECK(cudaEventSynchronize(end_));
    CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
  }
  Rep getResult() {
    auto pow_ =
        static_cast<int32_t>(tu_) - static_cast<int32_t>(TimeUnit::MilliSecond);
    auto factor = static_cast<Rep>(std::pow(10, pow_));
    return static_cast<Rep>(time * factor);
  }

 private:
  TimeUnit tu_;
  float time{-1.0f};
  cudaEvent_t start_;
  cudaEvent_t end_;
};

inline uint64_t getTimestamp() {
  return std::chrono::duration_cast<std::chrono::milliseconds>(
             std::chrono::system_clock::now().time_since_epoch())
      .count();
}

template <class K, class S>
void create_continuous_keys(K* h_keys, S* h_scores, const int key_num_per_op,
                            const K start = 0, int freq_range = 1000) {
  for (K i = 0; i < key_num_per_op; i++) {
    h_keys[i] = start + static_cast<K>(i);
    if (h_scores != nullptr) h_scores[i] = h_keys[i] % freq_range;
  }
}

template <class K, class S>
void create_random_keys(K* h_keys, S* h_scores, const int key_num_per_op) {
  std::unordered_set<K> numbers;
  std::random_device rd;
  std::mt19937_64 eng(rd());
  std::uniform_int_distribution<K> distr;
  int i = 0;

  while (numbers.size() < key_num_per_op) {
    numbers.insert(distr(eng));
  }
  for (const K num : numbers) {
    h_keys[i] = num;
    if (h_scores != nullptr) h_scores[i] = getTimestamp();
    i++;
  }
}

template <typename K, typename S>
void create_keys_for_hitrate(K* h_keys, S* h_scores, const int key_num_per_op,
                             const float hitrate = 0.6f,
                             const Hit_Mode hit_mode = Hit_Mode::last_insert,
                             const K end = 0, const bool reset = false,
                             int freq_range = 1000) {
  int divide = static_cast<int>(key_num_per_op * hitrate);
  if (Hit_Mode::random == hit_mode) {
    std::random_device rd;
    std::mt19937_64 eng(rd());
    K existed_max = end == 0 ? 1 : (end - 1);
    std::uniform_int_distribution<K> distr(0, existed_max);

    if (existed_max < divide) {
      std::cout << "# Can not generate enough keys for hit!";
      exit(-1);
    }
    std::unordered_set<K> numbers;
    while (numbers.size() < divide) {
      numbers.insert(distr(eng));
    }
    int i = 0;
    for (auto existed_value : numbers) {
      h_keys[i] = existed_value;
      if (h_scores != nullptr) h_scores[i] = h_keys[i] % freq_range;
      i++;
    }
  } else {
    // else keep its original value, but update scores
    for (int i = 0; i < divide; i++) {
      if (h_scores != nullptr) h_scores[i] = getTimestamp() % freq_range;
    }
  }

  static K new_value = std::numeric_limits<K>::max();
  if (reset) {
    new_value = std::numeric_limits<K>::max();
  }
  for (int i = divide; i < key_num_per_op; i++) {
    h_keys[i] = new_value--;
    if (h_scores != nullptr) h_scores[i] = getTimestamp() % freq_range;
  }
}

template <typename S>
void refresh_scores(S* h_scores, const int key_num_per_op) {
  for (int i = 0; i < key_num_per_op; i++) {
    h_scores[i] = getTimestamp();
  }
}

template <class K, class V>
void init_value_using_key(K* h_keys, V* h_vectors, const int key_num_per_op,
                          int dim) {
  for (size_t i = 0; i < key_num_per_op; i++) {
    for (size_t j = 0; j < dim; j++) {
      h_vectors[i * dim + j] = static_cast<V>(h_keys[i] * 0.00001);
    }
  }
}

template <class V>
__global__ void read_from_ptr_kernel(const V* const* __restrict src,
                                     V* __restrict dst, const size_t dim,
                                     size_t N) {
  size_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;

  for (size_t t = tid; t < N; t += blockDim.x * gridDim.x) {
    int vec_index = int(t / dim);
    int dim_index = t % dim;
    if (src[vec_index]) {
      dst[vec_index * dim + dim_index] = src[vec_index][dim_index];
    }
  }
}

template <class V>
void read_from_ptr(const V* const* __restrict src, V* __restrict dst,
                   const size_t dim, size_t n, cudaStream_t stream) {
  const size_t block_size = 1024;
  const size_t N = n * dim;
  const size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(N, block_size);

  read_from_ptr_kernel<V>
      <<<grid_size, block_size, 0, stream>>>(src, dst, dim, N);
}

template <class V>
__global__ void array2ptr_kernel(V** ptr, V* __restrict array, const size_t dim,
                                 size_t N) {
  size_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;

  for (size_t t = tid; t < N; t += blockDim.x * gridDim.x) {
    int vec_index = int(t);
    ptr[vec_index] = array + vec_index * dim;
  }
}

template <class V>
void array2ptr(V** ptr, V* __restrict array, const size_t dim, size_t n,
               cudaStream_t stream) {
  const size_t block_size = 1024;
  const size_t N = n;
  const size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(N, block_size);

  array2ptr_kernel<V><<<grid_size, block_size, 0, stream>>>(ptr, array, dim, N);
}

template <class S>
__global__ void host_nano_kernel(S* d_clk) {
  S mclk;
  asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(mclk));
  *d_clk = mclk;
}

template <class S>
S host_nano(cudaStream_t stream = 0) {
  S h_clk = 0;
  S* d_clk;

  CUDA_CHECK(cudaMalloc((void**)&(d_clk), sizeof(S)));
  host_nano_kernel<S><<<1, 1, 0, stream>>>(d_clk);
  CUDA_CHECK(cudaStreamSynchronize(stream));

  CUDA_CHECK(cudaMemcpy(&h_clk, d_clk, sizeof(S), cudaMemcpyDeviceToHost));
  CUDA_CHECK(cudaFree(d_clk));
  return h_clk;
}

template <class K, class S>
struct ExportIfPredFunctor {
  __forceinline__ __device__ bool operator()(const K& key, S& score,
                                             const K& pattern,
                                             const S& threshold) {
    return score > threshold;
  }
};

}  // namespace benchmark


================================================
FILE: benchmark/dual_bucket_benchmark.cc.cu
================================================
/*
 * Copyright (c) 2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <chrono>
#include <cstdio>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#include "merlin_hashtable.cuh"

using K = uint64_t;
using V = float;
using S = uint64_t;
using TableOptions = nv::merlin::HashTableOptions;
using TableMode = nv::merlin::TableMode;
using EvictStrategy = nv::merlin::EvictStrategy;

template <typename Table>
double benchmark_insert(Table& table, size_t n, K* d_keys, V* d_values,
                        S* d_scores, cudaStream_t stream) {
  CUDA_CHECK(cudaStreamSynchronize(stream));
  auto start = std::chrono::high_resolution_clock::now();
  table.insert_or_assign(n, d_keys, d_values, d_scores, stream, true);
  CUDA_CHECK(cudaStreamSynchronize(stream));
  auto end = std::chrono::high_resolution_clock::now();
  double ms = std::chrono::duration_cast<std::chrono::microseconds>(end - start)
                  .count() /
              1000.0;
  return static_cast<double>(n) / ms / 1000.0;  // Mops/s
}

template <typename Table>
double benchmark_find(Table& table, size_t n, K* d_keys, V* d_values,
                      bool* d_founds, cudaStream_t stream) {
  CUDA_CHECK(cudaStreamSynchronize(stream));
  auto start = std::chrono::high_resolution_clock::now();
  table.find(n, d_keys, d_values, d_founds, nullptr, stream);
  CUDA_CHECK(cudaStreamSynchronize(stream));
  auto end = std::chrono::high_resolution_clock::now();
  double ms = std::chrono::duration_cast<std::chrono::microseconds>(end - start)
                  .count() /
              1000.0;
  return static_cast<double>(n) / ms / 1000.0;  // Mops/s
}

void run_benchmark(size_t capacity, size_t dim, TableMode mode,
                   const char* mode_name) {
  using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kCustomized>;

  Table table;
  TableOptions options;
  options.init_capacity = capacity;
  options.max_capacity = capacity;
  options.max_hbm_for_vectors = 0;
  options.dim = dim;
  options.max_bucket_size = 128;
  options.table_mode = mode;
  table.init(options);

  cudaStream_t stream;
  CUDA_CHECK(cudaStreamCreate(&stream));

  // Generate keys.
  size_t max_n = capacity;
  std::vector<K> h_keys(max_n);
  std::vector<V> h_values(max_n * dim, 1.0f);
  std::vector<S> h_scores(max_n);
  std::iota(h_keys.begin(), h_keys.end(), 1);
  for (size_t i = 0; i < max_n; i++) h_scores[i] = i + 1;

  K* d_keys;
  V* d_values;
  S* d_scores;
  bool* d_founds;
  V* d_found_values;
  CUDA_CHECK(cudaMalloc(&d_keys, max_n * sizeof(K)));
  CUDA_CHECK(cudaMalloc(&d_values, max_n * dim * sizeof(V)));
  CUDA_CHECK(cudaMalloc(&d_scores, max_n * sizeof(S)));
  CUDA_CHECK(cudaMalloc(&d_founds, max_n * sizeof(bool)));
  CUDA_CHECK(cudaMalloc(&d_found_values, max_n * dim * sizeof(V)));

  CUDA_CHECK(cudaMemcpy(d_keys, h_keys.data(), max_n * sizeof(K),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_values, h_values.data(), max_n * dim * sizeof(V),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_scores, h_scores.data(), max_n * sizeof(S),
                        cudaMemcpyHostToDevice));

  printf("--- %s (capacity=%zuK, dim=%zu) ---\n", mode_name, capacity / 1024,
         dim);
  printf("  %-12s  %-18s  %-18s\n", "Load Factor", "Insert (Mops/s)",
         "Find (Mops/s)");

  float load_factors[] = {0.25f, 0.50f, 0.75f, 0.90f, 0.95f, 1.00f};
  size_t prev_n = 0;

  for (float lf : load_factors) {
    size_t target_n = static_cast<size_t>(capacity * lf);
    if (target_n > max_n) break;
    size_t batch_n = target_n - prev_n;
    if (batch_n == 0) continue;

    // Insert to reach target load factor.
    double insert_mops =
        benchmark_insert(table, batch_n, d_keys + prev_n,
                         d_values + prev_n * dim, d_scores + prev_n, stream);

    // Find all inserted keys.
    double find_mops = benchmark_find(table, target_n, d_keys, d_found_values,
                                      d_founds, stream);

    printf("  %-12.2f  %-18.1f  %-18.1f\n", lf, insert_mops, find_mops);
    prev_n = target_n;
  }

  // Memory efficiency: first eviction LF.
  // (Already covered in test, report here too.)
  size_t table_size = table.size(stream);
  CUDA_CHECK(cudaStreamSynchronize(stream));
  printf("  Final size: %zu / %zu (LF=%.4f)\n", table_size, capacity,
         static_cast<float>(table_size) / capacity);

  CUDA_CHECK(cudaFree(d_keys));
  CUDA_CHECK(cudaFree(d_values));
  CUDA_CHECK(cudaFree(d_scores));
  CUDA_CHECK(cudaFree(d_founds));
  CUDA_CHECK(cudaFree(d_found_values));
  CUDA_CHECK(cudaStreamDestroy(stream));
}

int main(int argc, char** argv) {
  printf("=== Dual-Bucket Benchmark Results ===\n\n");

  // Default: 1M capacity, dim=64.
  size_t capacity = 128 * 1024 * 8;  // ~1M
  size_t dim = 64;

  if (argc > 1) capacity = static_cast<size_t>(atol(argv[1]));
  if (argc > 2) dim = static_cast<size_t>(atol(argv[2]));

  run_benchmark(capacity, dim, TableMode::kThroughput, "THROUGHPUT_MODE");
  printf("\n");
  run_benchmark(capacity, dim, TableMode::kMemory, "MEMORY_MODE");
  printf("\n");

  printf("=== Benchmark Complete ===\n");
  return 0;
}


================================================
FILE: benchmark/find_with_missed_keys_benchmark.cc.cu
================================================
/*
 * Copyright (c) 2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <assert.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include "benchmark_util.cuh"
#include "merlin_hashtable.cuh"

using K = uint64_t;
using V = float;
using S = uint64_t;
using EvictStrategy = nv::merlin::EvictStrategy;
using TableOptions = nv::merlin::HashTableOptions;
using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kCustomized>;

void print_tile() {
  std::cout << std::endl
            << "|    \u03BB " << "| capacity " << "| max_hbm_for_vectors "
            << "| max_bucket_size " << "| dim " << "| missed_ratio "
            << "| througput(BillionKV/secs) ";
  std::cout << "|\n";

  //<< "| load_factor "
  std::cout << "|------"
            //<< "| capacity "
            << "|----------"
            //<< "| max_hbm_for_vectors "
            << "|---------------------"
            //<< "| max_bucket_size "
            << "|-----------------"
            //<< "| dim "
            << "|-----"
            //<< "| missed_ratio "
            << "|--------------"
            //<< "| througput(BillionKV/secs) "
            << "|---------------------------";
  std::cout << "|\n";
}

template <typename T>
void print_w(const T& t, size_t width) {
  std::cout << "|" << std::setw(width) << t;
}

void print_result(double load_factor, size_t capacity,
                  size_t max_hbm_for_vectors, size_t max_bucket_size,
                  size_t dim, double missed_ratio, float througput) {
  print_w(load_factor, 6);
  print_w(capacity, 10);
  print_w(max_hbm_for_vectors, 21);
  print_w(max_bucket_size, 17);
  print_w(dim, 5);
  print_w(missed_ratio, 14);
  print_w(througput, 27);
  std::cout << "|\n";
}

void test_find(size_t capacity, size_t dim, size_t max_hbm_for_vectors,
               double load_factor, size_t max_bucket_size,
               double missed_ratio) {
  MERLIN_CHECK(load_factor >= 0.0 && load_factor <= 1.0,
               "Invalid `load_factor`");
  K* h_keys;
  S* h_scores;
  V* h_vectors;

  TableOptions options;
  options.init_capacity = capacity;
  options.max_capacity = capacity;
  options.dim = dim;

  options.max_hbm_for_vectors = nv::merlin::MB(max_hbm_for_vectors);
  options.max_bucket_size = max_bucket_size;

  size_t key_num = capacity;
  CUDA_CHECK(cudaMallocHost(&h_keys, key_num * sizeof(K)));
  CUDA_CHECK(cudaMallocHost(&h_scores, key_num * sizeof(S)));
  CUDA_CHECK(cudaMallocHost(&h_vectors, key_num * options.dim * sizeof(V)));

  K* d_keys;
  S* d_scores;
  V* d_vectors;
  K* d_missed_keys;
  int* d_missed_indices;
  int* d_missed_size;

  CUDA_CHECK(cudaMalloc(&d_keys, key_num * sizeof(K)));
  CUDA_CHECK(cudaMalloc(&d_scores, key_num * sizeof(S)));
  CUDA_CHECK(cudaMalloc(&d_vectors, key_num * sizeof(V) * options.dim));
  CUDA_CHECK(cudaMalloc(&d_missed_keys, key_num * sizeof(K)));
  CUDA_CHECK(cudaMalloc(&d_missed_indices, key_num * sizeof(int)));
  CUDA_CHECK(cudaMalloc(&d_missed_size, sizeof(int)));

  cudaStream_t stream;
  CUDA_CHECK(cudaStreamCreate(&stream));
  // insert key-value
  size_t insert_num = (double)key_num * load_factor;
  benchmark::create_continuous_keys<K, S>(h_keys, h_scores, insert_num,
                                          0 /*start*/);
  benchmark::init_value_using_key<K, V>(h_keys, h_vectors, insert_num,
                                        options.dim);
  CUDA_CHECK(cudaMemcpy(d_keys, h_keys, insert_num * sizeof(K),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_scores, h_scores, insert_num * sizeof(S),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_vectors, h_vectors,
                        insert_num * sizeof(V) * options.dim,
                        cudaMemcpyHostToDevice));
  Table table;
  table.init(options);
  table.insert_or_assign(insert_num, d_keys, d_vectors, d_scores, stream);
  CUDA_CHECK(cudaStreamSynchronize(stream));

  // find key-value
  size_t find_num = (double)insert_num * (1.0 - missed_ratio);
  benchmark::create_continuous_keys<K, S>(h_keys, nullptr, find_num,
                                          0 /*start*/);
  benchmark::create_continuous_keys<K, S>(
      h_keys + find_num, nullptr, insert_num - find_num, insert_num /*start*/);
  CUDA_CHECK(cudaMemcpy(d_keys, h_keys, insert_num * sizeof(K),
                        cudaMemcpyHostToDevice));

  auto timer = benchmark::Timer<double>();
  timer.start();
  table.find(insert_num, d_keys, d_vectors, d_missed_keys, d_missed_indices,
             d_missed_size, d_scores, stream);
  CUDA_CHECK(cudaStreamSynchronize(stream));
  timer.end();

  CUDA_CHECK(cudaFreeHost(h_keys));
  CUDA_CHECK(cudaFreeHost(h_scores));
  CUDA_CHECK(cudaFreeHost(h_vectors));
  CUDA_CHECK(cudaFree(d_keys));
  CUDA_CHECK(cudaFree(d_scores));
  CUDA_CHECK(cudaFree(d_vectors));
  CUDA_CHECK(cudaFree(d_missed_keys));
  CUDA_CHECK(cudaFree(d_missed_indices));
  CUDA_CHECK(cudaFree(d_missed_size));

  CudaCheckError();
  float througput = insert_num / timer.getResult() / (1024 * 1024 * 1024.0f);
  print_result(load_factor, capacity, max_hbm_for_vectors, max_bucket_size, dim,
               missed_ratio, througput);
}

void test_main(double load_factor, double missed_ratio) {
  constexpr size_t CAPACITY = 100000000UL;
  print_tile();
  // pure HBM
  test_find(CAPACITY, 8, 8 * 1024UL, load_factor, 256, missed_ratio);
  test_find(CAPACITY, 8, 8 * 1024UL, load_factor, 128, missed_ratio);
  // hybrid
  test_find(CAPACITY, 8, 1 * 1024UL, load_factor, 256, missed_ratio);
  test_find(CAPACITY, 8, 1 * 1024UL, load_factor, 128, missed_ratio);
  // pure HMEM
  test_find(CAPACITY, 8, 0, load_factor, 256, missed_ratio);
  test_find(CAPACITY, 8, 0, load_factor, 128, missed_ratio);
}

int main() {
  test_main(0.2, 0);
  test_main(0.2, 0.5);
  test_main(0.2, 1.0);
  test_main(0.5, 0);
  test_main(0.5, 0.5);
  test_main(0.5, 1.0);
  test_main(1.0, 0);
  test_main(1.0, 0.5);
  test_main(1.0, 1.0);
  return 0;
}


================================================
FILE: benchmark/merlin_hashtable_benchmark.cc.cu
================================================
/*
 * Copyright (c) 2022, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <assert.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include "benchmark_util.cuh"
#include "merlin_hashtable.cuh"

using std::cerr;
using std::cout;
using std::endl;
using std::fixed;
using std::setfill;
using std::setprecision;
using std::setw;

using namespace nv::merlin;
using namespace benchmark;

enum class Test_Mode {
  pure_hbm = 0,
  hybrid = 1,
};

const float EPSILON = 0.001f;

std::string rep(int n) { return std::string(n, ' '); }

using K = uint64_t;
using S = uint64_t;
using V = float;
using EvictStrategy = nv::merlin::EvictStrategy;
using TableOptions = nv::merlin::HashTableOptions;

template <class Table>
float test_one_api(std::shared_ptr<Table>& table, const API_Select api,
                   const size_t dim, const size_t init_capacity,
                   const size_t key_num_per_op, const float load_factor,
                   const float hitrate = 0.6f) {
  K* h_keys;
  S* h_scores;
  V* h_vectors;
  bool* h_found;

  CUDA_CHECK(cudaMallocHost(&h_keys, key_num_per_op * sizeof(K)));
  CUDA_CHECK(cudaMallocHost(&h_scores, key_num_per_op * sizeof(S)));
  CUDA_CHECK(cudaMallocHost(&h_vectors, key_num_per_op * sizeof(V) * dim));
  CUDA_CHECK(cudaMallocHost(&h_found, key_num_per_op * sizeof(bool)));

  CUDA_CHECK(cudaMemset(h_vectors, 0, key_num_per_op * sizeof(V) * dim));

  bool need_scores = (Table::evict_strategy == EvictStrategy::kLfu ||
                      Table::evict_strategy == EvictStrategy::kEpochLfu ||
                      Table::evict_strategy == EvictStrategy::kCustomized);

  K* d_keys;
  S* d_scores_real;
  S* d_scores;
  V* d_vectors;
  V* d_def_val;
  V** d_vectors_ptr;
  bool* d_found;
  K* d_keys_out;

  K* d_evict_keys;
  S* d_evict_scores;

  CUDA_CHECK(cudaMalloc(&d_keys, key_num_per_op * sizeof(K)));
  CUDA_CHECK(cudaMalloc(&d_scores_real, key_num_per_op * sizeof(S)));
  CUDA_CHECK(cudaMalloc(&d_vectors, key_num_per_op * sizeof(V) * dim));
  CUDA_CHECK(cudaMalloc(&d_def_val, key_num_per_op * sizeof(V) * dim));
  CUDA_CHECK(cudaMalloc(&d_vectors_ptr, key_num_per_op * sizeof(V*)));
  CUDA_CHECK(cudaMalloc(&d_found, key_num_per_op * sizeof(bool)));
  CUDA_CHECK(cudaMalloc(&d_keys_out, key_num_per_op * sizeof(K)));

  CUDA_CHECK(cudaMalloc(&d_evict_keys, key_num_per_op * sizeof(K)));
  CUDA_CHECK(cudaMalloc(&d_evict_scores, key_num_per_op * sizeof(S)));

  CUDA_CHECK(cudaMemset(d_vectors, 1, key_num_per_op * sizeof(V) * dim));
  CUDA_CHECK(cudaMemset(d_def_val, 2, key_num_per_op * sizeof(V) * dim));
  CUDA_CHECK(cudaMemset(d_vectors_ptr, 0, key_num_per_op * sizeof(V*)));
  CUDA_CHECK(cudaMemset(d_found, 0, key_num_per_op * sizeof(bool)));

  d_scores = need_scores ? d_scores_real : nullptr;

  cudaStream_t stream;
  CUDA_CHECK(cudaStreamCreate(&stream));

  // initialize insert
  // step 1, no need to load load_factor
  uint64_t key_num_init = static_cast<uint64_t>(init_capacity * load_factor);
  const float target_load_factor = key_num_init * 1.0f / init_capacity;
  uint64_t key_num_remain = key_num_init % key_num_per_op == 0
                                ? key_num_per_op
                                : key_num_init % key_num_per_op;
  int32_t loop_num_init = (key_num_init + key_num_per_op - 1) / key_num_per_op;

  K start = 0UL;

  S threshold = benchmark::host_nano<S>();
  int global_epoch = 0;
  for (; global_epoch < loop_num_init; global_epoch++) {
    table->set_global_epoch(global_epoch);
    uint64_t key_num_cur_insert =
        global_epoch == loop_num_init - 1 ? key_num_remain : key_num_per_op;
    create_continuous_keys<K, S>(h_keys, h_scores, key_num_cur_insert, start);
    CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_cur_insert * sizeof(K),
                          cudaMemcpyHostToDevice));
    CUDA_CHECK(cudaMemcpy(d_scores_real, h_scores,
                          key_num_cur_insert * sizeof(S),
                          cudaMemcpyHostToDevice));
    table->find_or_insert(key_num_cur_insert, d_keys, d_vectors_ptr, d_found,
                          d_scores, stream);
    CUDA_CHECK(cudaStreamSynchronize(stream));

    start += key_num_cur_insert;
  }

  // step 2
  float real_load_factor = table->load_factor(stream);
  CUDA_CHECK(cudaStreamSynchronize(stream));
  while (target_load_factor - real_load_factor > EPSILON) {
    auto key_num_append = static_cast<int64_t>(
        (target_load_factor - real_load_factor) * init_capacity);
    if (key_num_append <= 0) break;
    key_num_append =
        std::min(static_cast<int64_t>(key_num_per_op), key_num_append);
    create_continuous_keys<K, S>(h_keys, h_scores, key_num_append, start);
    CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_append * sizeof(K),
                          cudaMemcpyHostToDevice));
    CUDA_CHECK(cudaMemcpy(d_scores_real, h_scores, key_num_append * sizeof(S),
                          cudaMemcpyHostToDevice));
    table->insert_or_assign(key_num_append, d_keys, d_vectors, d_scores,
                            stream);
    CUDA_CHECK(cudaStreamSynchronize(stream));
    start += key_num_append;
    real_load_factor = table->load_factor(stream);
    CUDA_CHECK(cudaStreamSynchronize(stream));
  }

  // For trigger the kernel selection in advance.
  int key_num_per_op_warmup = 1;
  for (int i = 0; i < 9; i++, global_epoch++) {
    table->set_global_epoch(global_epoch);
    switch (api) {
      case API_Select::find: {
        table->find(key_num_per_op_warmup, d_keys, d_vectors, d_found, d_scores,
                    stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      case API_Select::insert_or_assign: {
        table->insert_or_assign(key_num_per_op_warmup, d_keys, d_vectors,
                                d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      case API_Select::find_or_insert: {
        table->find_or_insert(key_num_per_op_warmup, d_keys, d_vectors,
                              d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      case API_Select::assign: {
        table->assign(key_num_per_op_warmup, d_keys, d_def_val, d_scores,
                      stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      case API_Select::insert_and_evict: {
        table->insert_and_evict(key_num_per_op_warmup, d_keys, d_vectors,
                                d_scores, d_evict_keys, d_def_val,
                                d_evict_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      case API_Select::find_ptr: {
        V** d_vectors_ptr = nullptr;
        CUDA_CHECK(
            cudaMalloc(&d_vectors_ptr, key_num_per_op_warmup * sizeof(V*)));
        benchmark::array2ptr(d_vectors_ptr, d_vectors, dim,
                             key_num_per_op_warmup, stream);

        CUDA_CHECK(cudaStreamSynchronize(stream));
        table->find(1, d_keys, d_vectors_ptr, d_found, d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        benchmark::read_from_ptr(d_vectors_ptr, d_vectors, dim,
                                 key_num_per_op_warmup, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        CUDA_CHECK(cudaFree(d_vectors_ptr));
        break;
      }
      case API_Select::find_or_insert_ptr: {
        V** d_vectors_ptr = nullptr;
        bool* d_found;
        CUDA_CHECK(cudaMalloc(&d_found, key_num_per_op_warmup * sizeof(bool)));
        CUDA_CHECK(
            cudaMalloc(&d_vectors_ptr, key_num_per_op_warmup * sizeof(V*)));
        benchmark::array2ptr(d_vectors_ptr, d_vectors, dim,
                             key_num_per_op_warmup, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        table->find_or_insert(key_num_per_op_warmup, d_keys, d_vectors_ptr,
                              d_found, d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        CUDA_CHECK(cudaFree(d_vectors_ptr));
        CUDA_CHECK(cudaFree(d_found));
        break;
      }
      case API_Select::export_batch: {
        size_t* d_dump_counter = nullptr;
        CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t)));
        CUDA_CHECK(cudaMemset(d_dump_counter, 0, sizeof(size_t)));

        table->export_batch(key_num_per_op_warmup, 0, d_dump_counter, d_keys,
                            d_vectors, d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        CUDA_CHECK(cudaFree(d_dump_counter));
        break;
      }
      case API_Select::export_batch_if: {
        size_t* d_dump_counter = nullptr;
        CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t)));
        CUDA_CHECK(cudaMemset(d_dump_counter, 0, sizeof(size_t)));
        K pattern = 0;
        table->template export_batch_if<ExportIfPredFunctor>(
            pattern, threshold, key_num_per_op_warmup, 0, d_dump_counter,
            d_keys, d_vectors, d_scores, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        CUDA_CHECK(cudaFree(d_dump_counter));
        break;
      }
      case API_Select::contains: {
        table->contains(1, d_keys, d_found, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
        break;
      }
      default: {
        std::cout << "[Unsupport API]\n";
      }
    }
  }
  create_keys_for_hitrate<K, S>(h_keys, h_scores, key_num_per_op, hitrate,
                                Hit_Mode::last_insert, start, true /*reset*/);
  CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_per_op * sizeof(K),
                        cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(d_scores_real, h_scores, key_num_per_op * sizeof(K),
                        cudaMemcpyHostToDevice));
  auto timer = benchmark::Timer<double>();
  global_epoch++;
  table->set_global_epoch(global_epoch);
  switch (api) {
    case API_Select::find: {
      timer.start();
      table->find(key_num_per_op, d_keys, d_vectors, d_found, d_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    case API_Select::insert_or_assign: {
      timer.start();
      table->insert_or_assign(key_num_per_op, d_keys, d_vectors, d_scores,
                              stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    case API_Select::find_or_insert: {
      timer.start();
      table->find_or_insert(key_num_per_op, d_keys, d_vectors, d_scores,
                            stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    case API_Select::assign: {
      timer.start();
      table->assign(key_num_per_op, d_keys, d_def_val, d_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    case API_Select::insert_and_evict: {
      timer.start();
      table->insert_and_evict(key_num_per_op, d_keys, d_vectors, d_scores,
                              d_evict_keys, d_def_val, d_evict_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    case API_Select::find_ptr: {
      V** d_vectors_ptr = nullptr;
      CUDA_CHECK(cudaMalloc(&d_vectors_ptr, key_num_per_op * sizeof(V*)));
      benchmark::array2ptr(d_vectors_ptr, d_vectors, dim, key_num_per_op,
                           stream);

      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.start();
      table->find(key_num_per_op, d_keys, d_vectors_ptr, d_found, d_scores,
                  stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      benchmark::read_from_ptr(d_vectors_ptr, d_vectors, dim, key_num_per_op,
                               stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      CUDA_CHECK(cudaFree(d_vectors_ptr));
      break;
    }
    case API_Select::find_or_insert_ptr: {
      V** d_vectors_ptr = nullptr;
      bool* d_found;
      CUDA_CHECK(cudaMalloc(&d_found, key_num_per_op * sizeof(bool)));
      CUDA_CHECK(cudaMalloc(&d_vectors_ptr, key_num_per_op * sizeof(V*)));
      benchmark::array2ptr(d_vectors_ptr, d_vectors, dim, key_num_per_op,
                           stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.start();
      table->find_or_insert(key_num_per_op, d_keys, d_vectors_ptr, d_found,
                            d_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      CUDA_CHECK(cudaFree(d_vectors_ptr));
      CUDA_CHECK(cudaFree(d_found));
      break;
    }
    case API_Select::export_batch: {
      size_t* d_dump_counter;

      // Try to export close to but less than `key_num_per_op` data.
      // It's normal to happen `illegal memory access` error occasionally.
      float safe_ratio = 0.995;

      CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t)));
      CUDA_CHECK(cudaMemset(d_dump_counter, 0, sizeof(size_t)));
      timer.start();
      table->export_batch(key_num_per_op / target_load_factor * safe_ratio, 0,
                          d_dump_counter, d_keys, d_vectors, d_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      CUDA_CHECK(cudaFree(d_dump_counter));
      break;
    }
    case API_Select::export_batch_if: {
      size_t* d_dump_counter;

      // Try to export close to but less than `key_num_per_op` data.
      // It's normal to happen `illegal memory access` error occasionally.
      float safe_ratio = 0.995;

      CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t)));
      CUDA_CHECK(cudaMemset(d_dump_counter, 0, sizeof(size_t)));
      timer.start();
      K pattern = 0;
      table->template export_batch_if<ExportIfPredFunctor>(
          pattern, threshold, key_num_per_op / target_load_factor * safe_ratio,
          0, d_dump_counter, d_keys, d_vectors, d_scores, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      CUDA_CHECK(cudaFree(d_dump_counter));
      break;
    }
    case API_Select::contains: {
      timer.start();
      table->contains(key_num_per_op, d_keys, d_found, stream);
      CUDA_CHECK(cudaStreamSynchronize(stream));
      timer.end();
      break;
    }
    default: {
      std::cout << "[Unsupport API]\n";
    }
  }

  CUDA_CHECK(cudaStreamDestroy(stream));

  CUDA_CHECK(cudaFreeHost(h_keys));
  CUDA_CHECK(cudaFreeHost(h_scores));
  CUDA_CHECK(cudaFreeHost(h_found));

  CUDA_CHECK(cudaFree(d_keys));
  CUDA_CHECK(cudaFree(d_scores_real));
  CUDA_CHECK(cudaFree(d_vectors));
  CUDA_CHECK(cudaFree(d_def_val));
  CUDA_CHECK(cudaFree(d_vectors_ptr));
  CUDA_CHECK(cudaFree(d_found));
  CUDA_CHECK(cudaFree(d_evict_keys));
  CUDA_CHECK(cudaFree(d_evict_scores));

  CUDA_CHECK(cudaDeviceSynchronize());
  CudaCheckError();

  float througput =
      key_num_per_op / timer.getResult() / (1024 * 1024 * 1024.0f);
  return througput;
}

static Test_Mode test_mode = Test_Mode::pure_hbm;

void print_title_a() {
  cout << endl
       << "|    \u03BB " << "| insert_or_assign " << "|   find "
       << "| find_or_insert " << "| assign " << "|  find* "
       << "| find_or_insert* ";
  if (Test_Mode::pure_hbm == test_mode) {
    cout << "| insert_and_evict ";
  }
  cout << "|\n";

  //<< "| load_factor "
  cout << "|-----:"
       //<< "| insert_or_assign "
       << "|-----------------:"
       //<< "|   find "
       << "|-------:"
       //<< "| find_or_insert "
       << "|---------------:"
       //<< "| assign "
       << "|-------:"
       //<< "|   find* "
       << "|-------:"
       //<< "| find_or_insert* "
       << "|----------------:";
  if (Test_Mode::pure_hbm == test_mode) {
    //<< "| insert_and_evict "
    cout << "|-----------------:";
  }
  cout << "|\n";
}

void print_title_b() {
  cout << endl
       << "|    \u03BB " << "| export_batch " << "| export_batch_if "
       << "|  contains ";
  cout << "|\n";

  //<< "| load_factor "
  cout << "|-----:"
       //<< "| export_batch "
       << "|-------------:"
       //<< "| export_batch_if "
       << "|----------------:"
       //<< "|  contains "
       << "|----------:";
  cout << "|\n";
}

void test_main(std::vector<API_Select>& apis, const size_t dim,
               const size_t init_capacity = 64 * 1024 * 1024UL,
               const size_t key_num_per_op = 1 * 1024 * 1024UL,
               const size_t hbm4values = 16, const float load_factor = 1.0f,
               const bool io_by_cpu = false,
               const std::vector<float> load_factors = {0.50f, 0.75f, 1.00f}) {
  size_t free, total;
  CUDA_CHECK(cudaSetDevice(0));
  CUDA_CHECK(cudaMemGetInfo(&free, &total));

  if (free / (1 << 30) < hbm4values) {
    std::cout << "free HBM is not enough, ignore current benchmark!"
              << std::endl;
    return;
  }
  TableOptions options;

  options.init_capacity = init_capacity;
  options.max_capacity = init_capacity;
  options.dim = dim;
  options.max_hbm_for_vectors = nv::merlin::GB(hbm4values);
  options.io_by_cpu = io_by_cpu;
  using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kLru, Sm80>;

  std::shared_ptr<Table> table = std::make_shared<Table>();
  table->init(options);

  for (float load_factor : load_factors) {
    std::cout << "|" << rep(1) << fixed << setprecision(2) << load_factor
              << " ";

    for (auto api : apis) {
      table->clear();
      CUDA_CHECK(cudaDeviceSynchronize());
      // There is a sampling of load_factor after several times call to target
      // API. Two consecutive calls can avoid the impact of sampling.
      auto res1 = test_one_api<Table>(table, api, dim, init_capacity,
                                      key_num_per_op, load_factor);
      auto res2 = test_one_api<Table>(table, api, dim, init_capacity,
                                      key_num_per_op, load_factor);
      auto res = std::max(res1, res2);
      std::cout << "|";
      switch (api) {
        case API_Select::find: {
          std::cout << rep(1);
          break;
        }
        case API_Select::insert_or_assign: {
          std::cout << rep(11);
          break;
        }
        case API_Select::find_or_insert: {
          std::cout << rep(9);
          break;
        }
        case API_Select::assign: {
          std::cout << rep(1);
          break;
        }
        case API_Select::insert_and_evict: {
          std::cout << rep(11);
          break;
        }
        case API_Select::find_ptr: {
          std::cout << rep(1);
          break;
        }
        case API_Select::find_or_insert_ptr: {
          std::cout << rep(10);
          break;
        }
        case API_Select::export_batch: {
          std::cout << rep(7);
          break;
        }
        case API_Select::export_batch_if: {
          std::cout << rep(10);
          break;
        }
        case API_Select::contains: {
          std::cout << rep(4);
          break;
        }
        default: {
          std::cout << "[Unsupport API]";
        }
      }
      std::cout << fixed << setprecision(3) << setw(6) << setfill(' ') << res
                << " ";
    }
    std::cout << "|\n";
  }
}

int main() {
  size_t key_num_per_op = 1 * 1024 * 1024UL;
  cudaDeviceProp props;
  CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
  cout << endl
       << "## Benchmark" << endl
       << endl
       << "* GPU: 1 x " << props.name << ": " << props.major << "."
       << props.minor << endl
       << "* Key Type = uint64_t" << endl
       << "* Value Type = float32 * {dim}" << endl
       << "* Key-Values per OP = " << key_num_per_op << endl
       << "* Evict strategy: LRU" << endl
       << "* `\u03BB`" << ": load factor" << endl
       << "* `find*` means the `find` API that directly returns the addresses "
          "of values."
       << endl
       << "* `find_or_insert*` means the `find_or_insert` API that directly "
          "returns the addresses of values."
       << endl
       << "* ***Throughput Unit: Billion-KV/second***" << endl
       << endl;
  auto print_configuration = [](const size_t dim, const size_t init_capacity,
                                const size_t hbm4values) {
    using V = float;
    int32_t capacity = static_cast<int32_t>(init_capacity / (1024 * 1024));
    size_t hmem4values = init_capacity * dim * sizeof(V) / (1024 * 1024 * 1024);
    hmem4values = hmem4values < hbm4values ? 0 : (hmem4values - hbm4values);
    cout << "\n* dim = " << dim << ", " << "capacity = " << capacity
         << " Million-KV, " << "HBM = " << hbm4values << " GB, "
         << "HMEM = " << hmem4values << " GB\n";
  };

  try {
    {
      std::vector<API_Select> apis_a{
          API_Select::insert_or_assign, API_Select::find,
          API_Select::find_or_insert,   API_Select::assign,
          API_Select::find_ptr,         API_Select::find_or_insert_ptr,
          API_Select::insert_and_evict};

      std::vector<API_Select> apis_b{API_Select::export_batch,
                                     API_Select::export_batch_if,
                                     API_Select::contains};
      test_mode = Test_Mode::pure_hbm;

      cout << "### On pure HBM mode: " << endl;
      print_configuration(8, 128 * 1024 * 1024UL, 4);
      print_title_a();
      test_main(apis_a, 8, 128 * 1024 * 1024UL, key_num_per_op, 4);

      print_title_b();
      test_main(apis_b, 8, 128 * 1024 * 1024UL, key_num_per_op, 4);

      print_configuration(32, 128 * 1024 * 1024UL, 16);
      print_title_a();
      test_main(apis_a, 32, 128 * 1024 * 1024UL, key_num_per_op, 16);

      print_title_b();
      test_main(apis_b, 32, 128 * 1024 * 1024UL, key_num_per_op, 16);

      print_configuration(64, 64 * 1024 * 1024UL, 16);
      print_title_a();
      test_main(apis_a, 64, 64 * 1024 * 1024UL, key_num_per_op, 16);

      print_title_b();
      test_main(apis_b, 64, 64 * 1024 * 1024UL, key_num_per_op, 16);

      cout << endl;
    }

    {
      std::vector<API_Select> apis_a{
          API_Select::insert_or_assign, API_Select::find,
          API_Select::find_or_insert,   API_Select::assign,
          API_Select::find_ptr,         API_Select::find_or_insert_ptr};

      std::vector<API_Select> apis_b{API_Select::export_batch,
                                     API_Select::export_batch_if,
                                     API_Select::contains};

      cout << "### On HBM+HMEM hybrid mode: " << endl;
      test_mode = Test_Mode::hybrid;
      print_configuration(64, 128 * 1024 * 1024UL, 16);
      print_title_a();
      test_main(apis_a, 64, 128 * 1024 * 1024UL, key_num_per_op, 16);

      print_title_b();
      test_main(apis_b, 64, 128 * 1024 * 1024UL, key_num_per_op, 16);

      print_configuration(64, 512 * 1024 * 1024UL, 32);
      print_title_a();
      test_main(apis_a, 64, 512 * 1024 * 1024UL, key_num_per_op, 32);

      print_title_b();
      test_main(apis_b, 64, 512 * 1024 * 1024UL, key_num_per_op, 32);
      cout << endl;
    }

    CUDA_CHECK(cudaDeviceSynchronize());
  } catch (const nv::merlin::CudaException& e) {
    cerr << e.what() << endl;
  }
  CUDA_CHECK(cudaDeviceSynchronize());
  return 0;
}


================================================
FILE: build_deps/gpus/BUILD
================================================


================================================
FILE: build_deps/gpus/check_cuda_libs.py
================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Verifies that a list of libraries is installed on the system.

Takes a list of arguments with every two subsequent arguments being a logical
tuple of (path, check_soname). The path to the library and either True or False
to indicate whether to check the soname field on the shared library.

Example Usage:
./check_cuda_libs.py /path/to/lib1.so True /path/to/lib2.so False
"""
import os
import os.path
import platform
import subprocess
import sys

# pylint: disable=g-import-not-at-top,g-importing-member
try:
    from shutil import which
except ImportError:
    from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top,g-importing-member


class ConfigError(Exception):
    pass


def check_cuda_lib(path, check_soname=True):
    """Tests if a library exists on disk and whether its soname matches the filename.

  Args:
    path: the path to the library.
    check_soname: whether to check the soname as well.

  Raises:
    ConfigError: If the library does not exist or if its soname does not match
    the filename.
  """
    if not os.path.isfile(path):
        raise ConfigError("No library found under: " + path)
    objdump = which("objdump")
    if check_soname and objdump is not None:
        # Decode is necessary as in py3 the return type changed from str to bytes
        output = subprocess.check_output([objdump, "-p", path]).decode("utf-8")
        output = [line for line in output.splitlines() if "SONAME" in line]
        sonames = [line.strip().split(" ")[-1] for line in output]
        if not any(soname == os.path.basename(path) for soname in sonames):
            raise ConfigError("None of the libraries match their SONAME: " +
                              path)


def main():
    try:
        args = [argv for argv in sys.argv[1:]]
        if len(args) % 2 == 1:
            raise ConfigError("Expected even number of arguments")
        checked_paths = []
        for i in range(0, len(args), 2):
            path = args[i]
            check_cuda_lib(path, check_soname=args[i + 1] == "True")
            checked_paths.append(path)
        # pylint: disable=superfluous-parens
        print(os.linesep.join(checked_paths))
        # pylint: enable=superfluous-parens
    except ConfigError as e:
        sys.stderr.write(str(e))
        sys.exit(1)


if __name__ == "__main__":
    main()


================================================
FILE: build_deps/gpus/configure.bzl
================================================
"""Repository rule for CUDA autoconfiguration.

`cuda_configure` depends on the following environment variables:

  * `NEED_CUDA`: Whether to enable building with CUDA.
  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
  * `SYSROOT`: The sysroot to use when compiling.
  * `CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
    `/usr/local/cuda,usr/`.
  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
    `/usr/local/cuda`.
  * `CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
    use the system default.
  * `CUDNN_VERSION`: The version of the cuDNN library.
  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
    `/usr/local/cuda`.
  * `CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
    `3.5,5.2`.
  * `PYTHON_BIN_PATH`: The python binary path
"""

load(
    "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
    "escape_string",
    "get_env_var",
)
load(
    "//build_deps/remote_config:common.bzl",
    "config_repo_label",
    "err_out",
    "execute",
    "get_bash_bin",
    "get_cpu_value",
    "get_host_environ",
    "get_python_bin",
    "raw_exec",
    "read_dir",
    "realpath",
    "which",
)

_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
_SYSROOT = "SYSROOT"
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
_CUDA_VERSION = "CUDA_VERSION"
_CUDNN_VERSION = "CUDNN_VERSION"
_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
_CUDA_COMPUTE_CAPABILITIES = "CUDA_COMPUTE_CAPABILITIES"
_CUDA_CONFIG_REPO = "CUDA_CONFIG_REPO"
_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"

_TENSORRT_VERSION = "TENSORRT_VERSION"
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TENSORRT_STATIC_PATH = "TENSORRT_STATIC_PATH"
_TENSORRT_LIBS = [
    "nvinfer",
    "nvinfer_plugin",
    "nvonnxparser",
    "nvparsers",
]
_TENSORRT_HEADERS = [
    "NvInfer.h",
    "NvUtils.h",
    "NvInferPlugin.h",
]
_TENSORRT_HEADERS_V6 = [
    "NvInfer.h",
    "NvUtils.h",
    "NvInferPlugin.h",
    "NvInferVersion.h",
    "NvInferRuntime.h",
    "NvInferRuntimeCommon.h",
    "NvInferPluginUtils.h",
    "NvOnnxParser.h",
    "NvOnnxConfig.h",
]
_TENSORRT_HEADERS_V8 = [
    "NvInfer.h",
    "NvInferLegacyDims.h",
    "NvInferImpl.h",
    "NvUtils.h",
    "NvInferPlugin.h",
    "NvInferVersion.h",
    "NvInferRuntime.h",
    "NvInferRuntimeCommon.h",
    "NvInferPluginUtils.h",
    "NvOnnxParser.h",
    "NvOnnxConfig.h",
]

def _at_least_version(actual_version, required_version):
    actual = [int(v) for v in actual_version.split(".")]
    required = [int(v) for v in required_version.split(".")]
    return actual >= required

def _get_tensorrt_headers(tensorrt_version):
    if _at_least_version(tensorrt_version, "8"):
        return _TENSORRT_HEADERS_V8
    if _at_least_version(tensorrt_version, "6"):
        return _TENSORRT_HEADERS_V6
    return _TENSORRT_HEADERS

def to_list_of_strings(elements):
    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.

    This is to be used to put a list of strings into the bzl file templates
    so it gets interpreted as list of strings in Starlark.

    Args:
      elements: list of string elements

    Returns:
      single string of elements wrapped in quotes separated by a comma."""
    quoted_strings = ["\"" + element + "\"" for element in elements]
    return ", ".join(quoted_strings)

def verify_build_defines(params):
    """Verify all variables that crosstool/BUILD.tpl expects are substituted.

    Args:
      params: dict of variables that will be passed to the BUILD.tpl template.
    """
    missing = []
    for param in [
        "cxx_builtin_include_directories",
        "extra_no_canonical_prefixes_flags",
        "host_compiler_path",
        "host_compiler_prefix",
        "host_compiler_warnings",
        "linker_bin_path",
        "compiler_deps",
        "unfiltered_compile_flags",
    ]:
        if ("%{" + param + "}") not in params:
            missing.append(param)

    if missing:
        auto_configure_fail(
            "BUILD.tpl template is missing these variables: " + str(missing) +
            ".\nWe only got: " + str(params) + ".",
        )

# TODO(dzc): Once these functions have been factored out of Bazel's
# cc_configure.bzl, load them from @bazel_tools instead.
# BEGIN cc_configure common functions.
def find_cc(repository_ctx):
    """Find the C++ compiler."""
    target_cc_name = "gcc"
    cc_path_envvar = _GCC_HOST_COMPILER_PATH
    cc_name = target_cc_name

    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
    if cc_name_from_env:
        cc_name = cc_name_from_env
    if cc_name.startswith("/"):
        # Absolute path, maybe we should make this supported by our which function.
        return cc_name
    cc = which(repository_ctx, cc_name)
    if cc == None:
        fail(("Cannot find {}, either correct your path or set the {}" +
              " environment variable").format(target_cc_name, cc_path_envvar))
    return cc

_INC_DIR_MARKER_BEGIN = "#include <...>"

# OSX add " (framework directory)" at the end of line, strip it.
_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)

def _cxx_inc_convert(path):
    """Convert path returned by cc -E xc++ in a complete path."""
    path = path.strip()
    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
    return path

def _normalize_include_path(repository_ctx, path):
    """Normalizes include paths before writing them to the crosstool.

      If path points inside the 'crosstool' folder of the repository, a relative
      path is returned.
      If path points outside the 'crosstool' folder, an absolute path is returned.
      """
    path = str(repository_ctx.path(path))
    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))

    if path.startswith(crosstool_folder):
        # We drop the path to "$REPO/crosstool" and a trailing path separator.
        return path[len(crosstool_folder) + 1:]
    return path

def _is_compiler_option_supported(repository_ctx, cc, option):
    """Checks that `option` is supported by the C compiler. Doesn't %-escape the option."""
    result = repository_ctx.execute([
        cc,
        option,
        "-o",
        "/dev/null",
        "-c",
        str(repository_ctx.path("tools/cpp/empty.cc")),
    ])
    return result.stderr.find(option) == -1

def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
    """Compute the list of default C or C++ include directories."""
    if lang_is_cpp:
        lang = "c++"
    else:
        lang = "c"
    sysroot = []
    if tf_sysroot:
        sysroot += ["--sysroot", tf_sysroot]
    result = raw_exec(
        repository_ctx,
        [cc, "-E", "-x" + lang, "-", "-v"] + sysroot,
    )
    stderr = err_out(result)
    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
    if index1 == -1:
        return []
    index1 = stderr.find("\n", index1)
    if index1 == -1:
        return []
    index2 = stderr.rfind("\n ")
    if index2 == -1 or index2 < index1:
        return []
    index2 = stderr.find("\n", index2 + 1)
    if index2 == -1:
        inc_dirs = stderr[index1 + 1:]
    else:
        inc_dirs = stderr[index1 + 1:index2].strip()

    print_resource_dir_supported = _is_compiler_option_supported(
        repository_ctx,
        cc,
        "-print-resource-dir",
    )

    if print_resource_dir_supported:
        resource_dir = repository_ctx.execute(
            [cc, "-print-resource-dir"],
        ).stdout.strip() + "/share"
        inc_dirs += "\n" + resource_dir

    return [
        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
        for p in inc_dirs.split("\n")
    ]

def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
    """Compute the list of default C and C++ include directories."""

    includes_cpp = _get_cxx_inc_directories_impl(
        repository_ctx,
        cc,
        True,
        tf_sysroot,
    )
    includes_c = _get_cxx_inc_directories_impl(
        repository_ctx,
        cc,
        False,
        tf_sysroot,
    )

    return includes_cpp + [
        inc
        for inc in includes_c
        if inc not in includes_cpp
    ]

def auto_configure_fail(msg):
    """Output failure message when cuda configuration fails."""
    red = "\033[0;31m"
    no_color = "\033[0m"
    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))

# END cc_configure common functions (see TODO above).

def _cuda_include_path(repository_ctx, cuda_config):
    """Generates the Starlark string with cuda include directories.

      Args:
        repository_ctx: The repository context.
        cc: The path to the gcc host compiler.

      Returns:
        A list of the gcc host compiler include directories.
      """
    nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
        cuda_config.cuda_toolkit_path,
        ".exe" if cuda_config.cpu_value == "Windows" else "",
    ))

    # The expected exit code of this command is non-zero. Bazel remote execution
    # only caches commands with zero exit code. So force a zero exit code.
    cmd = "%s -v /dev/null -o /dev/null ; [ $? -eq 1 ]" % str(nvcc_path)
    result = raw_exec(
        repository_ctx,
        [get_bash_bin(repository_ctx), "-c", cmd],
    )
    target_dir = ""
    for one_line in err_out(result).splitlines():
        if one_line.startswith("#$ _TARGET_DIR_="):
            target_dir = (cuda_config.cuda_toolkit_path + "/" +
                          one_line.replace(
                              "#$ _TARGET_DIR_=",
                              "",
                          ) + "/include")
    inc_entries = []
    if target_dir != "":
        inc_entries.append(realpath(repository_ctx, target_dir))
    inc_entries.append(
        realpath(repository_ctx, cuda_config.cuda_toolkit_path + "/include"),
    )
    return inc_entries

def matches_version(environ_version, detected_version):
    """Checks whether the user-specified version matches the detected version.

      This function performs a weak matching so that if the user specifies only
      the
      major or major and minor versions, the versions are still considered
      matching
      if the version parts match. To illustrate:

          environ_version  detected_version  result
          -----------------------------------------
          5.1.3            5.1.3             True
          5.1              5.1.3             True
          5                5.1               True
          5.1.3            5.1               False
          5.2.3            5.1.3             False

      Args:
        environ_version: The version specified by the user via environment
          variables.
        detected_version: The version autodetected from the CUDA installation on
          the system.
      Returns: True if user-specified version matches detected version and False
        otherwise.
    """
    environ_version_parts = environ_version.split(".")
    detected_version_parts = detected_version.split(".")
    if len(detected_version_parts) < len(environ_version_parts):
        return False
    for i, part in enumerate(detected_version_parts):
        if i >= len(environ_version_parts):
            break
        if part != environ_version_parts[i]:
            return False
    return True

_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "

_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"

def compute_capabilities(repository_ctx):
    """Returns a list of strings representing cuda compute capabilities.

    Args:
      repository_ctx: the repo rule's context.
    Returns: list of cuda architectures to compile for. 'compute_xy' refers to
      both PTX and SASS, 'sm_xy' refers to SASS only.
    """
    capabilities = get_host_environ(
        repository_ctx,
        _CUDA_COMPUTE_CAPABILITIES,
        "compute_35,compute_52",
    ).split(",")

    # Map old 'x.y' capabilities to 'compute_xy'.
    if len(capabilities) > 0 and all(
        [len(x.split(".")) == 2 for x in capabilities],
    ):
        # If all capabilities are in 'x.y' format, only include PTX for the
        # highest capability.
        cc_list = sorted([x.replace(".", "") for x in capabilities])
        capabilities = [
            "sm_%s" % x
            for x in cc_list[:-1]
        ] + ["compute_%s" % cc_list[-1]]
    for i, capability in enumerate(capabilities):
        parts = capability.split(".")
        if len(parts) != 2:
            continue
        capabilities[i] = "compute_%s%s" % (parts[0], parts[1])

    # Make list unique
    capabilities = dict(zip(capabilities, capabilities)).keys()

    # Validate capabilities.
    for capability in capabilities:
        if not capability.startswith(("compute_", "sm_")):
            auto_configure_fail("Invalid compute capability: %s" % capability)
        for prefix in ["compute_", "sm_"]:
            if not capability.startswith(prefix):
                continue
            if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(
            ):
                continue
            auto_configure_fail("Invalid compute capability: %s" % capability)

    return capabilities

def lib_name(base_name, cpu_value, version = None, static = False):
    """Constructs the platform-specific name of a library.

      Args:
        base_name: The name of the library, such as "cudart"
        cpu_value: The name of the host operating system.
        version: The version of the library.
        static: True the library is static or False if it is a shared object.

      Returns:
        The platform-specific name of the library.
      """
    version = "" if not version else "." + version
    if cpu_value in ("Linux", "FreeBSD"):
        if static:
            return "lib%s.a" % base_name
        return "lib%s.so%s" % (base_name, version)
    elif cpu_value == "Windows":
        return "%s.lib" % base_name
    elif cpu_value == "Darwin":
        if static:
            return "lib%s.a" % base_name
        return "lib%s%s.dylib" % (base_name, version)
    else:
        auto_configure_fail("Invalid cpu_value: %s" % cpu_value)

def _lib_path(lib, cpu_value, basedir, version, static):
    file_name = lib_name(lib, cpu_value, version, static)
    return "%s/%s" % (basedir, file_name)

def _should_check_soname(version, static):
    return version and not static

def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False):
    return (
        _lib_path(lib, cpu_value, basedir, version, static),
        _should_check_soname(version, static),
    )

def _check_cuda_libs(repository_ctx, script_path, libs):
    python_bin = get_python_bin(repository_ctx)
    contents = repository_ctx.read(script_path).splitlines()

    cmd = "from os import linesep;"
    cmd += "f = open('script.py', 'w');"
    for line in contents:
        cmd += "f.write('%s' + linesep);" % line
    cmd += "f.close();"
    cmd += "from os import system;"
    args = " ".join(["\"" + path + "\" " + str(check) for path, check in libs])
    cmd += "system('%s script.py %s');" % (python_bin, args)

    all_paths = [path for path, _ in libs]
    checked_paths = execute(
        repository_ctx,
        [python_bin, "-c", cmd],
    ).stdout.splitlines()

    # Filter out empty lines from splitting on '\r\n' on Windows
    checked_paths = [path for path in checked_paths if len(path) > 0]
    if all_paths != checked_paths:
        auto_configure_fail(
            "Error with installed CUDA libs. Expected '%s'. Actual '%s'." %
            (all_paths, checked_paths),
        )

def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
    """Returns the CUDA and cuDNN libraries on the system.

      Also, verifies that the script actually exist.

      Args:
        repository_ctx: The repository context.
        check_cuda_libs_script: The path to a script verifying that the cuda
          libraries exist on the system.
        cuda_config: The CUDA config as returned by _get_cuda_config

      Returns:
        Map of library names to structs of filename and path.
      """
    cpu_value = cuda_config.cpu_value
    stub_dir = "/stubs"

    check_cuda_libs_params = {
        "cuda": _check_cuda_lib_params(
            "cuda",
            cpu_value,
            cuda_config.config["cuda_library_dir"] + stub_dir,
            version = None,
            static = False,
        ),
        "cudart": _check_cuda_lib_params(
            "cudart",
            cpu_value,
            cuda_config.config["cuda_library_dir"],
            cuda_config.cudart_version,
            static = False,
        ),
        "cudart_static": _check_cuda_lib_params(
            "cudart_static",
            cpu_value,
            cuda_config.config["cuda_library_dir"],
            cuda_config.cudart_version,
            static = True,
        ),
        "cublas": _check_cuda_lib_params(
            "cublas",
            cpu_value,
            cuda_config.config["cublas_library_dir"],
            cuda_config.cublas_version,
            static = False,
        ),
        "cublasLt": _check_cuda_lib_params(
            "cublasLt",
            cpu_value,
            cuda_config.config["cublas_library_dir"],
            cuda_config.cublas_version,
            static = False,
        ),
        "cusolver": _check_cuda_lib_params(
            "cusolver",
            cpu_value,
            cuda_config.config["cusolver_library_dir"],
            cuda_config.cusolver_version,
            static = False,
        ),
        "curand": _check_cuda_lib_params(
            "curand",
            cpu_value,
            cuda_config.config["curand_library_dir"],
            cuda_config.curand_version,
            static = False,
        ),
        "cufft": _check_cuda_lib_params(
            "cufft",
            cpu_value,
            cuda_config.config["cufft_library_dir"],
            cuda_config.cufft_version,
            static = False,
        ),
        "cudnn": _check_cuda_lib_params(
            "cudnn",
            cpu_value,
            cuda_config.config["cudnn_library_dir"],
            cuda_config.cudnn_version,
            static = False,
        ),
        "cupti": _check_cuda_lib_params(
            "cupti",
            cpu_value,
            cuda_config.config["cupti_library_dir"],
            cuda_config.cupti_version,
            static = False,
        ),
        "cusparse": _check_cuda_lib_params(
            "cusparse",
            cpu_value,
            cuda_config.config["cusparse_library_dir"],
            cuda_config.cusparse_version,
            static = False,
        ),
    }

    # Verify that the libs actually exist at their locations.
    _check_cuda_libs(
        repository_ctx,
        check_cuda_libs_script,
        check_cuda_libs_params.values(),
    )

    paths = {
        filename: v[0]
        for (filename, v) in check_cuda_libs_params.items()
    }
    return paths

def _cudart_static_linkopt(cpu_value):
    """Returns additional platform-specific linkopts for cudart."""
    return "" if cpu_value == "Darwin" else "\"-lrt\","

def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
    python_bin = get_python_bin(repository_ctx)
    cmd = "from os import system;" + "system('\"%s\" %s %s');" % (
        python_bin,
        script_path,
        " ".join(cuda_libraries),
    )
    return execute(repository_ctx, [python_bin, "-c", cmd])

# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
# and nccl_configure.bzl.
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
    """Returns CUDA config dictionary from running find_cuda_config.py"""
    exec_result = _exec_find_cuda_config(
        repository_ctx,
        script_path,
        cuda_libraries,
    )

    if exec_result.return_code:
        auto_configure_fail("Failed to run find_cuda_config.py: %s" %
                            err_out(exec_result))

    # Parse the dict from stdout.
    return dict(
        [tuple(x.split(": ")) for x in exec_result.stdout.splitlines()],
    )

def _get_cuda_config(repository_ctx, find_cuda_config_script):
    """Detects and returns information about the CUDA installation on the system.

      Args:
        repository_ctx: The repository context.

      Returns:
        A struct containing the following fields:
          cuda_toolkit_path: The CUDA toolkit installation directory.
          cudnn_install_basedir: The cuDNN installation directory.
          cuda_version: The version of CUDA on the system.
          cudart_version: The CUDA runtime version on the system.
          cudnn_version: The version of cuDNN on the system.
          compute_capabilities: A list of the system's CUDA compute capabilities.
          cpu_value: The name of the host operating system.
      """
    config = find_cuda_config(
        repository_ctx,
        find_cuda_config_script,
        ["cuda", "cudnn"],
    )

    cpu_value = get_cpu_value(repository_ctx)
    toolkit_path = config["cuda_toolkit_path"]

    cuda_version = config["cuda_version"].split(".")
    cuda_major = cuda_version[0]
    cuda_minor = cuda_version[1]

    cuda_version = "%s.%s" % (cuda_major, cuda_minor)
    cudnn_version = "%s" % config["cudnn_version"]

    if int(cuda_major) >= 11:
        # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
        if int(cuda_major) == 11:
            cudart_version = "11.0"
            cupti_version = cuda_version
        else:
            cudart_version = ("%s") % cuda_major
            cupti_version = cudart_version
        cublas_version = ("%s") % config["cublas_version"].split(".")[0]
        cusolver_version = ("%s") % config["cusolver_version"].split(".")[0]
        curand_version = ("%s") % config["curand_version"].split(".")[0]
        cufft_version = ("%s") % config["cufft_version"].split(".")[0]
        cusparse_version = ("%s") % config["cusparse_version"].split(".")[0]
    elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
        # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
        # It changed from 'x.y' to just 'x' in CUDA 10.1.
        cuda_lib_version = ("%s") % cuda_major
        cudart_version = cuda_version
        cupti_version = cuda_version
        cublas_version = cuda_lib_version
        cusolver_version = cuda_lib_version
        curand_version = cuda_lib_version
        cufft_version = cuda_lib_version
        cusparse_version = cuda_lib_version
    else:
        cudart_version = cuda_version
        cupti_version = cuda_version
        cublas_version = cuda_version
        cusolver_version = cuda_version
        curand_version = cuda_version
        cufft_version = cuda_version
        cusparse_version = cuda_version

    return struct(
        cuda_toolkit_path = toolkit_path,
        cuda_version = cuda_version,
        cupti_version = cupti_version,
        cuda_version_major = cuda_major,
        cudart_version = cudart_version,
        cublas_version = cublas_version,
        cusolver_version = cusolver_version,
        curand_version = curand_version,
        cufft_version = cufft_version,
        cusparse_version = cusparse_version,
        cudnn_version = cudnn_version,
        compute_capabilities = compute_capabilities(repository_ctx),
        cpu_value = cpu_value,
        config = config,
    )

def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
    if not out:
        out = tpl.replace(":", "/")
    repository_ctx.template(
        out,
        Label("//build_deps/gpus/%s.tpl" % tpl),
        substitutions,
    )

def _file(repository_ctx, label):
    repository_ctx.template(
        label.replace(":", "/"),
        Label("//build_deps/gpus/%s.tpl" % label),
        {},
    )

_DUMMY_CROSSTOOL_BZL_FILE = """
def error_gpu_disabled():
  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
       "at the prompt to build with GPU support.")

  native.genrule(
      name = "error_gen_crosstool",
      outs = ["CROSSTOOL"],
      cmd = "echo 'Should not be run.' && exit 1",
  )

  native.filegroup(
      name = "crosstool",
      srcs = [":CROSSTOOL"],
      output_licenses = ["unencumbered"],
  )
"""

_DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")

error_gpu_disabled()
"""

def _norm_path(path):
    """Returns a path with '/' and remove the trailing slash."""
    path = path.replace("\\", "/")
    if path[-1] == "/":
        path = path[:-1]
    return path

def make_copy_files_rule(repository_ctx, name, srcs, outs):
    """Returns a rule to copy a set of files."""
    cmds = []

    # Copy files.
    for src, out in zip(srcs, outs):
        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
    outs = [('        "%s",' % out) for out in outs]
    return """genrule(
    name = "%s",
    outs = [
%s
    ],
    cmd = \"""%s \""",
)""" % (name, "\n".join(outs), " && \\\n".join(cmds))

def make_copy_dir_rule(
        repository_ctx,
        name,
        src_dir,
        out_dir,
        exceptions = None):
    """Returns a rule to recursively copy a directory.
    If exceptions is not None, it must be a list of files or directories in
    'src_dir'; these will be excluded from copying.
    """
    src_dir = _norm_path(src_dir)
    out_dir = _norm_path(out_dir)
    outs = read_dir(repository_ctx, src_dir)
    post_cmd = ""
    if exceptions != None:
        outs = [
            x
            for x in outs
            if not any([x.startswith(src_dir + "/" + y) for y in exceptions])
        ]
    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]

    # '@D' already contains the relative path for a single file, see
    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
    if exceptions != None:
        for x in exceptions:
            post_cmd += " ; rm -fR " + out_dir + "/" + x
    return """genrule(
    name = "%s",
    outs = [
%s
    ],
    cmd = \"""cp -rLf "%s/." "%s/" %s\""",
)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd)

def _flag_enabled(repository_ctx, flag_name):
    return get_host_environ(repository_ctx, flag_name) == "1"

def _tf_sysroot(repository_ctx):
    return get_host_environ(repository_ctx, _SYSROOT, "")

def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
    copts = []
    for capability in compute_capabilities:
        if capability.startswith("compute_"):
            capability = capability.replace("compute_", "sm_")
            copts.append("--cuda-include-ptx=%s" % capability)
        copts.append("--cuda-gpu-arch=%s" % capability)

    return str(copts)

def _tpl_path(repository_ctx, filename):
    return repository_ctx.path(Label("//build_deps/gpus/%s.tpl" % filename))

def _basename(repository_ctx, path_str):
    """Returns the basename of a path of type string.
    """

    num_chars = len(path_str)
    for i in range(num_chars):
        r_i = num_chars - 1 - i
        if path_str[r_i] == "/":
            return path_str[r_i + 1:]
    return path_str

def _create_local_cuda_repository(repository_ctx):
    """Creates the repository containing files set up to build with CUDA."""
    tpl_paths = {
        filename: _tpl_path(repository_ctx, filename)
        for filename in [
            "cuda:build_defs.bzl",
            "crosstool:crosstool_compiler_wrapper",
            "crosstool:BUILD",
            "crosstool:cc_toolchain_config.bzl",
            "cuda:cuda_config.h",
            "cuda:cuda_config.py",
        ]
    }
    tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD")
    find_cuda_config_script = repository_ctx.path(
        Label("//build_deps/gpus:find_cuda_config.py"),
    )

    cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)

    cuda_include_path = cuda_config.config["cuda_include_dir"]
    cublas_include_path = cuda_config.config["cublas_include_dir"]
    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
    cupti_header_dir = cuda_config.config["cupti_include_dir"]
    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]

    # Create genrule to copy files from the installed CUDA toolkit into execroot.
    copy_rules = [
        make_copy_dir_rule(
            repository_ctx,
            name = "cuda-include",
            src_dir = cuda_include_path,
            out_dir = "cuda/include",
        ),
        make_copy_dir_rule(
            repository_ctx,
            name = "cuda-nvvm",
            src_dir = nvvm_libdevice_dir,
            out_dir = "cuda/nvvm/libdevice",
        ),
        make_copy_dir_rule(
            repository_ctx,
            name = "cuda-extras",
            src_dir = cupti_header_dir,
            out_dir = "cuda/extras/CUPTI/include",
        ),
    ]

    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cublas-include",
            srcs = [
                cublas_include_path + "/cublas.h",
                cublas_include_path + "/cublas_v2.h",
                cublas_include_path + "/cublas_api.h",
                cublas_include_path + "/cublasLt.h",
            ],
            outs = [
                "cublas/include/cublas.h",
                "cublas/include/cublas_v2.h",
                "cublas/include/cublas_api.h",
                "cublas/include/cublasLt.h",
            ],
        ),
    )

    cusolver_include_path = cuda_config.config["cusolver_include_dir"]
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cusolver-include",
            srcs = [
                cusolver_include_path + "/cusolver_common.h",
                cusolver_include_path + "/cusolverDn.h",
            ],
            outs = [
                "cusolver/include/cusolver_common.h",
                "cusolver/include/cusolverDn.h",
            ],
        ),
    )

    cufft_include_path = cuda_config.config["cufft_include_dir"]
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cufft-include",
            srcs = [
                cufft_include_path + "/cufft.h",
            ],
            outs = [
                "cufft/include/cufft.h",
            ],
        ),
    )

    cusparse_include_path = cuda_config.config["cusparse_include_dir"]
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cusparse-include",
            srcs = [
                cusparse_include_path + "/cusparse.h",
            ],
            outs = [
                "cusparse/include/cusparse.h",
            ],
        ),
    )

    curand_include_path = cuda_config.config["curand_include_dir"]
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "curand-include",
            srcs = [
                curand_include_path + "/curand.h",
            ],
            outs = [
                "curand/include/curand.h",
            ],
        ),
    )

    check_cuda_libs_script = repository_ctx.path(
        Label("//build_deps/gpus:check_cuda_libs.py"),
    )
    cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
    cuda_lib_srcs = []
    cuda_lib_outs = []
    for path in cuda_libs.values():
        cuda_lib_srcs.append(path)
        cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path))
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cuda-lib",
            srcs = cuda_lib_srcs,
            outs = cuda_lib_outs,
        ),
    )

    file_ext = ""
    bin_files = (
        ["crt/link.stub"] +
        [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]]
    )
    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cuda-bin",
            srcs = [
                cuda_config.cuda_toolkit_path + "/bin/" + f
                for f in bin_files
            ],
            outs = ["cuda/bin/" + f for f in bin_files],
        ),
    )

    # Select the headers based on the cuDNN version (strip '64_' for Windows).
    cudnn_headers = ["cudnn.h"]
    if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8":
        cudnn_headers += [
            "cudnn_backend.h",
            "cudnn_adv_infer.h",
            "cudnn_adv_train.h",
            "cudnn_cnn_infer.h",
            "cudnn_cnn_train.h",
            "cudnn_ops_infer.h",
            "cudnn_ops_train.h",
            "cudnn_version.h",
        ]

    cudnn_srcs = []
    cudnn_outs = []
    for header in cudnn_headers:
        cudnn_srcs.append(cudnn_header_dir + "/" + header)
        cudnn_outs.append("cudnn/include/" + header)

    copy_rules.append(
        make_copy_files_rule(
            repository_ctx,
            name = "cudnn-include",
            srcs = cudnn_srcs,
            outs = cudnn_outs,
        ),
    )

    # Set up BUILD file for cuda/
    repository_ctx.template(
        "cuda/build_defs.bzl",
        tpl_paths["cuda:build_defs.bzl"],
        {
            "%{cuda_is_configured}": "True",
            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
                repository_ctx,
                cuda_config.compute_capabilities,
            ),
            "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
        },
    )

    cub_actual = "@cub_archive//:cub"
    if int(cuda_config.cuda_version_major) >= 11:
        cub_actual = ":cuda_headers"

    repository_ctx.template(
        "cuda/BUILD",
        tpl_paths["cuda:BUILD"],
        {
            "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]),
            "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]),
            "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
            "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
            "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
            "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
            "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
            "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
            "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
            "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
            "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
            "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
            "%{cub_actual}": cub_actual,
            "%{copy_rules}": "\n".join(copy_rules),
        },
    )

    tf_sysroot = _tf_sysroot(repository_ctx)

    # Set up crosstool/
    cc = find_cc(repository_ctx)
    cc_fullpath = cc

    host_compiler_includes = get_cxx_inc_directories(
        repository_ctx,
        cc_fullpath,
        tf_sysroot,
    )
    cuda_defines = {}
    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
    cuda_defines["%{cuda_toolkit_path}"] = ""
    cuda_defines["%{compiler}"] = "unknown"

    host_compiler_prefix = get_host_environ(
        repository_ctx,
        _GCC_HOST_COMPILER_PREFIX,
    )
    if not host_compiler_prefix:
        host_compiler_prefix = "/usr/bin"

    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
    cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
    cuda_defines["%{unfiltered_compile_flags}"] = ""

    cuda_defines["%{host_compiler_path}"] = "crosstool_compiler_wrapper"
    cuda_defines["%{host_compiler_warnings}"] = ""

    # nvcc has the system include paths built in and will automatically
    # search them; we cannot work around that, so we add the relevant cuda
    # system paths to the allowed compiler specific include paths.
    cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
        host_compiler_includes + _cuda_include_path(
            repository_ctx,
            cuda_config,
        ) + [cupti_header_dir, cudnn_header_dir],
    )

    # For gcc, do not canonicalize system header paths; some versions of gcc
    # pick the shortest possible path for system includes when creating the
    # .d file - given that includes that are prefixed with "../" multiple
    # time quickly grow longer than the root of the tree, this can lead to
    # bazel's header check failing.
    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""

    file_ext = ""
    nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext)
    cuda_defines["%{compiler_deps}"] = ":crosstool_compiler"

    wrapper_defines = {
        "%{cpu_compiler}": str(cc),
        "%{cuda_version}": cuda_config.cuda_version,
        "%{nvcc_path}": nvcc_path,
        "%{gcc_host_compiler_path}": str(cc),
    }
    repository_ctx.template(
        "crosstool/crosstool_compiler_wrapper",
        tpl_paths["crosstool:crosstool_compiler_wrapper"],
        wrapper_defines,
    )

    verify_build_defines(cuda_defines)

    # Only expand template variables in the BUILD file
    repository_ctx.template(
        "crosstool/BUILD",
        tpl_paths["crosstool:BUILD"],
        cuda_defines,
    )

    # No templating of cc_toolchain_config - use attributes and templatize the
    # BUILD file.
    repository_ctx.template(
        "crosstool/cc_toolchain_config.bzl",
        tpl_paths["crosstool:cc_toolchain_config.bzl"],
        {},
    )

    # Set up cuda_config.h
    repository_ctx.template(
        "cuda/cuda/cuda_config.h",
        tpl_paths["cuda:cuda_config.h"],
        {
            "%{cuda_version}": cuda_config.cuda_version,
            "%{cudart_version}": cuda_config.cudart_version,
            "%{cupti_version}": cuda_config.cupti_version,
            "%{cublas_version}": cuda_config.cublas_version,
            "%{cusolver_version}": cuda_config.cusolver_version,
            "%{curand_version}": cuda_config.curand_version,
            "%{cufft_version}": cuda_config.cufft_version,
            "%{cusparse_version}": cuda_config.cusparse_version,
            "%{cudnn_version}": cuda_config.cudnn_version,
            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
            "%{cuda_compute_capabilities}": ", ".join(
                [cc.split("_")[1] for cc in cuda_config.compute_capabilities],
            ),
        },
    )

    # Set up cuda_config.py, which is used by gen_build_info to provide
    # static build environment info to the API
    repository_ctx.template(
        "cuda/cuda/cuda_config.py",
        tpl_paths["cuda:cuda_config.py"],
        _py_tmpl_dict({
            "cuda_version": cuda_config.cuda_version,
            "cudnn_version": cuda_config.cudnn_version,
            "cuda_compute_capabilities": cuda_config.compute_capabilities,
            "cpu_compiler": str(cc),
        }),
    )

def _get_tensorrt_static_path(repository_ctx):
    return get_host_environ(repository_ctx, _TENSORRT_STATIC_PATH, None)

def _create_local_tensorrt_repository(repository_ctx):
    find_cuda_config_path = repository_ctx.path(
        Label("//build_deps/gpus:find_cuda_config.py"),
    )
    config = find_cuda_config(
        repository_ctx,
        find_cuda_config_path,
        ["tensorrt"],
    )
    tensorrt_version = config["tensorrt_version"]
    cpu_value = get_cpu_value(repository_ctx)

    # Copy the library and header files
    libraries = [
        lib_name(lib, cpu_value, tensorrt_version)
        for lib in _TENSORRT_LIBS
    ]
    library_dir = config["tensorrt_library_dir"] + "/"
    headers = _get_tensorrt_headers(tensorrt_version)
    include_dir = config["tensorrt_include_dir"] + "/"
    copy_rules = [
        make_copy_files_rule(
            repository_ctx,
            name = "tensorrt_lib",
            srcs = [library_dir + library for library in libraries],
            outs = ["tensorrt/lib/" + library for library in libraries],
        ),
        make_copy_files_rule(
            repository_ctx,
            name = "tensorrt_include",
            srcs = [include_dir + header for header in headers],
            outs = ["tensorrt/include/" + header for header in headers],
        ),
    ]

    tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
    if tensorrt_static_path:
        tensorrt_static_path = tensorrt_static_path + "/"
        if _at_least_version(tensorrt_version, "8"):
            raw_static_library_names = _TENSORRT_LIBS
        else:
            raw_static_library_names = _TENSORRT_LIBS + [
                "nvrtc",
                "myelin_compiler",
                "myelin_executor",
                "myelin_pattern_library",
                "myelin_pattern_runtime",
            ]

        static_library_names = [
            "%s_static" % name
            for name in raw_static_library_names
        ]
        static_libraries = [
            lib_name(lib, cpu_value, tensorrt_version, static = True)
            for lib in static_library_names
        ]
        copy_rules = copy_rules + [
            make_copy_files_rule(
                repository_ctx,
                name = "tensorrt_static_lib",
                srcs = [
                    tensorrt_static_path + library
                    for library in static_libraries
                ],
                outs = [
                    "tensorrt/lib/" + library
                    for library in static_libraries
                ],
            ),
        ]

    tpl_paths = {
        "tensorrt/build_defs.bzl": _tpl_path(repository_ctx, "tensorrt:build_defs.bzl"),
        "tensorrt/BUILD": _tpl_path(repository_ctx, "tensorrt:BUILD"),
        "tensorrt/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt:tensorrt_config.h"),
        "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt:tensorrt_config.py"),
    }

    # Set up config file.
    repository_ctx.template(
        "tensorrt/build_defs.bzl",
        tpl_paths["tensorrt/build_defs.bzl"],
        {"%{if_tensorrt}": "if_true"},
    )

    # Set up BUILD file.
    repository_ctx.template(
        "tensorrt/BUILD",
        tpl_paths["tensorrt/BUILD"],
        {
            "%{copy_rules}": "\n".join(copy_rules),
        },
    )

    # Set up tensorrt_config.h, which is used by
    # tensorflow/stream_executor/dso_loader.cc.
    repository_ctx.template(
        "tensorrt/tensorrt_config.h",
        tpl_paths["tensorrt/tensorrt_config.h"],
        {"%{tensorrt_version}": tensorrt_version},
    )

    # Set up tensorrt_config.py, which is used by gen_build_info to provide
    # build environment info to the API
    repository_ctx.template(
        "tensorrt/tensorrt_config.py",
        tpl_paths["tensorrt/tensorrt_config.py"],
        _py_tmpl_dict({
            "tensorrt_version": tensorrt_version,
        }),
    )

def _py_tmpl_dict(d):
    return {"%{cuda_config}": str(d)}

_CUDA_ENVIRONS = [
    _GCC_HOST_COMPILER_PATH,
    _GCC_HOST_COMPILER_PREFIX,
    "NEED_CUDA",
    _CUDA_TOOLKIT_PATH,
    _CUDNN_INSTALL_PATH,
    _CUDA_VERSION,
    _CUDNN_VERSION,
    _CUDA_COMPUTE_CAPABILITIES,
    "NVVMIR_LIBRARY_DIR",
    _PYTHON_BIN_PATH,
    "TMP",
    "TMPDIR",
    "CUDA_PATHS",
]

cuda_configure = repository_rule(
    implementation = _create_local_cuda_repository,
    environ = _CUDA_ENVIRONS,
)

_TENSORRT_ENVIRONS = [
    _TENSORRT_INSTALL_PATH,
    _TENSORRT_VERSION,
    _TENSORRT_STATIC_PATH,
    "CUDA_PATHS",
]

tensorrt_configure = repository_rule(
    implementation = _create_local_tensorrt_repository,
    environ = _TENSORRT_ENVIRONS,
)


================================================
FILE: build_deps/gpus/crosstool/BUILD
================================================


================================================
FILE: build_deps/gpus/crosstool/BUILD.tpl
================================================
# This file is expanded from a template by cuda_configure.bzl
# Update cuda_configure.bzl#verify_build_defines when adding new variables.

load(":cc_toolchain_config.bzl", "cc_toolchain_config")

licenses(["restricted"])

package(default_visibility = ["//visibility:public"])

toolchain(
    name = "toolchain-linux-x86_64",
    exec_compatible_with = [
        "@platforms//os:linux",
        "@platforms//cpu:x86_64",
    ],
    target_compatible_with = [
        "@platforms//os:linux",
        "@platforms//cpu:x86_64",
    ],
    toolchain = ":cc-compiler-local",
    toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
)

cc_toolchain_suite(
    name = "toolchain",
    toolchains = {
        "local|compiler": ":cc-compiler-local",
        "darwin|compiler": ":cc-compiler-darwin",
        "arm": ":cc-compiler-local",
        "aarch64": ":cc-compiler-local",
        "k8": ":cc-compiler-local",
        "piii": ":cc-compiler-local",
        "ppc": ":cc-compiler-local",
        "darwin": ":cc-compiler-darwin",
    },
)

cc_toolchain(
    name = "cc-compiler-local",
    all_files = "%{compiler_deps}",
    compiler_files = "%{compiler_deps}",
    ar_files = "%{compiler_deps}",
    as_files = "%{compiler_deps}",
    dwp_files = ":empty",
    linker_files = "%{compiler_deps}",
    objcopy_files = ":empty",
    strip_files = ":empty",
    # To support linker flags that need to go to the start of command line
    # we need the toolchain to support parameter files. Parameter files are
    # last on the command line and contain all shared libraries to link, so all
    # regular options will be left of them.
    supports_param_files = 1,
    toolchain_identifier = "local_linux",
    toolchain_config = ":cc-compiler-local-config",
)

cc_toolchain_config(
    name = "cc-compiler-local-config",
    cpu = "local",
    builtin_include_directories = [%{cxx_builtin_include_directories}],
    extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
    host_compiler_path = "%{host_compiler_path}",
    host_compiler_prefix = "%{host_compiler_prefix}",
    host_compiler_warnings = [%{host_compiler_warnings}],
    host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
    linker_bin_path = "%{linker_bin_path}",
    builtin_sysroot = "%{builtin_sysroot}",
    cuda_path = "%{cuda_toolkit_path}",
    compiler = "%{compiler}",
)

cc_toolchain(
    name = "cc-compiler-darwin",
    all_files = "%{compiler_deps}",
    compiler_files = "%{compiler_deps}",
    ar_files = "%{compiler_deps}",
    as_files = "%{compiler_deps}",
    dwp_files = ":empty",
    linker_files = "%{compiler_deps}",
    objcopy_files = ":empty",
    strip_files = ":empty",
    supports_param_files = 0,
    toolchain_identifier = "local_darwin",
    toolchain_config = ":cc-compiler-local-darwin",
)

cc_toolchain_config(
    name = "cc-compiler-local-darwin",
    cpu = "darwin",
    builtin_include_directories = [%{cxx_builtin_include_directories}],
    extra_no_canonical_prefixes_flags = [%{extra_no_canonical_prefixes_flags}],
    host_compiler_path = "%{host_compiler_path}",
    host_compiler_prefix = "%{host_compiler_prefix}",
    host_compiler_warnings = [%{host_compiler_warnings}],
    host_unfiltered_compile_flags = [%{unfiltered_compile_flags}],
    linker_bin_path = "%{linker_bin_path}",
)


filegroup(
    name = "empty",
    srcs = [],
)

filegroup(
    name = "crosstool_compiler",
    srcs = ["crosstool_compiler_wrapper"],
)


================================================
FILE: build_deps/gpus/crosstool/cc_toolchain_config.bzl.tpl
================================================
"""cc_toolchain_config rule for configuring CUDA toolchains on Linux, Mac, and Windows."""

load(
    "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
    "action_config",
    "artifact_name_pattern",
    "env_entry",
    "env_set",
    "feature",
    "feature_set",
    "flag_group",
    "flag_set",
    "tool",
    "tool_path",
    "variable_with_value",
    "with_feature_set",
)
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")

def all_assembly_actions():
    return [
        ACTION_NAMES.assemble,
        ACTION_NAMES.preprocess_assemble,
    ]

def all_compile_actions():
    return [
        ACTION_NAMES.assemble,
        ACTION_NAMES.c_compile,
        ACTION_NAMES.cpp_compile,
        ACTION_NAMES.cpp_header_parsing,
        ACTION_NAMES.cpp_module_codegen,
        ACTION_NAMES.cpp_module_compile,
        ACTION_NAMES.linkstamp_compile,
        ACTION_NAMES.preprocess_assemble,
    ]

def all_c_compile_actions():
    return [
        ACTION_NAMES.c_compile,
    ]

def all_cpp_compile_actions():
    return [
        ACTION_NAMES.cpp_compile,
        ACTION_NAMES.cpp_header_parsing,
        ACTION_NAMES.cpp_module_codegen,
        ACTION_NAMES.cpp_module_compile,
        ACTION_NAMES.linkstamp_compile,
    ]

def all_preprocessed_actions():
    return [
        ACTION_NAMES.c_compile,
        ACTION_NAMES.cpp_compile,
        ACTION_NAMES.cpp_header_parsing,
        ACTION_NAMES.cpp_module_codegen,
        ACTION_NAMES.cpp_module_compile,
        ACTION_NAMES.linkstamp_compile,
        ACTION_NAMES.preprocess_assemble,
    ]

def all_link_actions():
    return [
        ACTION_NAMES.cpp_link_executable,
        ACTION_NAMES.cpp_link_dynamic_library,
        ACTION_NAMES.cpp_link_nodeps_dynamic_library,
    ]

def all_executable_link_actions():
    return [
        ACTION_NAMES.cpp_link_executable,
    ]

def all_shared_library_link_actions():
    return [
        ACTION_NAMES.cpp_link_dynamic_library,
        ACTION_NAMES.cpp_link_nodeps_dynamic_library,
    ]

def all_archive_actions():
    return [ACTION_NAMES.cpp_link_static_library]

def all_strip_actions():
    return [ACTION_NAMES.strip]

def _library_to_link(flag_prefix, value, iterate = None):
    return flag_group(
        flags = [
            "{}%{{libraries_to_link.{}}}".format(
                flag_prefix,
                iterate if iterate else "name",
            ),
        ],
        iterate_over = ("libraries_to_link." + iterate if iterate else None),
        expand_if_equal = variable_with_value(
            name = "libraries_to_link.type",
            value = value,
        ),
    )

def _surround_static_library(prefix, suffix):
    return [
        flag_group(
            flags = [prefix, "%{libraries_to_link.name}", suffix],
            expand_if_true = "libraries_to_link.is_whole_archive",
        ),
        flag_group(
            flags = ["%{libraries_to_link.name}"],
            expand_if_false = "libraries_to_link.is_whole_archive",
        ),
    ]

def _prefix_static_library(prefix):
    return [
        flag_group(
            flags = ["%{libraries_to_link.name}"],
            expand_if_false = "libraries_to_link.is_whole_archive",
        ),
        flag_group(
            flags = [prefix + "%{libraries_to_link.name}"],
            expand_if_true = "libraries_to_link.is_whole_archive",
        ),
    ]

def _static_library_to_link(alwayslink_prefix, alwayslink_suffix = None):
    if alwayslink_suffix:
        flag_groups = _surround_static_library(alwayslink_prefix, alwayslink_suffix)
    else:
        flag_groups = _prefix_static_library(alwayslink_prefix)
    return flag_group(
        flag_groups = flag_groups,
        expand_if_equal = variable_with_value(
            name = "libraries_to_link.type",
            value = "static_library",
        ),
    )

def _iterate_flag_group(iterate_over, flags = [], flag_groups = []):
    return flag_group(
        iterate_over = iterate_over,
        expand_if_available = iterate_over,
        flag_groups = flag_groups,
        flags = flags,
    )

def _libraries_to_link_group(flavour):
    if flavour == "linux":
        return _iterate_flag_group(
            iterate_over = "libraries_to_link",
            flag_groups = [
                flag_group(
                    flags = ["-Wl,--start-lib"],
                    expand_if_equal = variable_with_value(
                        name = "libraries_to_link.type",
                        value = "object_file_group",
                    ),
                ),
                _library_to_link("", "object_file_group", "object_files"),
                flag_group(
                    flags = ["-Wl,--end-lib"],
                    expand_if_equal = variable_with_value(
                        name = "libraries_to_link.type",
                        value = "object_file_group",
                    ),
                ),
                _library_to_link("", "object_file"),
                _library_to_link("", "interface_library"),
                _static_library_to_link("-Wl,-whole-archive", "-Wl,-no-whole-archive"),
                _library_to_link("-l", "dynamic_library"),
                _library_to_link("-l:", "versioned_dynamic_library"),
            ],
        )
    elif flavour == "darwin":
        return _iterate_flag_group(
            iterate_over = "libraries_to_link",
            flag_groups = [
                _library_to_link("", "object_file_group", "object_files"),
                _library_to_link("", "object_file"),
                _library_to_link("", "interface_library"),
                _static_library_to_link("-Wl,-force_load,"),
                _library_to_link("-l", "dynamic_library"),
                _library_to_link("-l:", "versioned_dynamic_library"),
            ],
        )

def _action_configs_with_tool(path, actions):
    return [
        action_config(
            action_name = name,
            enabled = True,
            tools = [tool(path = path)],
        )
        for name in actions
    ]

def _action_configs(assembly_path, c_compiler_path, cc_compiler_path, archiver_path, linker_path, strip_path):
    return _action_configs_with_tool(
        assembly_path,
        all_assembly_actions(),
    ) + _action_configs_with_tool(
        c_compiler_path,
        all_c_compile_actions(),
    ) + _action_configs_with_tool(
        cc_compiler_path,
        all_cpp_compile_actions(),
    ) + _action_configs_with_tool(
        archiver_path,
        all_archive_actions(),
    ) + _action_configs_with_tool(
        linker_path,
        all_link_actions(),
    ) + _action_configs_with_tool(
        strip_path,
        all_strip_actions(),
    )

def _tool_paths(cpu, ctx):
    if cpu in ["local", "darwin"]:
        return [
            tool_path(name = "gcc", path = ctx.attr.host_compiler_path),
            tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + (
                "/ar" if cpu == "local" else "/libtool"
            )),
            tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"),
            tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"),
            tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"),
            tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"),
            tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"),
            tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"),
            tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"),
            tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"),
            tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"),
        ]
    else:
        fail("Unreachable")

def _sysroot_group():
    return flag_group(
        flags = ["--sysroot=%{sysroot}"],
        expand_if_available = "sysroot",
    )

def _no_canonical_prefixes_group(extra_flags):
    return flag_group(
        flags = [
            "-no-canonical-prefixes",
        ] + extra_flags,
    )

def _cuda_set(cuda_path, actions):
    if cuda_path:
        return [flag_set(
            actions = actions,
            flag_groups = [
                flag_group(
                    flags = ["--cuda-path=" + cuda_path],
                ),
            ],
        )]
    else:
        return []

def _nologo():
    return flag_group(flags = ["/nologo"])

def _features(cpu, compiler, ctx):
    if cpu in ["local", "darwin"]:
        return [
            feature(name = "no_legacy_features"),
            feature(
                name = "all_compile_flags",
                enabled = True,
                flag_sets = [
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [
                            flag_group(
                                flags = ["-MD", "-MF", "%{dependency_file}"],
                                expand_if_available = "dependency_file",
                            ),
                            flag_group(
                                flags = ["-gsplit-dwarf"],
                                expand_if_available = "per_object_debug_info_file",
                            ),
                        ],
                    ),
                    flag_set(
                        actions = all_preprocessed_actions(),
                        flag_groups = [
                            flag_group(
                                flags = ["-frandom-seed=%{output_file}"],
                                expand_if_available = "output_file",
                            ),
                            _iterate_flag_group(
                                flags = ["-D%{preprocessor_defines}"],
                                iterate_over = "preprocessor_defines",
                            ),
                            _iterate_flag_group(
                                flags = ["-include", "%{includes}"],
                                iterate_over = "includes",
                            ),
                            _iterate_flag_group(
                                flags = ["-iquote", "%{quote_include_paths}"],
                                iterate_over = "quote_include_paths",
                            ),
                            _iterate_flag_group(
                                flags = ["-I%{include_paths}"],
                                iterate_over = "include_paths",
                            ),
                            _iterate_flag_group(
                                flags = ["-isystem", "%{system_include_paths}"],
                                iterate_over = "system_include_paths",
                            ),
                            _iterate_flag_group(
                                flags = ["-F", "%{framework_include_paths}"],
                                iterate_over = "framework_include_paths",
                            ),
                        ],
                    ),
                    flag_set(
                        actions = all_cpp_compile_actions(),
                        flag_groups = [],
                    ),
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [
                            flag_group(
                                flags = [
                                    "-Wno-builtin-macro-redefined",
                                    "-D__DATE__=\"redacted\"",
                                    "-D__TIMESTAMP__=\"redacted\"",
                                    "-D__TIME__=\"redacted\"",
                                ],
                            ),
                            flag_group(
                                flags = ["-fPIC"],
                                expand_if_available = "pic",
                            ),
                            flag_group(
                                flags = ["-fPIE"],
                                expand_if_not_available = "pic",
                            ),
                            flag_group(
                                flags = [
                                    "-U_FORTIFY_SOURCE",
                                    "-D_FORTIFY_SOURCE=1",
                                    "-fstack-protector",
                                    "-Wall",
                                ] + ctx.attr.host_compiler_warnings + [
                                    "-fno-omit-frame-pointer",
                                ],
                            ),
                            _no_canonical_prefixes_group(
                                ctx.attr.extra_no_canonical_prefixes_flags,
                            ),
                        ],
                    ),
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [flag_group(flags = ["-DNDEBUG"])],
                        with_features = [with_feature_set(features = ["disable-assertions"])],
                    ),
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [
                            flag_group(
                                flags = [
                                    "-g0",
                                    "-O2",
                                    "-ffunction-sections",
                                    "-fdata-sections",
                                ],
                            ),
                        ],
                        with_features = [with_feature_set(features = ["opt"])],
                    ),
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [flag_group(flags = ["-g"])],
                        with_features = [with_feature_set(features = ["dbg"])],
                    ),
                ] + _cuda_set(
                    ctx.attr.cuda_path,
                    all_compile_actions(),
                ) + [
                    flag_set(
                        actions = all_compile_actions(),
                        flag_groups = [
                            _iterate_flag_group(
                                flags = ["%{user_compile_flags}"],
                                iterate_over = "user_compile_flags",
                            ),
                            _sysroot_group(),
                            flag_group(
                                expand_if_available = "source_file",
                                flags = ["-c", "%{source_file}"],
                            ),
                            flag_group(
                                expand_if_available = "output_assembly_file",
                                flags = ["-S"],
                            ),
                            flag_group(
                                expand_if_available = "output_preprocess_file",
                                flags = ["-E"],
                            ),
                            flag_group(
                                expand_if_available = "output_file",
                                flags = ["-o", "%{output_file}"],
                            ),
                        ],
                    ),
                ],
            ),
            feature(
                name = "all_archive_flags",
                enabled = True,
                flag_sets = [
                    flag_set(
                        actions = all_archive_actions(),
                        flag_groups = [
                            flag_group(
                                expand_if_available = "linker_param_file",
                                flags = ["@%{linker_param_file}"],
                            ),
                            flag_group(flags = ["rcsD"]),
                            flag_group(
                                flags = ["%{output_execpath}"],
                                expand_if_available = "output_execpath",
                            ),
                            flag_group(
                                iterate_over = "libraries_to_link",
                                flag_groups = [
                                    flag_group(
                                        flags = ["%{libraries_to_link.name}"],
                                        expand_if_equal = variable_with_value(
                                            name = "libraries_to_link.type",
                                            value = "object_file",
                                        ),
                                    ),
                                    flag_group(
                                        flags = ["%{libraries_to_link.object_files}"],
                                        iterate_over = "libraries_to_link.object_files",
                                        expand_if_equal = variable_with_value(
                                            name = "libraries_to_link.type",
                                            value = "object_file_group",
                                        ),
                                    ),
                                ],
                                expand_if_available = "libraries_to_link",
                            ),
                        ],
                    ),
                ],
            ),
            feature(
                name = "all_link_flags",
                enabled = True,
                flag_sets = [
                    flag_set(
                        actions = all_shared_library_link_actions(),
                        flag_groups = [flag_group(flags = ["-shared"])],
                    ),
                    flag_set(
                        actions = all_link_actions(),
                        flag_groups = ([
                            flag_group(flags = ["-Wl,-no-as-needed"])
                        ] if cpu == "local" else []) + ([
                            flag_group(flags = ["-B" + ctx.attr.linker_bin_path])
                        ] if ctx.attr.linker_bin_path else []) + [
                            flag_group(
                                flags = ["@%{linker_param_file}"],
                                expand_if_available = "linker_param_file",
                            ),
                            _iterate_flag_group(
                                flags = ["%{linkstamp_paths}"],
                                iterate_over = "linkstamp_paths",
                            ),
                            flag_group(
                                flags = ["-o", "%{output_execpath}"],
                                expand_if_available = "output_execpath",
                            ),
                            _iterate_flag_group(
                                flags = ["-L%{library_search_directories}"],
                                iterate_over = "library_search_directories",
                            ),
                            _iterate_flag_group(
                                iterate_over = "runtime_library_search_directories",
                                flags = [
                                    "-Wl,-rpath,$ORIGIN/%{runtime_library_search_directories}",
                                ] if cpu == "local" else [
                                    "-Wl,-rpath,@loader_path/%{runtime_library_search_directories}",
                                ],
                            ),
                            _libraries_to_link_group("darwin" if cpu == "darwin" else "linux"),
                            _iterate_flag_group(
                                flags = ["%{user_link_flags}"],
                                iterate_over = "user_link_flags",
                            ),
                            flag_group(
                                flags = ["-Wl,--gdb-index"],
                                expand_if_available = "is_using_fission",
                            ),
                            flag_group(
                                flags = ["-Wl,-S"],
                                expand_if_available = "strip_debug_symbols",
                            ),
                            flag_group(flags = ["-lc++" if cpu == "darwin" else "-lstdc++"]),
                            _no_canonical_prefixes_group(
                                ctx.attr.extra_no_canonical_prefixes_flags,
                            ),
                        ],
                    ),
                    flag_set(
                        actions = all_executable_link_actions(),
                        flag_groups = [flag_group(flags = ["-pie"])],
                    ),
                ] + ([
                    flag_set(
                        actions = all_link_actions(),
                        flag_groups = [flag_group(flags = [
                            "-Wl,-z,relro,-z,now",
                        ])],
                    ),
                ] if cpu == "local" else []) + ([
                    flag_set(
                        actions = all_link_actions(),
                        flag_groups = [
                            flag_group(flags = ["-Wl,--gc-sections"]),
                            flag_group(
                                flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"],
                            ),
                        ],
                    ),
                ] if cpu == "local" else []) + ([
                    flag_set(
                        actions = all_link_actions(),
                        flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])],
                    ),
                ] if cpu == "darwin" else []) + _cuda_set(
                    ctx.attr.cuda_path,
                    all_link_actions(),
                ) + [
                    flag_set(
                        actions = all_link_actions(),
                        flag_groups = [
                            _sysroot_group(),
                        ],
                    ),
                ],
            ),
            feature(name = "disable-assertions"),
            feature(
                name = "opt",
                implies = ["disable-assertions"],
            ),
            feature(name = "fastbuild"),
            feature(name = "dbg"),
            feature(name = "supports_dynamic_linker", enabled = True),
            feature(name = "pic", enabled = True),
            feature(name = "supports_pic", enabled = True),
            feature(name = "has_configured_linker_path", enabled = True),
        ]
    else:
        fail("Unreachable")

def _impl(ctx):
    cpu = ctx.attr.cpu
    compiler = ctx.attr.compiler

    if (cpu == "darwin"):
        toolchain_identifier = "local_darwin"
        target_cpu = "darwin"
        target_libc = "macosx"
        compiler = "compiler"
        action_configs = _action_configs(
            assembly_path = ctx.attr.host_compiler_path,
            c_compiler_path = ctx.attr.host_compiler_path,
            cc_compiler_path = ctx.attr.host_compiler_path,
            archiver_path = ctx.attr.host_compiler_prefix + "/libtool",
            linker_path = ctx.attr.host_compiler_path,
            strip_path = ctx.attr.host_compiler_prefix + "/strip",
        )
        artifact_name_patterns = []
    elif (cpu == "local"):
        toolchain_identifier = "local_linux"
        target_cpu = "local"
        target_libc = "local"
        action_configs = _action_configs(
            assembly_path = ctx.attr.host_compiler_path,
            c_compiler_path = ctx.attr.host_compiler_path,
            cc_compiler_path = ctx.attr.host_compiler_path,
            archiver_path = ctx.attr.host_compiler_prefix + "/ar",
            linker_path = ctx.attr.host_compiler_path,
            strip_path = ctx.attr.host_compiler_prefix + "/strip",
        )
        artifact_name_patterns = []
    else:
        fail("Unreachable")

    out = ctx.actions.declare_file(ctx.label.name)
    ctx.actions.write(out, "Fake executable")
    return [
        cc_common.create_cc_toolchain_config_info(
            ctx = ctx,
            features = _features(cpu, compiler, ctx),
            action_configs = action_configs,
            artifact_name_patterns = artifact_name_patterns,
            cxx_builtin_include_directories = ctx.attr.builtin_include_directories,
            toolchain_identifier = toolchain_identifier,
            host_system_name = "local",
            target_system_name = "local",
            target_cpu = target_cpu,
            target_libc = target_libc,
            compiler = compiler,
            abi_version = "local",
            abi_libc_version = "local",
            tool_paths = _tool_paths(cpu, ctx),
            make_variables = [],
            builtin_sysroot = ctx.attr.builtin_sysroot,
            cc_target_os = None,
        ),
        DefaultInfo(
            executable = out,
        ),
    ]

cc_toolchain_config = rule(
    implementation = _impl,
    attrs = {
        "cpu": attr.string(mandatory = True, values = ["darwin", "local"]),
        "compiler": attr.string(values = ["unknown"], default = "unknown"),
        "builtin_include_directories": attr.string_list(),
        "extra_no_canonical_prefixes_flags": attr.string_list(),
        "host_compiler_path": attr.string(),
        "host_compiler_prefix": attr.string(),
        "host_compiler_warnings": attr.string_list(),
        "host_unfiltered_compile_flags": attr.string_list(),
        "linker_bin_path": attr.string(),
        "builtin_sysroot": attr.string(),
        "cuda_path": attr.string(),
    },
    provides = [CcToolchainConfigInfo],
    executable = True,
)


================================================
FILE: build_deps/gpus/crosstool/crosstool_compiler_wrapper.tpl
================================================
#!/usr/bin/env python
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Crosstool wrapper for compiling CUDA programs.

SYNOPSIS:
  crosstool_compiler_wrapper [options passed in by cc_library()
                                or cc_binary() rule]

DESCRIPTION:
  This script is expected to be called by the cc_library() or cc_binary() bazel
  rules. When the option "-x cuda" is present in the list of arguments passed
  to this script, it invokes the nvcc CUDA compiler. Most arguments are passed
  as is as a string to --compiler-options of nvcc. When "-x cuda" is not
  present, this wrapper invokes hybrid_driver_is_not_gcc with the input
  arguments as is.
"""

__author__ = 'keveman@google.com (Manjunath Kudlur)'

import os
import pipes
import re
import subprocess
import sys
from argparse import ArgumentParser

# Template values set by cuda_autoconf.
CPU_COMPILER = ('%{cpu_compiler}')
GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')

NVCC_PATH = '%{nvcc_path}'
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
NVCC_VERSION = '%{cuda_version}'


def Log(s):
    print('gpus/crosstool: {0}'.format(s))


def GetOptionValue(argv, option):
    """Extract the list of values for option from the argv list.

  Args:
    argv: A list of strings, possibly the argv passed to main().
    option: The option whose value to extract, with the leading '-'.

  Returns:
    A list of values, either directly following the option,
    (eg., -opt val1 val2) or values collected from multiple occurrences of
    the option (eg., -opt val1 -opt val2).
  """

    parser = ArgumentParser()
    parser.add_argument(option, nargs='*', action='append')
    option = option.lstrip('-').replace('-', '_')
    args, _ = parser.parse_known_args(argv)
    if not args or not vars(args)[option]:
        return []
    else:
        return sum(vars(args)[option], [])


def GetHostCompilerOptions(argv):
    """Collect the -isystem, -iquote, and --sysroot option values from argv.

  Args:
    argv: A list of strings, possibly the argv passed to main().

  Returns:
    The string that can be used as the --compiler-options to nvcc.
  """

    parser = ArgumentParser()
    parser.add_argument('-isystem', nargs='*', action='append')
    parser.add_argument('-iquote', nargs='*', action='append')
    parser.add_argument('--sysroot', nargs=1)
    parser.add_argument('-g', nargs='*', action='append')
    parser.add_argument('-fno-canonical-system-headers', action='store_true')
    parser.add_argument('-no-canonical-prefixes', action='store_true')

    args, _ = parser.parse_known_args(argv)

    opts = ''

    if args.isystem:
        opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
    if args.iquote:
        opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
    if args.g:
        opts += ' -g' + ' -g'.join(sum(args.g, []))
    if args.fno_canonical_system_headers:
        opts += ' -fno-canonical-system-headers'
    if args.no_canonical_prefixes:
        opts += ' -no-canonical-prefixes'
    if args.sysroot:
        opts += ' --sysroot ' + args.sysroot[0]

    return opts


def _update_options(nvcc_options):
    if NVCC_VERSION in ("7.0", ):
        return nvcc_options

    update_options = {"relaxed-constexpr": "expt-relaxed-constexpr"}
    return [
        update_options[opt] if opt in update_options else opt
        for opt in nvcc_options
    ]


def GetNvccOptions(argv):
    """Collect the -nvcc_options values from argv.

  Args:
    argv: A list of strings, possibly the argv passed to main().

  Returns:
    The string that can be passed directly to nvcc.
  """

    parser = ArgumentParser()
    parser.add_argument('-nvcc_options', nargs='*', action='append')

    args, _ = parser.parse_known_args(argv)

    if args.nvcc_options:
        options = _update_options(sum(args.nvcc_options, []))
        return ' '.join(['--' + a for a in options])
    return ''


def system(cmd):
    """Invokes cmd with os.system().

  Args:
    cmd: The command.

  Returns:
    The exit code if the process exited with exit() or -signal
    if the process was terminated by a signal.
  """
    retv = os.system(cmd)
    if os.WIFEXITED(retv):
        return os.WEXITSTATUS(retv)
    else:
        return -os.WTERMSIG(retv)


def InvokeNvcc(argv, log=False):
    """Call nvcc with arguments assembled from argv.

  Args:
    argv: A list of strings, possibly the argv passed to main().
    log: True if logging is requested.

  Returns:
    The return value of calling system('nvcc ' + args)
  """

    host_compiler_options = GetHostCompilerOptions(argv)
    nvcc_compiler_options = GetNvccOptions(argv)
    opt_option = GetOptionValue(argv, '-O')
    m_options = GetOptionValue(argv, '-m')
    m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
    include_options = GetOptionValue(argv, '-I')
    out_file = GetOptionValue(argv, '-o')
    depfiles = GetOptionValue(argv, '-MF')
    defines = GetOptionValue(argv, '-D')
    defines = ''.join([' -D' + define for define in defines])
    undefines = GetOptionValue(argv, '-U')
    undefines = ''.join([' -U' + define for define in undefines])
    std_options = GetOptionValue(argv, '-std')
    nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
    nvcc_std_map = {}
    if int(NVCC_VERSION.split('.')[0]) >= 11:
        nvcc_std_map["c++1z"] = "c++17"
        nvcc_allowed_std_options += ["c++17", "c++1z"]
    std_options = ''.join([
        ' -std=' + (nvcc_std_map[define] if define in nvcc_std_map else define)
        for define in std_options if define in nvcc_allowed_std_options
    ][-1:])
    fatbin_options = ''.join([
        ' --fatbin-options=' + option
        for option in GetOptionValue(argv, '-Xcuda-fatbinary')
    ])

    # The list of source files get passed after the -c option. I don't know of
    # any other reliable way to just get the list of source files to be compiled.
    src_files = GetOptionValue(argv, '-c')

    # Pass -w through from host to nvcc, but don't do anything fancier with
    # warnings-related flags, since they're not necessarily the same across
    # compilers.
    warning_options = ' -w' if '-w' in argv else ''

    if len(src_files) == 0:
        return 1
    if len(out_file) != 1:
        return 1

    opt = (' -O2' if
           (len(opt_option) > 0 and int(opt_option[0]) > 0) else ' -g')

    includes = (' -I ' + ' -I '.join(include_options)
                if len(include_options) > 0 else '')

    # Unfortunately, there are other options that have -c prefix too.
    # So allowing only those look like C/C++ files.
    src_files = [
        f for f in src_files
        if re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C|\.cu|\.cuh$', f)
    ]
    srcs = ' '.join(src_files)
    out = ' -o ' + out_file[0]

    nvccopts = '-D_FORCE_INLINES '
    capabilities_sm = set(GetOptionValue(argv, "--cuda-gpu-arch"))
    capabilities_compute = set(GetOptionValue(argv, '--cuda-include-ptx'))
    # When both "code=sm_xy" and "code=compute_xy" are requested for a single
    # arch, they can be combined using "code=xy,compute_xy" which avoids a
    # redundant PTX generation during compilation.
    capabilities_both = capabilities_sm.intersection(capabilities_compute)
    for capability in capabilities_both:
        capability = capability[len('sm_'):]
        nvccopts += r'-gencode=arch=compute_%s,code=\"sm_%s,compute_%s\" ' % (
            capability, capability, capability)
    for capability in capabilities_sm - capabilities_both:
        capability = capability[len('sm_'):]
        nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability,
                                                                   capability)
    for capability in capabilities_compute - capabilities_both:
        capability = capability[len('sm_'):]
        nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (
            capability, capability)
    nvccopts += nvcc_compiler_options
    nvccopts += undefines
    nvccopts += defines
    nvccopts += std_options
    nvccopts += m_options
    nvccopts += warning_options
    # Force C++17 dialect (note, everything in just one string!)
    nvccopts += ' --std c++17 '
    nvccopts += fatbin_options

    if depfiles:
        # Generate the dependency file
        depfile = depfiles[0]
        cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' +
               host_compiler_options + '"' + ' --compiler-bindir=' +
               GCC_HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes +
               ' ' + srcs + ' -M -o ' + depfile)
        if log:
            Log(cmd)
        exit_status = system(cmd)
        if exit_status != 0:
            return exit_status

    cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' +
           host_compiler_options + ' -fPIC"' + ' --compiler-bindir=' +
           GCC_HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes +
           ' -c ' + srcs + out)

    # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
    # Need to investigate and fix.
    cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
    if log:
        Log(cmd)
    return system(cmd)


def main():
    parser = ArgumentParser()
    parser.add_argument('-x', nargs=1)
    parser.add_argument('--cuda_log', action='store_true')
    args, leftover = parser.parse_known_args(sys.argv[1:])

    if args.x and args.x[0] == 'cuda':
        if args.cuda_log:
            Log('-x cuda')
        leftover = [pipes.quote(s) for s in leftover]
        if args.cuda_log:
            Log('using nvcc')
        return InvokeNvcc(leftover, log=args.cuda_log)

    # Strip our flags before passing through to the CPU compiler for files which
    # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
    # We not only want to pass -x to the CPU compiler, but also keep it in its
    # relative location in the argv list (the compiler is actually sensitive to
    # this).
    cpu_compiler_flags = [
        flag for flag in sys.argv[1:] if not flag.startswith(('--cuda_log'))
    ]

    return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)


if __name__ == '__main__':
    sys.exit(main())


================================================
FILE: build_deps/gpus/cuda/BUILD
================================================


================================================
FILE: build_deps/gpus/cuda/BUILD.tpl
================================================
load(":build_defs.bzl", "cuda_header_library")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")

licenses(["restricted"])  # MPL2, portions GPL v3, LGPL v3, BSD-like

package(default_visibility = ["//visibility:public"])

bool_flag(
    name = "enable_cuda",
    build_setting_default = False,
)

config_setting(
    name = "is_cuda_enabled",
    flag_values = {":enable_cuda": "True"},
)


# Config setting whether built with CUDA support using nvcc.
#
# TODO(b/174244321), DEPRECATED: this target will be removed when all users
# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc.
selects.config_setting_group(
    name = "using_nvcc",
    match_all = [
        "//:is_cuda_enabled",
        "//:is_cuda_compiler_nvcc",
    ],
)

config_setting(
    name = "_opt",
    values = {"compilation_mode": "opt"},
    visibility = ["//visibility:private"],
)

# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"'
# All clients including TensorFlow should use these directives.
cuda_header_library(
    name = "cuda_headers",
    hdrs = [
        "cuda/cuda_config.h",
        ":cuda-include",
    ],
    include_prefix = "third_party/gpus",
    includes = [
        ".",  # required to include cuda/cuda/cuda_config.h as cuda/config.h
        "cuda/include",
    ],
)

cc_library(
    name = "cudart_static",
    srcs = ["cuda/lib/%{cudart_static_lib}"],
    linkopts = [
        "-ldl",
        "-lpthread",
        %{cudart_static_linkopt}
    ],
)

cc_library(
    name = "cuda_driver",
    srcs = ["cuda/lib/%{cuda_driver_lib}"],
)

cc_library(
    name = "cudart",
    srcs = ["cuda/lib/%{cudart_lib}"],
    data = ["cuda/lib/%{cudart_lib}"],
    linkstatic = 1,
)

cuda_header_library(
    name = "cublas_headers",
    hdrs = [":cublas-include"],
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["cublas/include"],
    strip_include_prefix = "cublas/include",
    deps = [":cuda_headers"],
)

cuda_header_library(
    name = "cusolver_headers",
    hdrs = [":cusolver-include"],
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["cusolver/include"],
    strip_include_prefix = "cusolver/include",
    deps = [":cuda_headers"],
)

cuda_header_library(
    name = "cufft_headers",
    hdrs = [":cufft-include"],
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["cufft/include"],
    strip_include_prefix = "cufft/include",
    deps = [":cuda_headers"],
)

cuda_header_library(
    name = "cusparse_headers",
    hdrs = [":cusparse-include"],
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["cusparse/include"],
    strip_include_prefix = "cusparse/include",
    deps = [":cuda_headers"],
)

cuda_header_library(
    name = "curand_headers",
    hdrs = [":curand-include"],
    include_prefix = "third_party/gpus/cuda/include",
    includes = ["curand/include"],
    strip_include_prefix = "curand/include",
    deps = [":cuda_headers"],
)

cc_library(
    name = "cublas",
    srcs = ["cuda/lib/%{cublas_lib}"],
    data = ["cuda/lib/%{cublas_lib}"],
    linkstatic = 1,
)

cc_library(
    name = "cublasLt",
    srcs = ["cuda/lib/%{cublasLt_lib}"],
    data = ["cuda/lib/%{cublasLt_lib}"],
    linkstatic = 1,
)

cc_library(
    name = "cusolver",
    srcs = ["cuda/lib/%{cusolver_lib}"],
    data = ["cuda/lib/%{cusolver_lib}"],
    linkopts = ["-lgomp"],
    linkstatic = 1,
)

cc_library(
    name = "cudnn",
    srcs = ["cuda/lib/%{cudnn_lib}"],
    data = ["cuda/lib/%{cudnn_lib}"],
    linkstatic = 1,
)

cc_library(
    name = "cudnn_header",
    hdrs = [":cudnn-include"],
    include_prefix = "third_party/gpus/cudnn",
    strip_include_prefix = "cudnn/include",
    deps = [":cuda_headers"],
)

cc_library(
    name = "cufft",
    srcs = ["cuda/lib/%{cufft_lib}"],
    data = ["cuda/lib/%{cufft_lib}"],
    linkstatic = 1,
)

cc_library(
    name = "curand",
    srcs = ["cuda/lib/%{curand_lib}"],
    data = ["cuda/lib/%{curand_lib}"],
    linkstatic = 1,
)

cc_library(
    name = "cuda",
    deps = [
        ":cublas",
        ":cublasLt",
        ":cuda_headers",
        ":cudart",
        ":cudnn",
        ":cufft",
        ":curand",
    ],
)

alias(
    name = "cub_headers",
    actual = "%{cub_actual}",
)

cuda_header_library(
    name = "cupti_headers",
    hdrs = [":cuda-extras"],
    include_prefix = "third_party/gpus",
    includes = ["cuda/extras/CUPTI/include/"],
    deps = [":cuda_headers"],
)

cc_library(
    name = "cupti_dsos",
    data = ["cuda/lib/%{cupti_lib}"],
)

cc_library(
    name = "cusparse",
    srcs = ["cuda/lib/%{cusparse_lib}"],
    data = ["cuda/lib/%{cusparse_lib}"],
    linkopts = ["-lgomp"],
    linkstatic = 1,
)

cc_library(
    name = "libdevice_root",
    data = [":cuda-nvvm"],
)

bzl_library(
    name = "build_defs_bzl",
    srcs = ["build_defs.bzl"],
    deps = [
        "@bazel_skylib//lib:selects",
    ],
)

py_library(
    name = "cuda_config_py",
    srcs = ["cuda/cuda_config.py"],
)

%{copy_rules}


================================================
FILE: build_deps/gpus/cuda/build_defs.bzl.tpl
================================================
# Macros for building CUDA code.
def cuda_default_copts():
    """Default options for all CUDA compilations."""
    return [
        "-x",
        "cuda",
        "-DUSE_CUDA=1",
        "-Xcuda-fatbinary=--compress-all",
    ] + %{cuda_extra_copts}


def cuda_gpu_architectures():
    """Returns a list of supported GPU architectures."""
    return %{cuda_gpu_architectures}


def cuda_header_library(name,
                        hdrs,
                        include_prefix=None,
                        strip_include_prefix=None,
                        deps=[],
                        **kwargs):
    """Generates a cc_library containing both virtual and system include paths.

    Generates both a header-only target with virtual includes plus the full
    target without virtual includes. This works around the fact that bazel can't
    mix 'includes' and 'include_prefix' in the same target."""

    native.cc_library(
        name=name + "_virtual",
        hdrs=hdrs,
        include_prefix=include_prefix,
        strip_include_prefix=strip_include_prefix,
        deps=deps,
        visibility=["//visibility:private"],
    )

    native.cc_library(name=name,
                      textual_hdrs=hdrs,
                      deps=deps + [":%s_virtual" % name],
                      **kwargs)


def cuda_cc_library(copts=[], **kwargs):
    """Wrapper over cc_library which adds default CUDA options."""
    native.cc_library(copts=cuda_default_copts() + copts, **kwargs)


def cuda_cc_binary(copts=[], **kwargs):
    """Wrapper over cc_library which adds default CUDA options."""
    native.cc_binary(copts=cuda_default_copts() + copts, **kwargs)


def cuda_cc_test(copts=[], **kwargs):
    """Wrapper over cc_test which adds default CUDA options."""
    native.cc_test(copts=copts, **kwargs)


================================================
FILE: build_deps/gpus/cuda/cuda_config.h.tpl
================================================
/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CUDA_CUDA_CONFIG_H_
#define CUDA_CUDA_CONFIG_H_

#define CUDA_VERSION "%{cuda_version}"
#define CUDART_VERSION "%{cudart_version}"
#define CUPTI_VERSION "%{cupti_version}"
#define CUBLAS_VERSION "%{cublas_version}"
#define CUSOLVER_VERSION "%{cusolver_version}"
#define CURAND_VERSION "%{curand_version}"
#define CUFFT_VERSION "%{cufft_version}"
#define CUSPARSE_VERSION "%{cusparse_version}"
#define CUDNN_VERSION "%{cudnn_version}"

#define CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}"

#define CUDA_COMPUTE_CAPABILITIES %{cuda_compute_capabilities}

#endif  // CUDA_CUDA_CONFIG_H_


================================================
FILE: build_deps/gpus/cuda/cuda_config.py.tpl
================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

config = %{cuda_config}


================================================
FILE: build_deps/gpus/find_cuda_config.py
================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Prints CUDA library and header directories and versions found on the system.

The script searches for CUDA library and header files on the system, inspects
them to determine their version and prints the configuration to stdout.
The paths to inspect and the required versions are specified through environment
variables. If no valid configuration is found, the script prints to stderr and
returns an error code.

The list of libraries to find is specified as arguments. Supported libraries are
CUDA (includes cuBLAS), cuDNN, NCCL, and TensorRT.

The script takes a list of base directories specified by the CUDA_PATHS
environment variable as comma-separated glob list. The script looks for headers
and library files in a hard-coded set of subdirectories from these base paths.
If CUDA_PATHS is not specified, a OS specific default is used:

  Linux:   /usr/local/cuda, /usr, and paths from 'ldconfig -p'.
  Windows: CUDA_PATH environment variable, or
           C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\*

For backwards compatibility, some libraries also use alternative base
directories from other environment variables if they are specified. List of
library-specific environment variables:

  Library   Version env variable  Additional base directories
  ----------------------------------------------------------------
  CUDA      CUDA_VERSION       CUDA_TOOLKIT_PATH
  cuBLAS    CUBLAS_VERSION     CUDA_TOOLKIT_PATH
  cuDNN     CUDNN_VERSION      CUDNN_INSTALL_PATH
  NCCL      NCCL_VERSION       NCCL_INSTALL_PATH, NCCL_HDR_PATH
  TensorRT  TENSORRT_VERSION   TENSORRT_INSTALL_PATH

Versions environment variables can be of the form 'x' or 'x.y' to request a
specific version, empty or unspecified to accept any version.

The output of a foun
Download .txt
gitextract_3c35qd95/

├── .bazeliskrc
├── .bazelrc
├── .clang-format
├── .github/
│   └── workflows/
│       ├── blossom-ci.yml
│       ├── docs-build.yaml
│       ├── docs-preview-pr.yaml
│       ├── docs-remove-stale-reviews.yaml
│       └── docs-sched-rebuild.yaml
├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── STYLE_GUIDE.md
├── WORKSPACE
├── bazel_build.sh
├── benchmark/
│   ├── BUILD
│   ├── benchmark_util.cuh
│   ├── dual_bucket_benchmark.cc.cu
│   ├── find_with_missed_keys_benchmark.cc.cu
│   └── merlin_hashtable_benchmark.cc.cu
├── build_deps/
│   ├── gpus/
│   │   ├── BUILD
│   │   ├── check_cuda_libs.py
│   │   ├── configure.bzl
│   │   ├── crosstool/
│   │   │   ├── BUILD
│   │   │   ├── BUILD.tpl
│   │   │   ├── cc_toolchain_config.bzl.tpl
│   │   │   └── crosstool_compiler_wrapper.tpl
│   │   ├── cuda/
│   │   │   ├── BUILD
│   │   │   ├── BUILD.tpl
│   │   │   ├── build_defs.bzl.tpl
│   │   │   ├── cuda_config.h.tpl
│   │   │   └── cuda_config.py.tpl
│   │   └── find_cuda_config.py
│   └── remote_config/
│       ├── BUILD
│       ├── BUILD.tpl
│       ├── common.bzl
│       └── remote_platform_configure.bzl
├── cmake/
│   └── modules/
│       └── ClangFormat.cmake
├── docs/
│   ├── Makefile
│   ├── README.md
│   ├── make.bat
│   ├── requirements-doc.txt
│   └── source/
│       ├── _static/
│       │   ├── .gitkeep
│       │   └── css/
│       │       ├── banner.css
│       │       └── custom.css
│       ├── _templates/
│       │   ├── footer.html
│       │   └── versions.html
│       ├── conf.py
│       ├── index.rst
│       └── toc.yaml
├── include/
│   ├── BUILD
│   ├── merlin/
│   │   ├── BUILD
│   │   ├── allocator.cuh
│   │   ├── array_kernels.cuh
│   │   ├── core_kernels/
│   │   │   ├── BUILD
│   │   │   ├── accum_or_assign.cuh
│   │   │   ├── contains.cuh
│   │   │   ├── dual_bucket_lookup.cuh
│   │   │   ├── dual_bucket_upsert.cuh
│   │   │   ├── dual_bucket_utils.cuh
│   │   │   ├── find_or_insert.cuh
│   │   │   ├── find_ptr_or_insert.cuh
│   │   │   ├── group_lock_kernels.cuh
│   │   │   ├── kernel_utils.cuh
│   │   │   ├── lookup.cuh
│   │   │   ├── lookup_ptr.cuh
│   │   │   ├── update.cuh
│   │   │   ├── update_score.cuh
│   │   │   ├── update_values.cuh
│   │   │   ├── upsert.cuh
│   │   │   └── upsert_and_evict.cuh
│   │   ├── core_kernels.cuh
│   │   ├── debug.hpp
│   │   ├── flexible_buffer.cuh
│   │   ├── group_lock.cuh
│   │   ├── memory_pool.cuh
│   │   ├── multi_vector.hpp
│   │   ├── optimizers.cuh
│   │   ├── types.cuh
│   │   └── utils.cuh
│   ├── merlin_hashtable.cuh
│   └── merlin_localfile.hpp
├── run_all_tests.sh
└── tests/
    ├── accum_or_assign_test.cc.cu
    ├── assign_score_test.cc.cu
    ├── assign_values_test.cc.cu
    ├── dual_bucket_test.cc.cu
    ├── dynamic_max_capacity_test.cc.cu
    ├── export_batch_if_test.cc.cu
    ├── find_or_insert_ptr_lock_test.cc.cu
    ├── find_or_insert_ptr_test.cc.cu
    ├── find_or_insert_test.cc.cu
    ├── find_with_missed_keys_test.cc.cu
    ├── group_lock_test.cc.cu
    ├── insert_and_evict_test.cc.cu
    ├── lock_unlock_test.cc.cu
    ├── memory_pool_test.cc.cu
    ├── merlin_hashtable_test.cc.cu
    ├── reserved_keys_test.cc.cu
    ├── save_and_load_test.cc.cu
    ├── test_util.cuh
    └── uint32_score_test.cc.cu
Download .txt
SYMBOL INDEX (60 symbols across 5 files)

FILE: build_deps/gpus/check_cuda_libs.py
  class ConfigError (line 39) | class ConfigError(Exception):
  function check_cuda_lib (line 43) | def check_cuda_lib(path, check_soname=True):
  function main (line 67) | def main():

FILE: build_deps/gpus/find_cuda_config.py
  class ConfigError (line 73) | class ConfigError(Exception):
  function _is_linux (line 77) | def _is_linux():
  function _is_macos (line 81) | def _is_macos():
  function _matches_version (line 85) | def _matches_version(actual_version, required_version):
  function _at_least_version (line 112) | def _at_least_version(actual_version, required_version):
  function _get_header_version (line 118) | def _get_header_version(path, name):
  function _cartesian_product (line 127) | def _cartesian_product(first, second):
  function _get_ld_config_paths (line 132) | def _get_ld_config_paths():
  function _get_default_cuda_paths (line 150) | def _get_default_cuda_paths(cuda_version):
  function _header_paths (line 162) | def _header_paths():
  function _library_paths (line 175) | def _library_paths():
  function _not_found_error (line 189) | def _not_found_error(base_paths, relative_paths, filepattern):
  function _find_file (line 199) | def _find_file(base_paths, relative_paths, filepattern):
  function _find_library (line 206) | def _find_library(base_paths, library_name, required_version):
  function _find_versioned_file (line 217) | def _find_versioned_file(base_paths, relative_paths, filepatterns,
  function _find_header (line 233) | def _find_header(base_paths, header_name, required_version, get_version):
  function _find_cuda_config (line 239) | def _find_cuda_config(base_paths, required_version):
  function _find_cublas_config (line 304) | def _find_cublas_config(base_paths, required_version, cuda_version):
  function _find_cusolver_config (line 338) | def _find_cusolver_config(base_paths, required_version, cuda_version):
  function _find_curand_config (line 371) | def _find_curand_config(base_paths, required_version, cuda_version):
  function _find_cufft_config (line 402) | def _find_cufft_config(base_paths, required_version, cuda_version):
  function _find_cudnn_config (line 433) | def _find_cudnn_config(base_paths, required_version):
  function _find_cusparse_config (line 457) | def _find_cusparse_config(base_paths, required_version, cuda_version):
  function _find_nccl_config (line 488) | def _find_nccl_config(base_paths, required_version):
  function _find_tensorrt_config (line 509) | def _find_tensorrt_config(base_paths, required_version):
  function _list_from_env (line 537) | def _list_from_env(env_name, default=[]):
  function _get_legacy_path (line 544) | def _get_legacy_path(env_name, default=[]):
  function _normalize_path (line 559) | def _normalize_path(path):
  function find_cuda_config (line 564) | def find_cuda_config():
  function main (line 638) | def main():

FILE: include/merlin/debug.hpp
  type nv (line 24) | namespace nv {
    type merlin (line 25) | namespace merlin {
      class CudaException (line 27) | class CudaException : public std::runtime_error {
        method CudaException (line 29) | CudaException(const std::string& what) : runtime_error(what) {}
      function cuda_check_ (line 32) | inline void cuda_check_(cudaError_t val, const char* file, int line) {
      class MerlinException (line 50) | class MerlinException : public std::runtime_error {
        method MerlinException (line 52) | MerlinException(const std::string& what) : runtime_error(what) {}
      function merlin_check_ (line 56) | inline void merlin_check_(bool cond, const Msg& msg, const char* file,

FILE: include/merlin/multi_vector.hpp
  type nv (line 26) | namespace nv {
    type merlin (line 27) | namespace merlin {
      class MultiVector (line 45) | class MultiVector {
        method MultiVector (line 51) | explicit MultiVector(Lens... lens) {
        method get (line 62) | auto get(uint8_t* data) {
        method length (line 67) | size_t length(size_t idx) const { return lengths_[idx]; }
        method offset (line 69) | size_t offset(size_t idx) const { return offsets_[idx]; }
        method total_size (line 71) | size_t total_size() const { return total_size_; }
        method align_up (line 78) | constexpr size_t align_up(size_t n, size_t alignment) {
        method compute_offsets (line 82) | void compute_offsets() {
      function get_vector (line 95) | auto get_vector(MultiVector<Ts...>& mv, uint8_t* data) {

FILE: include/merlin_localfile.hpp
  type nv (line 24) | namespace nv {
    type merlin (line 25) | namespace merlin {
      class LocalKVFile (line 45) | class LocalKVFile : public BaseKVFile<K, V, M> {
        method LocalKVFile (line 47) | LocalKVFile() : keys_fp_(nullptr), values_fp_(nullptr), scores_fp_...
        method open (line 61) | bool open(const std::string& keys_path, const std::string& values_...
        method close (line 85) | void close() noexcept {
        method read (line 115) | size_t read(const size_t n, const size_t dim, K* keys, V* vectors,
        method write (line 141) | size_t write(const size_t n, const size_t dim, const K* keys,
Condensed preview — 104 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,154K chars).
[
  {
    "path": ".bazeliskrc",
    "chars": 24,
    "preview": "USE_BAZEL_VERSION=5.0.0\n"
  },
  {
    "path": ".bazelrc",
    "chars": 760,
    "preview": "build -c opt\nbuild --copt -O3\nbuild --copt -pthread\nbuild --linkopt -pthread\nbuild --linkopt -ldl\nbuild --incompatible_l"
  },
  {
    "path": ".clang-format",
    "chars": 91,
    "preview": "BasedOnStyle: Google\nDerivePointerAlignment: false\nIncludeBlocks: Merge\nSortIncludes: true\n"
  },
  {
    "path": ".github/workflows/blossom-ci.yml",
    "chars": 3067,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
  },
  {
    "path": ".github/workflows/docs-build.yaml",
    "chars": 1230,
    "preview": "name: docs-build\n\non:\n  pull_request:\n    branches: [master]\n\njobs:\n  build:\n    runs-on: \"ubuntu-latest\"\n\n    steps:\n  "
  },
  {
    "path": ".github/workflows/docs-preview-pr.yaml",
    "chars": 245,
    "preview": "name: docs-preview-pr\n\non:\n  workflow_run:\n    workflows: [docs-build]\n    types: [completed]\n\nenv:\n  WF_ID: ${{ github."
  },
  {
    "path": ".github/workflows/docs-remove-stale-reviews.yaml",
    "chars": 248,
    "preview": "name: docs-remove-stale-reviews\n\non:\n  schedule:\n    # 42 minutes after 0:00 UTC on Sundays\n    - cron: \"42 0 * * 0\"\n  w"
  },
  {
    "path": ".github/workflows/docs-sched-rebuild.yaml",
    "chars": 5306,
    "preview": "name: docs-sched-rebuild\n\non:\n  push:\n    branches: [master]\n    tags:\n      - v*\n  workflow_dispatch:\n\njobs:\n  build:\n "
  },
  {
    "path": ".gitignore",
    "chars": 131,
    "preview": ".DS_Store\n.idea\n.vscode\nbuild\n.clwb\ncmake-build-debug/\ndocs/build\ndocs/source/README.md\ndocs/source/CONTRIBUTING.md\ndocs"
  },
  {
    "path": ".gitmodules",
    "chars": 120,
    "preview": "[submodule \"tests/googletest\"]\n\tpath = tests/googletest\n\turl = https://github.com/google/googletest.git\n\tignore = dirty\n"
  },
  {
    "path": "CMakeLists.txt",
    "chars": 8677,
    "preview": "# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# Apache-2.0\n#\n# Licensed under the Apache Li"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1507,
    "preview": "# Contributing\n\n## About HierarchicalKV\n\nHierarchicalKV is a part of NVIDIA Merlin and provides hierarchical key-value s"
  },
  {
    "path": "LICENSE",
    "chars": 11347,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 18418,
    "preview": "# [NVIDIA HierarchicalKV(Beta)](https://github.com/NVIDIA-Merlin/HierarchicalKV)\n\n[![Version](https://img.shields.io/git"
  },
  {
    "path": "STYLE_GUIDE.md",
    "chars": 818,
    "preview": "#### C++\nC++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).\n\nHierar"
  },
  {
    "path": "WORKSPACE",
    "chars": 456,
    "preview": "workspace(name = \"HierarchicalKV\")\n\nload(\"@bazel_tools//tools/build_defs/repo:http.bzl\", \"http_archive\")\nload(\"//build_d"
  },
  {
    "path": "bazel_build.sh",
    "chars": 141,
    "preview": "#!/bin/bash\n\n# Usage : `./bazel_build.sh` or `bash bazel_build.sh`\nset -e\nexport $(cat .bazeliskrc | xargs)\n\nbazel build"
  },
  {
    "path": "benchmark/BUILD",
    "chars": 469,
    "preview": "load(\"@local_config_cuda//cuda:build_defs.bzl\", \"cuda_cc_library\")\n\ncc_binary(\n    name = \"benchmark_util\",\n    deps = ["
  },
  {
    "path": "benchmark/benchmark_util.cuh",
    "chars": 8088,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "benchmark/dual_bucket_benchmark.cc.cu",
    "chars": 5722,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "benchmark/find_with_missed_keys_benchmark.cc.cu",
    "chars": 6705,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "benchmark/merlin_hashtable_benchmark.cc.cu",
    "chars": 23793,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "build_deps/gpus/BUILD",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "build_deps/gpus/check_cuda_libs.py",
    "chars": 2944,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "build_deps/gpus/configure.bzl",
    "chars": 43905,
    "preview": "\"\"\"Repository rule for CUDA autoconfiguration.\n\n`cuda_configure` depends on the following environment variables:\n\n  * `N"
  },
  {
    "path": "build_deps/gpus/crosstool/BUILD",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "build_deps/gpus/crosstool/BUILD.tpl",
    "chars": 3467,
    "preview": "# This file is expanded from a template by cuda_configure.bzl\n# Update cuda_configure.bzl#verify_build_defines when addi"
  },
  {
    "path": "build_deps/gpus/crosstool/cc_toolchain_config.bzl.tpl",
    "chars": 25793,
    "preview": "\"\"\"cc_toolchain_config rule for configuring CUDA toolchains on Linux, Mac, and Windows.\"\"\"\n\nload(\n    \"@bazel_tools//too"
  },
  {
    "path": "build_deps/gpus/crosstool/crosstool_compiler_wrapper.tpl",
    "chars": 10732,
    "preview": "#!/usr/bin/env python\n# Copyright (c) 2023, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "build_deps/gpus/cuda/BUILD",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "build_deps/gpus/cuda/BUILD.tpl",
    "chars": 5122,
    "preview": "load(\":build_defs.bzl\", \"cuda_header_library\")\nload(\"@bazel_skylib//:bzl_library.bzl\", \"bzl_library\")\nload(\"@bazel_skyli"
  },
  {
    "path": "build_deps/gpus/cuda/build_defs.bzl.tpl",
    "chars": 1803,
    "preview": "# Macros for building CUDA code.\ndef cuda_default_copts():\n    \"\"\"Default options for all CUDA compilations.\"\"\"\n    retu"
  },
  {
    "path": "build_deps/gpus/cuda/cuda_config.h.tpl",
    "chars": 1197,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "build_deps/gpus/cuda/cuda_config.py.tpl",
    "chars": 615,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "build_deps/gpus/find_cuda_config.py",
    "chars": 24400,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "build_deps/remote_config/BUILD",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "build_deps/remote_config/BUILD.tpl",
    "chars": 832,
    "preview": "# Each platform creates a constraint @<platform>//:platform_constraint that\n# is listed in its constraint_values; rule t"
  },
  {
    "path": "build_deps/remote_config/common.bzl",
    "chars": 8622,
    "preview": "\"\"\"Functions common across configure rules.\"\"\"\n\nBAZEL_SH = \"BAZEL_SH\"\nPYTHON_BIN_PATH = \"PYTHON_BIN_PATH\"\nPYTHON_LIB_PAT"
  },
  {
    "path": "build_deps/remote_config/remote_platform_configure.bzl",
    "chars": 1769,
    "preview": "\"\"\"Repository rule to create a platform for a docker image to be used with RBE.\"\"\"\n\n\ndef _remote_platform_configure_impl"
  },
  {
    "path": "cmake/modules/ClangFormat.cmake",
    "chars": 1282,
    "preview": "# Copyright Tomas Zeman 2018.\n# Distributed under the Boost Software License, Version 1.0.\n# (See accompanying file LICE"
  },
  {
    "path": "docs/Makefile",
    "chars": 785,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
  },
  {
    "path": "docs/README.md",
    "chars": 4132,
    "preview": "# Documentation\n\nThis folder contains the scripts necessary to build the documentation for HierarchicalKV.\nYou can view "
  },
  {
    "path": "docs/make.bat",
    "chars": 764,
    "preview": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-bu"
  },
  {
    "path": "docs/requirements-doc.txt",
    "chars": 443,
    "preview": "# packages necessary to run tests and push PRs\n# assumes requirements for nvtabular logic are already installed\n\nwheel\n\n"
  },
  {
    "path": "docs/source/_static/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "docs/source/_static/css/banner.css",
    "chars": 557,
    "preview": ".wy-nav-content {\n    margin: 0;\n    background: #fcfcfc;\n    padding-top: 40px;\n}\n\n.wy-side-nav-search {\n    display: b"
  },
  {
    "path": "docs/source/_static/css/custom.css",
    "chars": 382,
    "preview": "dl.cpp > dt > span.pre { padding-right: 2px; }\n\n/* dl.cpp > dt > a > span.pre { padding-right: 2px; } */\n\ndl > dt > em >"
  },
  {
    "path": "docs/source/_templates/footer.html",
    "chars": 907,
    "preview": "{% extends '!footer.html' %}\n{% block contentinfo %}\n{{ super() }}\n<p>\n<a href=\"https://www.nvidia.com/en-us/about-nvidi"
  },
  {
    "path": "docs/source/_templates/versions.html",
    "chars": 806,
    "preview": "{%- if current_version %}\n<div class=\"rst-versions\" data-toggle=\"rst-versions\" role=\"note\" aria-label=\"versions\">\n  <spa"
  },
  {
    "path": "docs/source/conf.py",
    "chars": 5478,
    "preview": "\"\"\"\n Copyright (c) 2021, NVIDIA CORPORATION.\n\n Licensed under the Apache License, Version 2.0 (the \"License\");\n you may "
  },
  {
    "path": "docs/source/index.rst",
    "chars": 805,
    "preview": "Merlin Key-Value Storage\n========================\n\nMerlin Key-Value Storage is an open source library that provides hier"
  },
  {
    "path": "docs/source/toc.yaml",
    "chars": 472,
    "preview": "root: index\nsubtrees:\n  - caption: Contents\n    entries:\n      - file: README.md\n        title: Introduction\n      - fil"
  },
  {
    "path": "include/BUILD",
    "chars": 555,
    "preview": "load(\"@local_config_cuda//cuda:build_defs.bzl\", \"cuda_cc_library\")\n\ncuda_cc_library(\n    name = \"merlin_localfile\",\n    "
  },
  {
    "path": "include/merlin/BUILD",
    "chars": 804,
    "preview": "load(\"@local_config_cuda//cuda:build_defs.bzl\", \"cuda_cc_library\")\n\ncuda_cc_library(\n    name = \"types_and_utils\",\n    s"
  },
  {
    "path": "include/merlin/allocator.cuh",
    "chars": 4516,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/array_kernels.cuh",
    "chars": 4479,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/BUILD",
    "chars": 332,
    "preview": "load(\"@local_config_cuda//cuda:build_defs.bzl\", \"cuda_cc_library\")\n\ncuda_cc_library(\n    name = \"core_kernels\",\n    srcs"
  },
  {
    "path": "include/merlin/core_kernels/accum_or_assign.cuh",
    "chars": 11689,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/contains.cuh",
    "chars": 11425,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/dual_bucket_lookup.cuh",
    "chars": 29842,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/dual_bucket_upsert.cuh",
    "chars": 33454,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/dual_bucket_utils.cuh",
    "chars": 3670,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/find_or_insert.cuh",
    "chars": 76797,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/find_ptr_or_insert.cuh",
    "chars": 15280,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/group_lock_kernels.cuh",
    "chars": 4616,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/kernel_utils.cuh",
    "chars": 34835,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/lookup.cuh",
    "chars": 49693,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/lookup_ptr.cuh",
    "chars": 18127,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/update.cuh",
    "chars": 37607,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/update_score.cuh",
    "chars": 24682,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/update_values.cuh",
    "chars": 34498,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/upsert.cuh",
    "chars": 72952,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels/upsert_and_evict.cuh",
    "chars": 77398,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/core_kernels.cuh",
    "chars": 53960,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/debug.hpp",
    "chars": 2230,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/flexible_buffer.cuh",
    "chars": 1482,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/group_lock.cuh",
    "chars": 10480,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/memory_pool.cuh",
    "chars": 20676,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/multi_vector.hpp",
    "chars": 2703,
    "preview": "/*\n * Copyright (c) 2025, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/optimizers.cuh",
    "chars": 2803,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/types.cuh",
    "chars": 11728,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin/utils.cuh",
    "chars": 9720,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin_hashtable.cuh",
    "chars": 165832,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "include/merlin_localfile.hpp",
    "chars": 5397,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "run_all_tests.sh",
    "chars": 371,
    "preview": "#!/bin/bash\n\n# Usage : `bash run_all_tests.sh`\n\n# Search for all binary files that end with \"test\"\nfiles=$(find ./build/"
  },
  {
    "path": "tests/accum_or_assign_test.cc.cu",
    "chars": 116784,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/assign_score_test.cc.cu",
    "chars": 74374,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/assign_values_test.cc.cu",
    "chars": 30099,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/dual_bucket_test.cc.cu",
    "chars": 68402,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/dynamic_max_capacity_test.cc.cu",
    "chars": 4938,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/export_batch_if_test.cc.cu",
    "chars": 11059,
    "preview": "#include <cooperative_groups.h>\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <algorithm>\n#include <"
  },
  {
    "path": "tests/find_or_insert_ptr_lock_test.cc.cu",
    "chars": 122803,
    "preview": "/*\n * Copyright (c) 2025, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/find_or_insert_ptr_test.cc.cu",
    "chars": 139003,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/find_or_insert_test.cc.cu",
    "chars": 143547,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/find_with_missed_keys_test.cc.cu",
    "chars": 7052,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/group_lock_test.cc.cu",
    "chars": 5523,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/insert_and_evict_test.cc.cu",
    "chars": 70882,
    "preview": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/lock_unlock_test.cc.cu",
    "chars": 3611,
    "preview": "/*\n * Copyright (c) 2025, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/memory_pool_test.cc.cu",
    "chars": 18921,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/merlin_hashtable_test.cc.cu",
    "chars": 148362,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/reserved_keys_test.cc.cu",
    "chars": 3231,
    "preview": "/*\n * Copyright (c) 2024, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/save_and_load_test.cc.cu",
    "chars": 4769,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/test_util.cuh",
    "chars": 23894,
    "preview": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * y"
  },
  {
    "path": "tests/uint32_score_test.cc.cu",
    "chars": 10641,
    "preview": "#include <gtest/gtest.h>\n#include <algorithm>\n#include <cstdint>\n#include <iostream>\n#include <limits>\n#include <memory>"
  }
]

About this extraction

This page contains the full source code of the NVIDIA-Merlin/HierarchicalKV GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 104 files (2.0 MB), approximately 526.9k tokens, and a symbol index with 60 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!