[
  {
    "path": ".clang-format",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\n\nLanguage: Cpp\nBasedOnStyle: LLVM\nAccessModifierOffset: -2\nAlignAfterOpenBracket: Align\nAlignArrayOfStructures: None\nAlignConsecutiveMacros: None\nAlignConsecutiveAssignments: None\nAlignConsecutiveBitFields: None\nAlignConsecutiveDeclarations: None\nAlignEscapedNewlines: Right\nAlignOperands: Align\nAlignTrailingComments: true\nAllowAllArgumentsOnNextLine: true\n# AllowAllConstructorInitializersOnNextLine: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowShortEnumsOnASingleLine: true\nAllowShortBlocksOnASingleLine: Always\nAllowShortCaseLabelsOnASingleLine: true\nAllowShortFunctionsOnASingleLine: InlineOnly\nAllowShortIfStatementsOnASingleLine: Never\nAllowShortLambdasOnASingleLine: Inline\nAllowShortLoopsOnASingleLine: false\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: false\nAlwaysBreakTemplateDeclarations: Yes\nBinPackArguments: false\nBinPackParameters: false\nBraceWrapping:\n  AfterCaseLabel: true\n  AfterClass: true\n  AfterControlStatement: true\n  AfterEnum: true\n  AfterFunction: true\n  AfterNamespace: true\n  AfterStruct: true\n  AfterUnion: true\n  AfterExternBlock: true\n  BeforeCatch: true\n  BeforeElse: true\n  BeforeLambdaBody: false\n  BeforeWhile:     false\n  IndentBraces: false\n  SplitEmptyFunction: false\n  SplitEmptyRecord: false\n  SplitEmptyNamespace: false\nBreakBeforeBinaryOperators: NonAssignment\nBreakBeforeConceptDeclarations: true\nBreakBeforeBraces: Custom\nBreakBeforeInheritanceComma: false\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializersBeforeComma: false\nBreakConstructorInitializers: BeforeColon\nBreakInheritanceList: BeforeColon\nBreakStringLiterals: true\nColumnLimit: 80\nCommentPragmas: \"^ H2_DISPATCH\"\nCompactNamespaces: false\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 2\nContinuationIndentWidth: 2\nCpp11BracedListStyle: true\nDerivePointerAlignment: false\nDisableFormat: false\nFixNamespaceComments: true\n# TypenameMacros: ['TODO']\nIncludeBlocks: Regroup\n# Hopefully this will keep the config file at the top...\nIncludeCategories:\n  - Regex:           '^<catch2/catch.hpp>'\n    Priority:        -1                         # Always up front\n  - Regex:           '^((<|\").+_(config|export)\\.h(pp)?(>|\"))'        # Configure headers\n    Priority:        -1\n  - Regex:           '^((<|\")lbannv2/)'   # Project headers\n    Priority:         2\n  - Regex:           '^((<|\")h2/)'   # Project headers\n    Priority:         2\n  - Regex:           '^\"[[:alnum:]_.]+\\.(hpp|cuh)\"'   # Project headers\n    Priority:         3\n  - Regex:           '^<[[:alnum:]_.]+\\.(hpp|cuh)>'  # \"Normal headers\"\n    Priority:         4\n  - Regex:           '^(<(ATen|c10|torch|pybind11|python)/)'\n    Priority:         4\n  - Regex:           '<[[:alnum:]_]+>'          # STL Headers last\n    Priority:         6\n  - Regex:           '*.h'\n    Priority:         7\nIncludeIsMainRegex: \"(Test)?$\"\nIndentCaseLabels: false\nIndentPPDirectives: None\nIndentWidth: 2\nIndentWrappedFunctionNames: false\nKeepEmptyLinesAtTheStartOfBlocks: false\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\n# PenaltyBreakAssignment\n# PenaltyBreakBeforeFirstCallParameter\n# PenaltyBreakComment\n# PenaltyBreakFirstLessLess\n# PenaltyBreakString\n# PenaltyBreakTemplateDeclaration\n# PenaltyExcessCharacter\n# PenaltyReturnTypeOnItsOwnLine\nPointerAlignment: Left\nQualifierAlignment: Right\nReflowComments: true\nSortIncludes: true\nSortUsingDeclarations: true\nSpaceAfterCStyleCast: true\nSpaceAfterLogicalNot: false\nSpaceAfterTemplateKeyword: true\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeCpp11BracedList: true\nSpaceBeforeCtorInitializerColon: true\nSpaceBeforeInheritanceColon: true\nSpaceBeforeParens: ControlStatements\nSpaceBeforeRangeBasedForLoopColon: true\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles: false\nSpacesInCStyleCastParentheses: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard: Latest\n# StatementMacros\nTabWidth: 2\nUseTab: Never\n"
  },
  {
    "path": ".gitignore",
    "content": "################################################################################\n## Copyright 2019-2020 Lawrence Livermore National Security, LLC and other\n## DiHydrogen Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\n\n# Emacs backup garbage\n.backup/\n.cache/\n\n# Other standard ignores\n*~\n*.pyc\n\\#*#\n.#*\n.*.swp\n.DS_Store\n*.bak\n.dir-locals.el\ncompile_commands.json\n\n# building/install not-entirely-out-of-source stuff\nbuild*/\ninstall*/\n\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ncmake_minimum_required(VERSION 3.27)\nproject(LBANNv2\n  VERSION 0.0.1\n  DESCRIPTION \"DiHydrogen integration with PyTorch\"\n  HOMEPAGE_URL \"https://github.com/lbann\"\n  LANGUAGES CXX\n)\n\noption(LBANNV2_DEBUG_MODE\n  \"Enable extra assertions helpful in debugging.\"\n  OFF)\n\n# Make Tom's life easier\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON\n  CACHE BOOL \"Write compile_commands.json\" FORCE)\n\n# FIXME (trb): This is probably the right thing, but we should think\n# about if this is strictly needed.\nset(BUILD_SHARED_LIBS ON)\nset(CMAKE_CXX_STANDARD 20) # For DiHydrogen\n\n# FIXME (trb): These are generally useful for development and\n# debugging. I should probably pass them on cmd line, but again, lazy.\nset(CMAKE_CXX_FLAGS_DEBUG \"-g3 -O0 -fno-omit-frame-pointer\")\nset(CMAKE_HIP_FLAGS_DEBUG \"-g3 -O0 -fno-omit-frame-pointer\")\n\n# Language support\n#\n# Just set things for CUDA *and* HIP hoping they'll be ignored on\n# irrelevant platforms.\n\n# ampere, hopper\nset(CMAKE_CUDA_ARCHITECTURES 80 90)\nset(TORCH_CUDA_ARCH_LIST 8.0 9.0)\nset(CMAKE_CUDA_STANDARD 17)\n\n# MI50, MI250X, MI300A, MI300X\nset(CMAKE_HIP_ARCHITECTURES gfx906 gfx90a gfx942)\nset(ENV{PYTORCH_ROCM_ARCH} \"${CMAKE_HIP_ARCHITECTURES}\")\nset(PYTORCH_ROCM_ARCH ${CMAKE_HIP_ARCHITECTURES})\n\n# Setup dependencies\n\nset(LBANNV2_MINIMUM_Python_VERSION 3.10)\nset(LBANNV2_MINIMUM_H2_VERSION 0.4.0)\nset(LBANNV2_MINIMUM_Torch_VERSION 2.9.0)\n\nfind_package(Python\n  ${LBANNV2_MINIMUM_Python_VERSION}\n  REQUIRED\n  COMPONENTS Interpreter Development.Module)\n\n# Interrogate the Python environment (via pip) to detect NVIDIA\n# dependencies in the environment. Currently, this is based on the\n# Torch module that's installed in the environment, if any exists, and\n# meaningful values will only be returned if such a module exists.\n#\n# FIXME (trb): We just handle cuDNN and NCCL here because those are\n# the only ones that overlap with Al/H2 needs, but we might consider\n# adding paths for the rest of them since Torch will (presumably)\n# depend on them.\n#\n# An alternative approach _could_ be to detect all NVIDIA modules\n# known to pip and simply parse those. I'm not sure how realistic this\n# might be in practice, but presumably one _could_ have\n# nvidia-cudnn-cu11 and nvidia-cudnn-cu12 in the same environment, and\n# one could imagine that those packages would provide distinct\n# installations of these libraries (fun fact: they don't). Hence the\n# preference to let PyTorch tell me which modules it should use. If\n# someone was trying to use a Torch that Pip couldn't detect but with\n# pip-managed NVIDIA modules, I would classify them as a \"power user\"\n# and expect that they can handle adding command line arguments to the\n# LBANNv2 build.\nlist(PREPEND CMAKE_MODULE_PATH \"${PROJECT_SOURCE_DIR}/cmake\")\n\ninclude(LBANNv2DetectTorchNVIDIALibraries)\ndetect_torch_nvidia_libraries(LIBRARIES cudnn nccl)\n\nforeach (pkg cudnn nccl)\n  if (LBANNV2_DETECTED_${pkg})\n    string(TOUPPER \"${pkg}\" pkg_upper)\n    set(${pkg_upper}_LIBRARY\n      \"${LBANNV2_DETECTED_${pkg}_LIBRARY}\"\n      CACHE\n      FILEPATH\n      \"Path to ${pkg_upper} library.\" FORCE)\n    set(${pkg_upper}_INCLUDE_PATH\n      \"${LBANNV2_DETECTED_${pkg}_INCLUDE_PATH}\"\n      CACHE\n      PATH\n      \"Include directory for ${pkg}\" FORCE)\n  endif()\nendforeach ()\n\n# Special handling for Torch+cuDNN\nif (LBANNV2_DETECTED_cudnn)\n  # Torch uses \"LIBRARY_PATH\" for the location of the main cuDNN\n  # library. Because why wouldn't they??\n  set(CUDNN_LIBRARY_PATH\n    \"${LBANNV2_DETECTED_cudnn_LIBRARY}\"\n    CACHE\n    FILEPATH\n    \"Path to cuDNN library.\")\n\n  set(CAFFE2_USE_CUDNN ON CACHE BOOL \"Have the build search for cuDNN\")\nendif ()\n\n# Ok, the CMake here gets a little rocky. The goal is to \"pip install\n# .\" and it should just build \"the right thing\". So we need to\n# auto-detect as much as we can under the weakest assumptions possible\n# (e.g., we should not assume \"torch.cuda.is_available()\" gives\n# meaningful information, as we may be building on a GPU-less head\n# node). It seems reasonable to just find Torch and see what its CMake\n# export can tell us. For instance, \"torch_hip\" will be found on ROCm\n# platforms, and \"torch_cuda\" will be found on CUDA platforms -- we\n# assume (hope!) that these are truly orthogonal! From there, we can\n# pull a few additional flags in by further interrogating the targets,\n# if needed.\n\nfind_package(Torch\n  ${LBANNV2_MINIMUM_Torch_VERSION}\n  REQUIRED\n)\n\n# We also don't care about the limited API nonsense, so we can use\n# libtorch. Let's find it.\nif (TORCH_LIBRARY)\n  get_filename_component(TORCH_LIB_DIR \"${TORCH_LIBRARY}\" DIRECTORY)\nendif ()\nfind_library(TORCH_PYTHON_LIBRARY\n  torch_python\n  HINTS\n  ${TORCH_LIB_DIR}\n  ${Python_SITELIB}/torch/lib64\n  ${Python_SITELIB}/torch/lib\n  NO_DEFAULT_PATH)\nfind_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED)\n\n# MI300A only becomes a factor when doing a ROCm build. So start by\n# assuming we don't have it.\n#\n# FIXME (trb): This should, of course, be relaxed to just represent\n# memory coherence. However, I don't have access to any non-MI300A\n# memory-coherent architectures. If anyone does, I'm happy to abstract\n# this now; otherwise, I'll wait until I acquire such access myself.\nset(LBANNV2_WITHOUT_MI300A ON)\nunset(LBANNV2_WITH_MI300A)\nunset(LBANNV2_UNKNOWN_MI300A)\nunset(LBANNV2_HAS_CUDA)\nunset(LBANNV2_HAS_ROCM)\n\nif (TARGET torch_cuda)\n  set(LBANNV2_HAS_CUDA ON)\n  # We need to edit out the CUDA arch flags out. Or at least edit them\n  # down to supported archs (>=70).\nelseif (TARGET torch_hip)\n  enable_language(HIP)\n  set(LBANNV2_HAS_ROCM ON)\n\n  # Handle MI300A configure checks.\n  include(LBANNv2DetermineMI300A)\n  set(_valid_mi300a_status \"WITH\" \"WITHOUT\" \"UNKNOWN\")\n  set(LBANNV2_MI300A_STATUS \"DETECT\"\n    CACHE STRING\n    \"On MI300A? Valid values: WITH, WITHOUT, UNKNOWN, DETECT\")\n  string(TOUPPER \"${LBANNV2_MI300A_STATUS}\" _mi300a_status_upper)\n  if (NOT _mi300a_status_upper IN_LIST _valid_mi300a_status)\n    determine_mi300a_support(_mi300a_status_upper)\n  endif ()\n  unset(LBANNV2_WITH_MI300A)\n  unset(LBANNV2_WITHOUT_MI300A)\n  unset(LBANNV2_UNKNOWN_MI300A)\n  set(LBANNV2_${_mi300a_status_upper}_MI300A ON)\n  # If we determine that we have MI300A, we can make some static\n  # optimizations and eliminate some flow control. In the \"UNKNOWN\"\n  # case, these static branches are replaced by dynamic ones, possibly\n  # incurring some small overhead.\n  #\n  # As far as I can figure, the only case in which this could cause\n  # problems (rather than just being suboptimal) is if we declare (or\n  # decide) that we have MI300A when we actually do not. In\n  # particular, this would cause our assumptions about CPU/GPU memory\n  # visibility to be invalid -- hipMalloc'd memory would not be valid\n  # on the CPU.\n\n  # We need to remove any \"std=c++<XY>\" type options because we're\n  # ahead of PyTorch's minimum requirements there.\n  get_target_property(\n    _torch_hip_compile_opts\n    torch_hip\n    INTERFACE_COMPILE_OPTIONS)\n  foreach (_opt ${_torch_hip_compile_opts})\n    if (_opt MATCHES \"-std=c\\\\+\\\\+[0-9a-z]+\")\n      list(REMOVE_ITEM _torch_hip_compile_opts \"${_opt}\")\n    endif ()\n  endforeach()\n  set_target_properties(torch_hip\n    PROPERTIES INTERFACE_COMPILE_OPTIONS \"${_torch_hip_compile_opts}\")\n\n  # FIXME (trb): So I really, truly hate this, but this seems to be\n  # the shortest approach to dealing with version-based switches in\n  # the C++ code. Other approaches involve obscure preprocessor macros\n  # or long-winded SFINAE tricks, and because I want you, dear human\n  # reader, to be happy, I opted for this very simple, highly readable\n  # implementation instead.\n  if (Torch_VERSION VERSION_LESS \"2.11.0\")\n    set(LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS TRUE)\n  endif ()\nendif ()\n\n# We need to determine if we should be using a CXX11_ABI macro or not\n# so we can forward as appropriate to spdlog/Catch2/etc. We need to do\n# this *BEFORE* adding DiHydrogen(/spdlog/Catch2); otherwise it won't\n# get picked up and we'd have to add it to the respective targets\n# later on.\nif (TORCH_CXX_FLAGS AND TORCH_CXX_FLAGS MATCHES \"GLIBCXX_USE_CXX11_ABI=([01])\")\n  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${CMAKE_MATCH_1})\nendif ()\n\n# spdlog\ninclude(FetchContent)\nFetchContent_Declare(\n  spdlog\n  GIT_REPOSITORY https://github.com/gabime/spdlog.git\n  GIT_TAG 79524ddd08a4ec981b7fea76afd08ee05f83755d # v1.17.0\n  GIT_SHALLOW 1\n  FIND_PACKAGE_ARGS CONFIG\n)\n\n# Ensure spdlog gets installed and exported properly. I can probably\n# make this a non-cached (or a FORCE'd) variable, but this is fine.\nset(SPDLOG_INSTALL ON CACHE INTERNAL \"Install spdlog\")\nFetchContent_MakeAvailable(spdlog)\n\n# Python module stuff\nfind_package(pybind11 CONFIG REQUIRED)\n\n# Set a few RPATH handling things\nset(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)\nif(APPLE)\n  list(PREPEND CMAKE_INSTALL_RPATH \"@loader_path\")\nelse()\n  list(PREPEND CMAKE_INSTALL_RPATH \"\\$ORIGIN\")\nendif()\n\n# Add the library\nadd_library(lbannv2 SHARED)\nadd_library(lbann::lbannv2 ALIAS lbannv2)\ntarget_sources(lbannv2\n  PUBLIC\n  FILE_SET HEADERS\n  BASE_DIRS src\n)\ntarget_link_libraries(lbannv2\n  PUBLIC\n  torch\n  spdlog::spdlog\n)\nset_target_properties(lbannv2\n  PROPERTIES\n  CXX_STANDARD 20\n  CXX_STANDARD_REQUIRED ON\n  CXX_EXTENSIONS OFF\n  VERSION ${LBANNv2_VERSION}\n  SOVERSION ${LBANNv2_VERSION_MAJOR}\n)\n\n# Create the Python module\npython_add_library(_lbannv2 MODULE WITH_SOABI)\ntarget_link_libraries(_lbannv2\n  PUBLIC\n  lbann::lbannv2\n  \"${TORCH_PYTHON_LIBRARY}\"\n  PRIVATE\n  pybind11::headers\n)\nset_target_properties(_lbannv2\n  PROPERTIES\n  CXX_STANDARD 20\n  CXX_STANDARD_REQUIRED ON\n  CXX_EXTENSIONS OFF\n)\n\n# Handle logging. If `LBANNV2_LOG_LEVEL` is not set,\n# SPDLOG_ACTIVE_LEVEL will not be set on the command line and will\n# default to `SPDLOG_LEVEL_TRACE` in the C++ code\n# (src/lbannv2/utils/logging.hpp).\n#\n# NOTE that this is the *compile time* log level. That is, if\n# LBANN_LOG_LEVEL is set to \"TRACE\", every log message (*using the\n# LBANNV2_LOG* macros) will be compiled; if it's set to \"INFO\",\n# messages flagged as \"TRACE\" or \"DEBUG\" will not even be compiled.\n# The default is set to \"TRACE\" so that all log messages are\n# available, depending on the log level selected at runtime.\nset(lbannv2_ok_log_levels\n  \"TRACE\" \"DEBUG\" \"INFO\" \"WARN\" \"ERROR\" \"CRITICAL\" \"OFF\")\nif (LBANNV2_LOG_LEVEL IN_LIST lbannv2_ok_log_levels)\n  target_compile_definitions(\n    lbannv2\n    PRIVATE\n    SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${LBANNV2_LOG_LEVEL}\n  )\n\n  target_compile_definitions(\n    _lbannv2\n    PRIVATE\n    SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${LBANNV2_LOG_LEVEL}\n  )\nendif ()\n\n# Add the sources to the library\nadd_subdirectory(src/lbannv2)\n\n# Generate the export header\ninclude(GenerateExportHeader)\ngenerate_export_header(lbannv2)\n\n# Generate the configuration header\nconfigure_file(\n  ${PROJECT_SOURCE_DIR}/cmake/lbannv2_config.h.in\n  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_config.h\n  @ONLY\n)\n\n# Include it in the file set\ntarget_sources(lbannv2 PUBLIC\n  FILE_SET HEADERS\n  BASE_DIRS ${CMAKE_CURRENT_BINARY_DIR}\n  FILES\n  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_config.h\n  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_export.h\n)\n\n# Handle unit testing\ninclude(CTest)\nif (BUILD_TESTING)\n  add_subdirectory(test)\nendif ()\n\n# Install stuff\n#\n# When building the Python bindings, we still install the whole C++\n# library. We might want to clean this up. Also, we set\n# tools.scikit-build.wheel.install-dir=lbannv2 so it installs into\n# <site-packages>/lbannv2.\ninclude(GNUInstallDirs)\n\nset(\n  CMAKE_INSTALL_CMAKEDIR\n  \"${CMAKE_INSTALL_LIBDIR}/cmake/lbannv2\"\n)\n\ninstall(TARGETS lbannv2\n  EXPORT lbannv2Targets\n  FILE_SET HEADERS\n)\n\ninstall(EXPORT lbannv2Targets\n  DESTINATION ${CMAKE_INSTALL_CMAKEDIR}\n  NAMESPACE lbann::\n)\n\ninstall(TARGETS _lbannv2\n  DESTINATION ${CMAKE_INSTALL_LIBDIR}\n)\n\ninclude(CMakePackageConfigHelpers)\nconfigure_package_config_file(\n  cmake/lbannv2Config.cmake.in\n  \"${CMAKE_BINARY_DIR}/lbannv2Config.cmake\"\n  INSTALL_DESTINATION \"${CMAKE_INSTALL_CMAKEDIR}\"\n)\nwrite_basic_package_version_file(\n  lbannv2ConfigVersion.cmake\n  COMPATIBILITY SameMinorVersion\n)\ninstall(\n  FILES\n  \"${CMAKE_BINARY_DIR}/lbannv2Config.cmake\"\n  \"${CMAKE_BINARY_DIR}/lbannv2ConfigVersion.cmake\"\n  DESTINATION \"${CMAKE_INSTALL_CMAKEDIR}\"\n)\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing Guidelines for LBANN\n\nWe welcome any contributions to LBANN in the form of Pull Requests.\nPlease follow the guidelines below for more information.\n\n## Attribution\n\nIf you have not added yourself to the authors list in \n[CONTRIBUTORS](https://github.com/LLNL/lbann/blob/develop/CONTRIBUTORS), please do so in the appropriate place.\n\n## git guidelines\n\nWhen ready for review and merge, Pull Requests must match the latest `develop` branch commit.\nIf not ready, **rebase** the commits onto the latest commit. Avoid merge commits.\n\n## Style guidelines\n\nFor C/C++ and GPU code, we follow the [LLVM coding style](https://llvm.org/docs/CodingStandards.html) with\nadaptations, see the [coding style README](https://github.com/LLNL/lbann/blob/develop/README_coding_style.txt) and the\n[clang-format configuration](https://github.com/LLNL/lbann/blob/develop/.clang-format) for more information.\n\nFor Python code, we follow the [Google coding style](https://google.github.io/styleguide/pyguide.html) guidelines,\nbut allow some exceptions to create layers in the LBANN Python frontend.\n\n## Setting up automatic formatting\n\nTo enforce file formatting at every commit, you can use the pre-commit hook provided in the repository.\nMake a symbolic link from `.git/hooks/pre-commit` to our script by running the following command\n**from the root of your git repository**:\n\n```sh\nuser@/path/to/lbann$ ln -s ../../scripts/pre-commit-hook.sh .git/hooks/pre-commit\n```\n\nMake sure you have `clang-format` installed for C/C++ formatting. If you do not have it installed in the path,\nyou may override it by setting the `$CLANG_FORMAT` environment variable to its path.\n"
  },
  {
    "path": "CONTRIBUTORS",
    "content": "LLNL Core Team:\n  Brian Van Essen <vanessen1@llnl.gov> [@bvanessen]\n  Tom Benson <benson31@llnl.gov> [@benson31]\n  Nikoli Dryden <dryden1@llnl.gov> [@ndryden]\n  Tal Ben-Nun <talbn@llnl.gov> [@tbennun]\n  Pier Fiedorowicz <fiedorowicz1@llnl.gov> [@fiedorowicz1]\n  \nCollaborators:\n  Shehtab Zaman [@szaman19]\n\nNotable Prior LLNL Team Members\n  Ryan Forsyth [@forsyth2]\n  David Hysom [@davidHysom]\n  Katie Graham [@graham63]\n  Keita Iwabuchi [@KIwabuchi]\n  Sam Ade Jacobs [@samadejacobs]\n  Arpan Jain [@aj-prime]\n  Hyojin Kim\n  Naoya Maruyama [@naoyam]\n  Erin McCarthy\n  Adam Moody [@adammoody]\n  Tim Moon [@timmoon10]\n  Yosuke Oyama [@oyamay}\n  Michael Wyatt [@mrwyattii]\n  Jae-Seung Yeom [@JaeseungYeom]\n\n"
  },
  {
    "path": "LICENSE",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\nCopyright (c) 2014-2024, Lawrence Livermore National Security, LLC.\nProduced at the Lawrence Livermore National Laboratory.\nWritten by the LBANN Research Team (B. Van Essen, et al.) listed in\nthe CONTRIBUTORS file. <lbann-dev@llnl.gov>\n\nLLNL-CODE-697807.\nAll rights reserved.\n\nThis file is part of LBANN: Livermore Big Artificial Neural Network\nToolkit. For details, see http://software.llnl.gov/LBANN or\nhttps://github.com/LBANN and https://github.com/LLNL/LBANN.\n\nLicensed under the Apache License, Version 2.0 (the \"Licensee\"); you\nmay not use this file except in compliance with the License.  You may\nobtain a copy of the License at:\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\nimplied. See the License for the specific language governing\npermissions and limitations under the license.\n"
  },
  {
    "path": "NOTICE",
    "content": "This work was produced under the auspices of the U.S. Department of Energy by\nLawrence Livermore National Laboratory under Contract DE-AC52-07NA27344.\n\nThis work was prepared as an account of work sponsored by an agency of the\nUnited States Government. Neither the United States Government nor Lawrence\nLivermore National Security, LLC, nor any of their employees makes any warranty,\nexpressed or implied, or assumes any legal liability or responsibility for the\naccuracy, completeness, or usefulness of any information, apparatus, product, or\nprocess disclosed, or represents that its use would not infringe privately owned\nrights. Reference herein to any specific commercial product, process, or service\nby trade name, trademark, manufacturer, or otherwise does not necessarily\nconstitute or imply its endorsement, recommendation, or favoring by the United\nStates Government or Lawrence Livermore National Security, LLC. The views and\nopinions of authors expressed herein do not necessarily state or reflect those\nof the United States Government or Lawrence Livermore National Security, LLC,\nand shall not be used for advertising or product endorsement purposes.\n"
  },
  {
    "path": "README.md",
    "content": "# Build\n\nTo save some pip-related heartburn, LBANNv2 is currently BYOT (\"bring\nyour own Torch\").\n\n```\npip install torch\npip install .\n```\n\n# License\n\nCopyright 2014-2025 Lawrence Livermore National Security, LLC and other\nLBANN Project Developers. See the top-level LICENSE file for details.\n\nSPDX-License-Identifier: Apache-2.0\n\nLLNL-CODE-697807\n"
  },
  {
    "path": "cmake/LBANNv2DetectTorchNVIDIALibraries.cmake",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\nfunction(detect_torch_nvidia_libraries)\n  set(_detect_opts)\n  set(_detect_single_val_args)\n  set(_detect_multi_value_args LIBRARIES)\n  cmake_parse_arguments(PARSE_ARGV 0 _detect\n    \"${_detect_opts}\" \"${_detect_single_value_args}\" \"${_detect_multi_value_args}\")\n\n  find_package(Python 3.9 REQUIRED COMPONENTS Interpreter Development.Module)\n\n  # Get information about torch. If Pip doesn't know about torch, that's\n  # fine. We just stop and fall back on the user's environment, assuming\n  # Torch to have been built from source.\n  execute_process(\n    COMMAND \"${Python_EXECUTABLE}\" -m pip show --no-color torch\n    ERROR_VARIABLE _detect_pip_show_error\n    OUTPUT_VARIABLE _detect_pip_show_output\n    RESULT_VARIABLE _detect_pip_show_result)\n\n  # Split the string on newlines\n  string(REPLACE \"\\n\" \";\" _detect_torch_show_lines \"${_detect_pip_show_output}\")\n\n  # And find the \"requires\" line:\n  list(FILTER _detect_torch_show_lines INCLUDE REGEX \"^Requires\")\n\n  # Now filter that down to the NVIDIA modules:\n  string(REGEX MATCHALL \"nvidia-[-a-z]+-cu[0-9]+\" _detect_nvidia_modules \"${_detect_torch_show_lines}\")\n\n  # Now that we have a list of modules, we need to search the file lists\n  # of these. There are at least 2 approaches.\n  #\n  #  1. We can interrogate 'pip show --files <module name>', parse the\n  #     base from the \"Location\" line and prepend it to any matching\n  #     lines under the \"Files\" header to get the full paths to any\n  #     relevant files.\n  #\n  #  2. We can use 'importlib.metadata' to parse the metadata associated\n  #     with each module. This has the advantage that we don't have to\n  #     do as much manual parsing and string manipulation -- the data we\n  #     need can be generated with a simple list comprehension.\n  #\n  # While neither approach is particularly difficult, I've opted for\n  # number 2. I especially like that by joining the output string with\n  # semicolons, CMake will natively interpret the list of paths as a\n  # CMake list, further simplifying things.\n\n  # Get the list of paths out of the metadata. Separate with semicolon\n  # so CMake interprets the output as a list directly.\n  set(_detect_get_paths_program\n    \"import importlib.metadata as md; import sys; print(\\\";\\\".join([str(f.locate()) for f in md.files(sys.argv[1])]))\")\n\n  foreach (lib IN LISTS _detect_LIBRARIES)\n    string(REGEX MATCH \"nvidia-${lib}-cu[0-9]+\" _detect_nvidia_lib_module \"${_detect_nvidia_modules}\")\n\n    # Find paths\n    execute_process(\n      COMMAND \"${Python_EXECUTABLE}\" -c \"${_detect_get_paths_program}\" \"${_detect_nvidia_lib_module}\"\n      ERROR_VARIABLE _detect_get_paths_error\n      OUTPUT_VARIABLE _detect_get_paths_output\n      RESULT_VARIABLE _detect_get_paths_result)\n\n    foreach (path ${_detect_get_paths_output})\n      if (path MATCHES \".*${lib}\\\\.h$\")\n\n        cmake_path(GET\n          path\n          PARENT_PATH\n          _detect_parent_path)\n        set(LBANNV2_DETECTED_${lib}_INCLUDE_PATH\n          \"${_detect_parent_path}\"\n          CACHE\n          PATH\n          \"Include directory for ${lib}\")\n\n      elseif (path MATCHES \".*lib${lib}${CMAKE_SHARED_LIBRARY_SUFFIX}.*\")\n\n        set(LBANNV2_DETECTED_${lib}_LIBRARY\n          \"${path}\"\n          CACHE\n          FILEPATH\n          \"Library for ${lib}\")\n\n      endif ()\n    endforeach ()\n\n    # Consider the thing found if both the include path and the\n    # library are available.\n    if (LBANNV2_DETECTED_${lib}_LIBRARY AND LBANNV2_DETECTED_${lib}_INCLUDE_PATH)\n      set(LBANNV2_DETECTED_${lib} TRUE PARENT_SCOPE)\n    endif ()\n  endforeach ()\nendfunction ()\n"
  },
  {
    "path": "cmake/LBANNv2DetermineMI300A.cmake",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ncmake_minimum_required(VERSION 3.24.0)\n\n# Tries to determine whether the machine in question is MI300A. The\n# check hinges on \"rocm-smi\" returning sane things. Sadly we must do\n# this rather than just testing the arch flag because \"gfx942\" refers\n# to both the MI300A and the MI300X.\n#\n# The return value is ternary:\n#\n#   - \"WITH\" means we have determined that we have MI300A\n#   - \"WITHOUT\" means we have determined that we do NOT have MI300A\n#   - \"UNKNOWN\" means that we cannot determine whether we have MI300A,\n#     generally because \"rocm-smi\" produced no usable output.\n#\n# In cases where the node on which LBANNv2 is being built does not\n# have GPUs, or if it happens to have *different* GPUs from the ones\n# on the compute nodes, users are advised to provide this information\n# directly, if possible.\n#\n# As this all comes down to \"rocm-smi\" output on the node on which\n# LBANNv2 is configured, users are advised that there is high risk for\n# incorrect or suboptimal information if not configuring on a compute\n# node.\nfunction (determine_mi300a_support OUTPUT_VARIABLE)\n\n  # Call rocm-smi (should be in the PATH). If rsmi fails, then we say \"unknown\".\n  execute_process(\n    COMMAND rocm-smi --showproductname --json\n    OUTPUT_VARIABLE _rsmi_info\n    ERROR_VARIABLE _rsmi_error\n    ERROR_QUIET\n  )\n\n  if (_rsmi_error AND _rsmi_error MATCHES \".*ERROR.*\")\n    set(${OUTPUT_VARIABLE} \"UNKNOWN\" PARENT_SCOPE)\n    return ()\n  endif ()\n\n  string(JSON _gfx_version\n    ERROR_VARIABLE _json_err\n    GET \"${_rsmi_info}\" \"card0\" \"GFX Version\")\n\n  # To get here, rsmi returned something valid, and this path just was not right.\n  if (_json_err)\n    message(DEBUG\n      \"JSON Error: ${_json_err}\\n\\nAssuming MI300A status is 'UNKNOWN'.\")\n    set(${OUTPUT_VARIABLE} \"UNKNOWN\" PARENT_SCOPE)\n    return ()\n  endif ()\n\n  if (_gfx_version MATCHES \".*gfx942.*\")\n    execute_process(\n      COMMAND rocminfo\n      OUTPUT_VARIABLE _rocminfo_output\n      ERROR_VARIABLE _rocminfo_error\n      ERROR_QUIET\n    )\n    string(FIND \"${_rocminfo_output}\" \"MI300A\" _mi300a_exists)\n    if (_mi300a_exists EQUAL -1)\n      set(${OUTPUT_VARIABLE} \"WITHOUT\" PARENT_SCOPE)\n    else ()\n      set(${OUTPUT_VARIABLE} \"WITH\" PARENT_SCOPE)\n    endif ()\n  else ()\n    set(${OUTPUT_VARIABLE} \"WITHOUT\" PARENT_SCOPE)\n  endif ()\nendfunction ()\n"
  },
  {
    "path": "cmake/lbannv2Config.cmake.in",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ninclude(\"${CMAKE_CURRENT_LIST_DIR}/lbannv2ConfigVersion.cmake\")\nset(LBANNv2_VERSION ${PACKAGE_VERSION})\n\ninclude(CMakeFindDependencyMacro)\n\nset(lbannv2_MINIMUM_H2_VERSION @lbannv2_MINIMUM_H2_VERSION@)\nset(lbannv2_MINIMUM_Torch_VERSION @lbannv2_MINIMUM_Torch_VERSION@)\n\nfind_dependency(DiHydrogen\n  ${lbannv2_MINIMUM_H2_VERSION}\n  COMPONENTS Core Meta Patterns\n)\nfind_dependency(Torch\n  ${lbannv2_MINIMUM_Torch_VERSION}\n)\n\n@PACKAGE_INIT@\n\nif (NOT TARGET lbann::lbannv2)\n  include(\"${CMAKE_CURRENT_LIST_DIR}/lbannv2Targets.cmake\")\nendif ()\n\ncheck_required_components(lbannv2)\nset(LBANNv2_LIBRARIES lbann::lbannv2)\n"
  },
  {
    "path": "cmake/lbannv2_config.h.in",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n// clang-format off\n\n#include <lbannv2_export.h>\n\n// Version information\n#define LBANNV2_VERSION_MAJOR @PROJECT_VERSION_MAJOR@\n#define LBANNV2_VERSION_MINOR @PROJECT_VERSION_MINOR@\n#define LBANNV2_VERSION_PATCH @PROJECT_VERSION_PATCH@\n#define LBANNV2_VERSION \"@PROJECT_VERSION@\"\n\n#cmakedefine01 LBANNV2_DEBUG_MODE\n\n#cmakedefine01 LBANNV2_HAS_CUDA\n#cmakedefine01 LBANNV2_HAS_ROCM\n#define LBANNV2_HAS_GPU (LBANNV2_HAS_CUDA + LBANNV2_HAS_ROCM)\n\n#cmakedefine01 LBANNV2_WITH_MI300A\n#cmakedefine01 LBANNV2_WITHOUT_MI300A\n#cmakedefine01 LBANNV2_UNKNOWN_MI300A\n\n#cmakedefine01 LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS\n\n#ifndef SPDLOG_ACTIVE_LEVEL\n// This defaults to \"TRACE\" so that all messages are compiled and\n// available. Use the runtime environment variable to control which\n// are seen.\n#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE\n#endif\n\n// clang-format on\n"
  },
  {
    "path": "pyproject.toml",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\n[build-system]\nrequires = [\n  \"scikit-build-core>=0.10\",\n  \"pybind11\"\n]\nbuild-backend = \"scikit_build_core.build\"\n\n[project]\nname = \"lbannv2\"\nversion = \"0.0.1\"\ndescription = \"LBANN's core integration with PyTorch\"\nauthors = [\n  { name = \"Tal Ben Nun\", email = \"bennun2@llnl.gov\" },\n  { name = \"Tom Benson\", email = \"benson31@llnl.gov\" },\n  { name = \"Nikoli Dryden\", email = \"dryden1@llnl.gov\" },\n  { name = \"Pier Fiedorowicz\", email = \"fiedorowicz1@llnl.gov\" },\n  { name = \"Brian Van Essen\", email = \"vanessen1@llnl.gov\" },\n]\nlicense = { file = \"LICENSE\" }\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nclassifiers = [\n  \"Development Status :: 2 - Pre-Alpha\",\n\n  \"License :: OSI Approved :: Apache Software License\",\n\n  \"Programming Language :: C++\",\n\n  \"Programming Language :: Python :: 3\",\n  \"Programming Language :: Python :: 3.9\",\n  \"Programming Language :: Python :: 3.10\",\n  \"Programming Language :: Python :: 3.11\",\n  \"Programming Language :: Python :: 3.12\",\n  \"Programming Language :: Python :: 3.13\",\n  \"Programming Language :: Python :: 3.14\",\n\n  \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n  \"Topic :: Software Development :: Libraries\",\n  \"Topic :: Software Development :: Libraries :: Python Modules\",\n  \"Topic :: Software Development :: Version Control :: Git\",\n\n  \"Private :: Do Not Upload\"\n  ]\ndependencies = [\n  \"pybind11\",\n  \"torch>=2.9\"\n  ]\n\n[project.optional-dependencies]\ntest = [\"pytest\"]\n\n[tool.scikit-build]\nminimum-version = \"build-system.requires\"\nbuild-dir = \"build\"\ncmake.version = \">=3.30.0\"\nninja.version = \">=1.11\"\nninja.make-fallback = false\nwheel.expand-macos-universal-tags = true\nwheel.install-dir = \"lbannv2\"\n\n[tool.pytest]\nminversion = \"9.0\"\ntestpaths = [\n    \"test/py\",\n]"
  },
  {
    "path": "python/lbannv2/__init__.py",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\nimport sys\nimport torch\n\ntry:\n    from .lib._lbannv2 import *\nexcept ModuleNotFoundError:\n    from .lib64._lbannv2 import *\n\nfrom ._automigrate import automigrate\n\n# Setup state needed by the library\ninit_lbannv2()\n\ndef is_available():\n    try:\n        return bool(is_lbannv2_gpu_available())\n    except Exception:\n        return False\n\nclass MigratableMemory:\n    \"\"\"Use LBANNv2's allocator for the given device\"\"\"\n\n    def __enter__(self):\n        use_mi300a_host_allocator()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        use_pytorch_host_allocator()\n\n\ndef make_migratory_tensor(ctor, *args, **kwargs):\n    with MigratableMemory():\n        return ctor(*args, **kwargs)\n"
  },
  {
    "path": "python/lbannv2/_automigrate.py",
    "content": "import torch\nfrom typing import Callable, Union\n\ntry:\n    from .lib._lbannv2 import migrate\nexcept ModuleNotFoundError:\n    from .lib64._lbannv2 import migrate\n\n\ndef automigrate(f: Union[Callable, torch.fx.GraphModule]) -> torch.fx.GraphModule:\n    \"\"\"Check the graph for candidates for automatic pointer migration,\n    replacing them with appropriate calls to 'migrate'. This function\n    operates at the ATen IR (FX Graph) level, so it cannot perfectly\n    determine all cases in which a migrate is possible. Symbolic\n    tracing cannot, for instance, tell the device on which inputs or\n    \"member tensors\" (e.g., of some nn layer) reside. We can make some\n    inferences, though (e.g., all nodes downstream of a memory\n    relocation call (\"to\", \"cpu\", etc) can be assumed to live on that\n    device until the next such relocation call). Additionally, we\n    cannot, in general, know the provenance of the underlying memory\n    of a given tensor. At this time, we only support migration that\n    comes from LBANNv2 allocators, as we are 100% sure of its\n    allocation. We could likely support any memory that was allocated\n    by or registered with the HIP runtime, though -- \"future work\".\n\n    Args:\n        f (Union[Callable, torch.fx.GraphModule]): Any callable\n            amenable to symbolic_trace()-ing. If this is a\n            torch.fx.GraphModule, it will be modified in-place and\n            returned.\n\n    Returns:\n        A torch.fx.GraphModule representing the input Callable, with\n            \"data movement\" nodes replaced with LBANNv2 pointer\n            \"migration\", when appropriate. If the input was already a\n            torch.fx.GraphModule, it is modified in-place and\n            returned.\n\n    \"\"\"\n\n    def safe_for_migrate(n: torch.fx.graph.Node) -> bool:\n        \"\"\"\n        At the IR level, a tensor is a candidate for migration if\n        it isn't used multiple places and if the underlying operation\n        isn't trying to change more than just the device.\n        \"\"\"\n        input_ok = len(n.args[0].users) == 1\n        args_ok = len(node.kwargs) == 1 and \"device\" in node.kwargs\n        # If we're dealing with 'cuda' or 'cpu', then it's ok for the\n        # kwargs to be empty (note that 'cuda' also supports 'device'\n        # as a kwarg).\n        if n.target != \"to\":\n            args_ok = args_ok or len(node.kwargs) == 0\n        return input_ok and args_ok\n\n    def get_target_device(n: torch.fx.graph.Node) -> torch.device:\n        return (\n            torch.device(n.kwargs[\"device\"])\n            if \"device\" in n.kwargs\n            else torch.device(str(n.target))\n        )\n\n    if isinstance(f, torch.fx.GraphModule):\n        gm = f\n    else:\n        gm = torch.fx.symbolic_trace(f)\n\n    # We can handle \"to\" or the device-specific methods (\"cuda\", e.g.).\n    migrate_candidates = [\"to\", \"cuda\", \"cpu\"]\n    for node in gm.graph.nodes:\n        if node.target in migrate_candidates and safe_for_migrate(node):\n            with gm.graph.inserting_before(node):\n                # Add a new node\n                new_node = gm.graph.call_function(\n                    migrate,\n                    args=(*node.args, get_target_device(node)),\n                )\n                node.replace_all_uses_with(new_node)\n\n            gm.graph.erase_node(node)\n\n    gm.recompile()\n    return gm\n"
  },
  {
    "path": "src/lbannv2/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ntarget_sources(lbannv2\n  PUBLIC\n  FILE_SET HEADERS\n  FILES\n  types.hpp\n)\n\nadd_subdirectory(memory)\nadd_subdirectory(ops)\nadd_subdirectory(utils)\n\n# Pybind/Torch registration\nif (SKBUILD)\n  add_subdirectory(python)\nendif ()\n"
  },
  {
    "path": "src/lbannv2/memory/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ntarget_sources(lbannv2\n  PUBLIC\n  FILE_SET HEADERS\n  FILES\n  allocator.hpp\n  # h2_allocator_wrappers.hpp\n  registry.hpp\n)\ntarget_sources(lbannv2\n  PRIVATE\n  allocator.cpp\n  registry.cpp\n)\n\nif (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)\n  target_sources(lbannv2\n    PUBLIC\n    FILE_SET HEADERS\n    FILES\n    mi300a_allocator.hpp\n  )\n  target_sources(lbannv2\n    PRIVATE\n    mi300a_allocator.cpp\n  )\nendif ()\n"
  },
  {
    "path": "src/lbannv2/memory/allocator.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include \"lbannv2/memory/allocator.hpp\"\n\n#include \"lbannv2/memory/registry.hpp\"\n#include \"lbannv2/utils/errors.hpp\"\n#include \"lbannv2/utils/logging.hpp\"\n\n#include <c10/core/CPUAllocator.h>\n\n#if LBANNV2_HAS_CUDA\n#include <ATen/cuda/CUDAContextLight.h>\n#elif LBANNV2_HAS_ROCM\n#include <ATen/hip/HIPContextLight.h>\n#endif\n\n#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n#include \"lbannv2/memory/mi300a_allocator.hpp\"\n#endif\n\nnamespace lbannv2\n{\n\nc10::DataPtr Allocator::allocate(size_t n)\n{\n  // Do the allocation\n  void* const buffer = this->raw_alloc(n);\n\n  // Log the allocation\n  LBANNV2_TRACE(\"Allocator::allocate(n={}, ptr={})\", n, buffer);\n  pointer_registry().add(buffer, n, this);\n\n  // Decorate the allocation.\n  return {buffer, buffer, this->raw_deleter(), this->get_device()};\n}\n\n}  // namespace lbannv2\n\nbool lbannv2::is_managed_ptr(void const* const ptr) noexcept\n{\n  return pointer_registry().known(ptr);\n}\n\nnamespace\n{\n\nc10::Allocator* pt_orig_cpu_alloc_ = nullptr;\n\n}  // namespace\n\nvoid lbannv2::use_mi300a_cpu_allocator()\n{\n#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n#if LBANNV2_UNKNOWN_MI300A\n  if (gpu::is_integrated())\n#endif\n  {\n    if (!pt_orig_cpu_alloc_)\n      pt_orig_cpu_alloc_ = c10::GetCPUAllocator();\n    c10::SetCPUAllocator(&MI300Allocator::instance());\n    return;\n  }\n#endif\n  LBANNV2_WARN(\"No MI300A allocator available\");\n}\n\nvoid lbannv2::use_torch_cpu_allocator()\n{\n  if (pt_orig_cpu_alloc_)\n    c10::SetCPUAllocator(pt_orig_cpu_alloc_);\n}\n"
  },
  {
    "path": "src/lbannv2/memory/allocator.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n#include <lbannv2_config.h>\n\n#include <c10/core/Allocator.h>\n\nnamespace lbannv2\n{\n\n/** @class Allocator\n *  @brief A simplistic interface for LBANN allocators.\n */\nclass LBANNV2_EXPORT Allocator : public c10::Allocator\n{\npublic:\n  virtual void* raw_alloc(size_t nbytes) = 0;\n  virtual void raw_dealloc(void* ptr) = 0;\n  virtual c10::Device get_device() const noexcept = 0;\n\n  c10::DataPtr allocate(size_t n) final;\n};  // class Allocator\n\nLBANNV2_EXPORT bool is_managed_ptr(void const* ptr) noexcept;\n\nLBANNV2_EXPORT void use_mi300a_cpu_allocator();\nLBANNV2_EXPORT void use_torch_cpu_allocator();\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/memory/h2_allocator_wrappers.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2/memory/h2_allocator_wrappers.hpp>\n\nnamespace lbannv2\n{\n\ntemplate <h2::Device D>\nH2AllocatorWrapper<D>& H2AllocatorWrapper<D>::instance()\n{\n  static H2AllocatorWrapper<D> allocator;\n  return allocator;\n}\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/memory/h2_allocator_wrappers.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#include <lbannv2/memory/allocator.hpp>\n#include <lbannv2/utils/logging.hpp>\n\n#include <h2/core/allocator.hpp>\n\n#include <c10/core/Allocator.h>\n\nnamespace lbannv2\n{\n\ntemplate <h2::Device D>\nclass H2AllocatorWrapper : public Allocator\n{\n  using AllocatorType = h2::internal::Allocator<std::byte, D>;\n\npublic:\n  /** @name Virtual function overrides */\n  ///@{\n\n  // memcpy\n  void copy_data(void* dst, void const* src, size_t n) const final\n  {\n    if constexpr (D == h2::Device::CPU)\n    {\n      LBANNV2_TRACE(\"H2AllocatorWrapper<CPU>::copy_data(dst={}, src={}, bytes={})\",\n                    dst, src, n);\n      std::memcpy(dst, src, n);\n    }\n#if LBANNV2_HAS_GPU\n    if constexpr (D == h2::Device::GPU)\n    {\n      LBANNV2_TRACE(\"H2AllocatorWrapper<GPU>::copy_data(dst={}, src={}, bytes={})\",\n                    dst, src, n);\n      h2::gpu::mem_copy(dst, src, n);\n    }\n#endif\n  }\n\n  void* raw_allocate(size_t n) final\n  {\n    return reinterpret_cast<void*>(\n      AllocatorType::allocate(n, h2::ComputeStream {D}));\n  }\n\n  void raw_deallocate(void* ptr) final\n  {\n    AllocatorType::deallocate(reinterpret_cast<std::byte*>(ptr),\n                              h2::ComputeStream {D});\n  }\n\n  c10::Device get_device() const noexcept final\n  {\n#if LBANNV2_HAS_GPU\n    if constexpr (D == h2::Device::GPU)\n      return c10::Device {\n        c10::kCUDA, static_cast<c10::DeviceIndex>(h2::gpu::current_gpu())};\n#endif\n    return c10::Device {c10::kCPU};\n  }\n\n  ///@}\n\n  // Singleton\n  static H2AllocatorWrapper& instance()\n  {\n    static H2AllocatorWrapper<D> allocator;\n    return allocator;\n  }\n\nprivate:\n  H2AllocatorWrapper() = default;\n  ~H2AllocatorWrapper() = default;\n  H2AllocatorWrapper(H2AllocatorWrapper const&) = delete;\n  H2AllocatorWrapper(H2AllocatorWrapper&&) = delete;\n  H2AllocatorWrapper& operator=(H2AllocatorWrapper const&) = delete;\n  H2AllocatorWrapper& operator=(H2AllocatorWrapper&&) = delete;\n};\n\nusing H2CPUAllocatorWrapper = H2AllocatorWrapper<h2::Device::CPU>;\n#if LBANNV2_HAS_GPU\nusing H2GPUAllocatorWrapper = H2AllocatorWrapper<h2::Device::GPU>;\n#endif\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/memory/memory_utils.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#include <lbannv2/memory/allocator.hpp>\n\nnamespace lbannv2\n{\n\n/** @class AllocatorWrapper\n *  @brief Wrap an allocator with a different device.\n *\n *  This wraps a c10::Allocator instance. Allocations from that\n *  allocator are intercepted and the DataPtr is updated to have the\n *  specified Device.\n *\n *  The primary intention is to wrap LBANN allocators as \"native\n *  device\" allocators, though it could be used the other way, too.\n *  However, there is no pointer registration in this class -- LBANNv2\n *  allocators handle this internally, so including that here would\n *  \"double register\" pointers. This could be cleaned up a bit down\n *  the road.\n */\nclass AllocatorWrapper : public c10::Allocator\n{\npublic:\n  /** @brief Constructor\n   *\n   *  @param[in] alloc The allocator to wrap.\n   *  @param[in] device The device to use for DataPtrs produced by\n   *                    this allocator.\n   */\n  AllocatorWrapper(c10::Allocator& alloc, c10::Device device)\n    : m_alloc {alloc}, m_device {std::move(device)}\n  {}\n  ~AllocatorWrapper() = default;\n\n  c10::DataPtr allocate(size_t n) final\n  {\n    auto dptr = m_alloc.allocate(n);\n    dptr.unsafe_set_device(m_device);\n    // NOTE (trb): We could replace the deleter fn to be\n    // this->raw_deleter, but since this->raw_deleter() just calls\n    // that->raw_deleter(), what would be the point?? This story\n    // changes if we start tracking memory allocations in the registry\n    // through this class.\n    return dptr;\n  }\n\n  c10::DeleterFnPtr raw_deleter() const noexcept final\n  {\n    return m_alloc.raw_deleter();\n  }\n\n  void copy_data(void* dst, void const* src, size_t n) const final\n  {\n    m_alloc.copy_data(dst, src, n);\n  }\n\nprivate:\n  c10::Allocator& m_alloc;\n  c10::Device m_device;\n};  // class AllocatorWrapper\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/memory/mi300a_allocator.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include \"lbannv2_config.h\"\n\n#include \"lbannv2/memory/mi300a_allocator.hpp\"\n\n#include \"lbannv2/memory/registry.hpp\"\n#include \"lbannv2/utils/errors.hpp\"\n#include \"lbannv2/utils/gpu_utils.hpp\"\n#include \"lbannv2/utils/logging.hpp\"\n\n#if LBANNV2_HAS_CUDA\n#include <ATen/cuda/CUDAContextLight.h>\n#include <c10/cuda/CUDAStream.h>\n#elif LBANNV2_HAS_ROCM\n#include <ATen/hip/HIPContextLight.h>\n#include <c10/hip/HIPStream.h>\n#endif\n\n#include <c10/core/CachingDeviceAllocator.h>\n\nnamespace\n{\nbool get_use_nonblocking_stream_env_var()\n{\n  char* env = std::getenv(\"LBANNV2_NONBLOCKING_HOST_ALLOC_STREAM\");\n  return env && std::strlen(env) && env[0] != '0';\n}\n\nbool use_nonblocking_stream()\n{\n  static bool const nonblock = get_use_nonblocking_stream_env_var();\n  LBANNV2_DEBUG(\"Using nonblocking MI300A allocation stream? {}\", nonblock);\n  return nonblock;\n}\n\nstruct StreamRAII\n{\n  ::lbannv2::TorchGPUStream_t stream;\n\n  StreamRAII()\n    : stream {lbannv2::c10_gpu::getStreamFromExternal(\n        use_nonblocking_stream() ? lbannv2::gpu::make_nonblocking_stream()\n                                 : lbannv2::gpu::make_stream(),\n        lbannv2::gpu::current_device())}\n  {}\n  ~StreamRAII()\n  {\n    try\n    {\n      lbannv2::gpu::destroy_stream(stream.stream());\n    }\n    catch (...)\n    {}\n  }\n};  // struct StreamRAII\n\n// Internal stream for managing \"host\" allocations through CUB\n::lbannv2::TorchGPUStream_t host_allocation_stream(c10::DeviceIndex const idx)\n{\n  static std::vector<StreamRAII> stream_raii(lbannv2::gpu::num_devices());\n  LBANNV2_ASSERT_ALWAYS(idx >= 0 && idx < lbannv2::gpu::num_devices());\n  return stream_raii[idx].stream;\n}\n\nc10::Device resolve_device(c10::Device const& d)\n{\n  if (d.is_cuda() && !d.has_index())\n    return {c10::kCUDA, lbannv2::gpu::current_device()};\n\n  return d;\n}\n\n#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS\nnamespace DeviceAlloc_ns = c10::hip::HIPCachingAllocator;\n#else\nnamespace DeviceAlloc_ns = c10::cuda::CUDACachingAllocator;\n#endif\n\nvoid lbannv2_report_free(DeviceAlloc_ns::TraceEntry const& entry)\n{\n  try\n  {\n    void* const ptr = reinterpret_cast<void*>(entry.addr_);\n    lbannv2::pointer_registry().remove(ptr);\n    LBANNV2_TRACE(\"Deallocate (ptr={})\", (void const*) ptr);\n  }\n  catch (lbannv2::UnknownAddress const&)\n  {\n    // ignore -- ptr allocated in Torch\n  }\n}\n\nvoid lbannv2_trace_alloc(DeviceAlloc_ns::TraceEntry const& entry)\n{\n  if (entry.action_ == DeviceAlloc_ns::TraceEntry::FREE_COMPLETED)\n    lbannv2_report_free(entry);\n}\n}  // namespace\n\nnamespace lbannv2\n{\n\nMI300Allocator::MI300Allocator()\n{\n#if LBANNV2_WITHOUT_MI300A || LBANNV2_UNKNOWN_MI300A\n#if LBANNV2_UNKNOWN_MI300A\n  if (!lbannv2::gpu::is_integrated())\n#endif\n    throw std::runtime_error(\"MI300Allocator is only supported on MI300A\");\n#endif\n\n  auto* const dev_alloc =\n    dynamic_cast<DeviceAlloc_t*>(at::cuda::getCUDADeviceAllocator());\n  LBANNV2_ASSERT_ALWAYS(dev_alloc);\n  if (!dev_alloc->initialized())\n    dev_alloc->init(gpu::num_devices());\n\n  // Trace memory stuff\n  dev_alloc->attachAllocatorTraceTracker(lbannv2_trace_alloc);\n\n  alloc_ = dev_alloc;\n}\n\nvoid MI300Allocator::copy_data(void* const dst,\n                               void const* const src,\n                               size_t const bytes) const\n{\n  LBANNV2_TRACE(\n    \"MI300Allocator::copy_data(dst={}, src={}, bytes={})\", dst, src, bytes);\n  std::memcpy(dst, src, bytes);\n}\n\nvoid* MI300Allocator::raw_alloc(size_t const nbytes)\n{\n  auto* const ptr = alloc_->raw_alloc_with_stream(\n    nbytes, host_allocation_stream(lbannv2::gpu::current_device()));\n\n  LBANNV2_TRACE(\n    \"MI300Allocator::raw_allocate(nbytes={}): ptr={}, current_device={}\",\n    nbytes,\n    ptr,\n    lbannv2::gpu::current_device());\n  lbannv2::gpu::sync(host_allocation_stream(lbannv2::gpu::current_device()));\n\n  return ptr;\n}\n\nvoid MI300Allocator::raw_dealloc(void* ptr)\n{\n  LBANNV2_TRACE(\"MI300Allocator::raw_deallocate(ptr={})\", ptr);\n  alloc_->raw_delete(ptr);\n}\n\nc10::Device MI300Allocator::get_device() const noexcept\n{\n  return c10::Device {c10::kCPU};\n}\n\nc10::DeleterFnPtr MI300Allocator::raw_deleter() const\n{\n  return alloc_->raw_deleter();\n}\n\nMI300Allocator& MI300Allocator::instance()\n{\n  static MI300Allocator alloc;\n  return alloc;\n}\n\n}  // namespace lbannv2\n\nc10::DeviceIndex lbannv2::get_device_idx(void const* const ptr) noexcept\n{\n  int device_idx;\n  auto const hip_status = hipPointerGetAttribute(\n    &device_idx, HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, const_cast<void*>(ptr));\n  if (hip_status == hipSuccess)\n  {\n    return static_cast<c10::DeviceIndex>(device_idx);\n  }\n  else\n  {\n    LBANNV2_DEBUG(\"lbannv2::get_device_idx(ptr={}) failed. Error: {}\",\n                  ptr,\n                  hipGetErrorString(hip_status));\n    return -1;\n  }\n}\n\n// Let's aim for a fully robust implementation here. We must consider:\n//   1. Migrating from D(:m) -> D(:m) is a no-op.\n//   2. Migrating from D:m -> D:n is a deep copy\nvoid lbannv2::migrate_ptr(c10::DataPtr& ptr,\n                          c10::Device to_device,\n                          c10::Stream with_stream)\n{\n  auto const real_tgt_device = resolve_device(to_device);\n\n  // If no migration actually happens, just short-circuit...\n  if (ptr.device() == real_tgt_device)\n    return;\n\n#if LBANNV2_WITHOUT_MI300A || LBANNV2_UNKNOWN_MI300A\n#if LBANNV2_UNKNOWN_MI300A\n  if (!lbannv2::gpu::is_integrated())\n#endif\n  {\n    throw std::runtime_error(\"migrate_ptr is only supported on MI300A\");\n  }\n#endif\n\n  // Check that the migration is valid\n  auto const ptr_dev_idx = get_device_idx(ptr.get_context());\n  c10::Device const real_src_device = ptr_dev_idx == -1\n                                        ? c10::Device {c10::kCPU}\n                                        : c10::Device {c10::kCUDA, ptr_dev_idx};\n  LBANNV2_ASSERT(real_tgt_device.is_cpu() || real_src_device == real_tgt_device,\n                 std::runtime_error,\n                 \"lbannv2::migrate_ptr: invalid src/tgt device combo\");\n\n  // Update the stream\n  auto const new_stream = real_tgt_device.is_cpu()\n                            ? host_allocation_stream(ptr_dev_idx)\n                            : TorchGPUStream_t(with_stream);\n\n  // UGH. Oh well.\n  MI300Allocator::instance().alloc_->recordStream(ptr, new_stream);\n\n  // Finally, update the DataPtr itself\n  ptr.unsafe_set_device(real_tgt_device);\n}\n"
  },
  {
    "path": "src/lbannv2/memory/mi300a_allocator.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2/memory/allocator.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n\n#include <c10/core/Stream.h>\n\n#if LBANNV2_HAS_CUDA\n#include <c10/cuda/CUDACachingAllocator.h>\n#elif LBANNV2_HAS_ROCM\n#include <c10/hip/HIPCachingAllocator.h>\n#endif\n\nnamespace lbannv2\n{\n\n// Call when moving pointer to a different device\nvoid migrate_ptr(c10::DataPtr& ptr,\n                 c10::Device to_device,\n                 c10::Stream with_stream);\n\nclass MI300Allocator final : public Allocator\n{\npublic:\n  void copy_data(void* dst, void const* src, size_t bytes) const final;\n\n  void* raw_alloc(size_t nbytes) final;\n\n  void raw_dealloc(void* ptr) final;\n\n  c10::DeleterFnPtr raw_deleter() const final;\n\n  c10::Device get_device() const noexcept final;\n\n  static MI300Allocator& instance();\n\nprivate:\n  MI300Allocator();\n  ~MI300Allocator() = default;\n  MI300Allocator(MI300Allocator const&) = delete;\n  MI300Allocator(MI300Allocator&&) = delete;\n  MI300Allocator& operator=(MI300Allocator const&) = delete;\n  MI300Allocator& operator=(MI300Allocator&&) = delete;\n\n#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS\n  using DeviceAlloc_t = ::c10::hip::HIPCachingAllocator::HIPAllocator;\n#else\n  using DeviceAlloc_t = ::c10::cuda::CUDACachingAllocator::CUDAAllocator;\n#endif\n  DeviceAlloc_t* alloc_;\n\n  friend void migrate_ptr(c10::DataPtr&, c10::Device, c10::Stream);\n\n};\n\n/** @brief Get the device with which the allocation is associated.\n *\n * @note From what I can tell, this is any valid pointer -- it doesn't\n *       have to be the \"context\" pointer, for instance.\n *\n * @param[in] A pointer to valid memory.\n *\n * @returns The (GPU) device index with which the allocation is\n *          associated. -1 if not GPU memory or nullptr.\n */\nc10::DeviceIndex get_device_idx(void const* const ptr) noexcept;\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/memory/registry.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include \"registry.hpp\"\n\n#include \"lbannv2/utils/errors.hpp\"\n#include \"lbannv2/utils/logging.hpp\"\n\nnamespace\n{\n\n// Syntactic sugar. Iterators kinda suck for readability.\nauto const& get_ptr_range(std::input_iterator auto const& map_iter) noexcept\n{\n  return map_iter->first;\n}\n\nauto const& get_allocator_ptr(std::input_iterator auto const& map_iter) noexcept\n{\n  return map_iter->second;\n}\n\nstd::size_t range_bytes(std::pair<void*, void*> const& r) noexcept\n{\n  return std::distance((std::byte*) r.first, (std::byte*) r.second);\n}\n\n}  // namespace\n\nnamespace lbannv2\n{\n\nvoid PointerRegistry::add(void* const ptr,\n                          size_t const size,\n                          c10::Allocator* const allocator)\n{\n  if (!ptr)\n    return;\n\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const [it, added] = m_registry.emplace(\n    KeyT {ptr, static_cast<std::byte*>(ptr) + size}, allocator);\n  LBANNV2_ASSERT(\n    added, std::runtime_error, \"Address range overlaps existing range\");\n\n  LBANNV2_TRACE(\"Registered pointer range start={}, size={}, allocator={}\",\n                ptr,\n                size,\n                (void*) allocator);\n}\n\nvoid PointerRegistry::remove(void* const ptr)\n{\n  if (!ptr)\n    return;\n\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const it = m_registry.find(ptr);\n  if (it == m_registry.cend())\n    throw UnknownAddress {};\n  else if (get_ptr_range(it).first != ptr)\n    throw std::runtime_error(\"Cannot remove ptr; not beginning of range.\");\n\n  {\n    [[maybe_unused]] auto const& [ptr_range, alloc_ptr] = *it;\n    LBANNV2_TRACE(\"Deregistered pointer range start={}, size={}, allocator={}\",\n                  ptr_range.first,\n                  range_bytes(ptr_range),\n                  (void*) alloc_ptr);\n  }\n\n  m_registry.erase(it);\n}\n\nbool PointerRegistry::known(void const* const ptr) const noexcept\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  return m_registry.contains(ptr);\n}\n\nc10::Allocator* PointerRegistry::get_allocator(void const* const ptr) const\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const it = m_registry.find(ptr);\n  if (it == m_registry.cend())\n    throw UnknownAddress {};\n  return get_allocator_ptr(it);\n}\n\nvoid PointerRegistry::unsafe_reset_allocator(void const* const ptr,\n                                             c10::Allocator* const new_alloc)\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const it = m_registry.find(ptr);\n  if (it == m_registry.cend())\n    throw UnknownAddress {};\n  it->second = new_alloc;\n}\n\nvoid* PointerRegistry::get_context(void const* const ptr) const\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const it = m_registry.find(ptr);\n  if (it == m_registry.cend())\n    throw UnknownAddress {};\n  return get_ptr_range(it).first;\n}\n\nstd::size_t PointerRegistry::bytes_registered() const noexcept\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  size_t bytes = 0UL;\n  for (auto const& kvp : m_registry)\n  {\n    bytes += range_bytes(kvp.first);\n  }\n  return bytes;\n}\n\nstd::size_t\nPointerRegistry::bytes_registered(void const* const ptr) const noexcept\n{\n  std::lock_guard<std::mutex> lock(m_registry_mtx);\n  auto const it = m_registry.find(ptr);\n  if (it != m_registry.cend())\n  {\n    return range_bytes(it->first);\n  }\n  return 0;\n}\n\n}  // namespace lbannv2\n\nauto lbannv2::pointer_registry() -> PointerRegistry&\n{\n  static PointerRegistry registry;\n  return registry;\n}\n"
  },
  {
    "path": "src/lbannv2/memory/registry.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#include <lbannv2/memory/allocator.hpp>\n\n#include <map>\n#include <mutex>\n#include <stdexcept>\n\n#include <c10/core/DeviceType.h>\n\nnamespace lbannv2\n{\n\nstruct LBANNV2_EXPORT UnknownAddress : std::runtime_error\n{\n  UnknownAddress() : std::runtime_error {\"Unknown address\"} {}\n};\n\n// We should consider the issue of registering nullptr or equivalent\n// zero-size allocations. Note that if ISO C++ is the only source of\n// memory, this should be an error. But I'm not sure how all of the\n// allocators we encounter might handle a zero-size allocation (e.g.,\n// cudaMalloc and friends). ISO C++, however, requires zero-size\n// allocations to still return unique, non-null pointers (section\n// 6.7.5.5.2, paragraph 2).\n\n/** @class PointerRegistry\n *  @brief Tracks known memory regions\n */\nclass LBANNV2_EXPORT PointerRegistry\n{\npublic:\n  /** @brief Register an allocation.\n   *\n   *  @param[in] ptr The beginning of the allocated range.\n   *  @param[in] size The size in bytes of the allocated range.\n   *  @param[in] allocator The allocator responsible for deleting the range.\n   */\n  void add(void* ptr, size_t size, c10::Allocator* allocator);\n\n  /** @brief Deregister an allocation.\n   *\n   *  The pointer passed must match a pointer registered with add().\n   *\n   *  @param[in] ptr The (context) pointer to deregister.\n   */\n  void remove(void* ptr);\n\n  /** @brief Query whether this address is part of a registered\n   *         allocation.\n   *\n   *  Returns @c true for any address that is included in a registered\n   *  allocation, that is, in the range [ptr, ptr + size) for any\n   *  (ptr, size) passed to add().\n   *\n   *  @param[in] ptr The pointer in question.\n   */\n  bool known(void const* ptr) const noexcept;\n\n  /** @brief Get the allocator used to allocate this pointer.\n   *\n   *  @param[in] ptr The pointer whose allocator is needed.\n   *\n   *  @throws UnknownAddress if the pointer is not part of a\n   *          registered allocation.\n   */\n  c10::Allocator* get_allocator(void const* ptr) const;\n\n  /** @brief Reset the allocator associated with a pointer.\n   *\n   *  In cases of MI300A pointer migration, this allows us to keep our\n   *  internal bookkeeping consistent. It should not be used outside\n   *  of this context.\n   */\n  void unsafe_reset_allocator(void const* ptr, c10::Allocator* new_alloc);\n  // FIXME (trb): An alternative would be to make this similar to\n  // \"compare and swap\" semantics (i.e., having to provide what the\n  // user thinks the current allocator is); see also, replacing a\n  // deleter on a DataPtr. My concern is this will never be called\n  // \"properly\" but rather just with a dummy\n  // \"registry.get_allocator(ptr)\" in that argument, so what would the\n  // point really be?\n\n  /** @brief Get the context of the given pointer.\n   *\n   *  The context is the address returned by the raw allocator when\n   *  the allocation is requested. It is the pointer that must be\n   *  passed to @c delete.\n   *\n   *  @param[in] ptr The pointer whose context is needed.\n   *\n   *  @throws UnknownAddress if the pointer is not part of a\n   *          registered allocation.\n   */\n  void* get_context(void const* ptr) const;\n\n  /** @brief Get the current number of registered ranges */\n  size_t num_registered() const noexcept\n  {\n    std::lock_guard<std::mutex> lock(m_registry_mtx);\n    return m_registry.size();\n  }\n\n  /** @brief Get the current number of registered bytes */\n  size_t bytes_registered() const noexcept;\n\n  /** @brief Get the number of bytes associated with the given\n   *         pointer.\n   *\n   *  Unregistered pointers return 0. Since zero-sized ranges are\n   *  allowed in the registry, this function cannot serve as a proxy\n   *  for known().\n   *\n   *  @param[in] ptr Any valid address.\n   *\n   *  @returns The number of bytes in an allocation associated with\n   *           the pointer.\n   */\n  size_t bytes_registered(void const*) const noexcept;\n\npublic:\n  using KeyT = std::pair<void*, void*>;\n  /** @class RangeLessAndDisjoint\n   *  @brief Comparison operator for pointer ranges\n   *\n   *  'a' is RangeLessAndDisjoint from 'b' if its upper bound is <=\n   *  the lower bound of 'b', and, because we consider zero-size\n   *  ranges to be valid, if its lower bound is strictly less than the\n   *  lower bound of 'b'. A consequence of this definition is that two\n   *  ranges will be \"equivalent\", by the STL's definition of the\n   *  concept, if and only if they overlap. Thus, using this as the\n   *  `compare` operator in an associative map keyed on ranges [a,b),\n   *  a<=b (with the equality case denoting a valid but zero-sized\n   *  range) allows us to quickly identify overlapping ranges.\n   *\n   *  This provides benefits to our use-case in two ways. First,\n   *  overlapping regions are forbidden. Thus, we will never add a\n   *  range that overlaps a previously added range because the new key\n   *  will present as equivalent to an existing key. Second, we can\n   *  search for pointers p efficiently, using `key_type{p,p}` as the\n   *  key. Searching this way will yield a range containing `p`, if\n   *  one exists. I have included comparison operators that take a\n   *  single pointer to facilitate this computation directly. Because\n   *  they operate exactly \"as though\" we had passed a zero-size\n   *  range, the ordering remains consistent and searches maintain\n   *  their logarithmic complexity.\n   */\n  struct RangeLessAndDisjoint\n  {\n    /** @brief Needed to enable the templated overloads to find,\n     *         contains, etc.\n     */\n    typedef std::true_type is_transparent;\n\n    bool operator()(KeyT const& a, KeyT const& b) const noexcept\n    {\n      return a.second <= b.first && a.first != b.first;\n    }\n\n    bool operator()(void const* const a, KeyT const& b) const noexcept\n    {\n      return a < b.first;\n    }\n\n    bool operator()(KeyT const& a, void const* const b) const noexcept\n    {\n      return a.first != b && a.second <= b;\n    }\n  };\n\nprivate:\n  using MapType = std::map<KeyT, c10::Allocator*, RangeLessAndDisjoint>;\n  MapType m_registry;\n  mutable std::mutex m_registry_mtx;\n};  // struct PointerRegistry\n\nLBANNV2_EXPORT PointerRegistry& pointer_registry();\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/ops/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ntarget_sources(lbannv2\n  PUBLIC\n  FILE_SET HEADERS\n  FILES\n  migrate.hpp\n)\ntarget_sources(lbannv2\n  PRIVATE\n  migrate.cpp\n)\n\n# Note that LBANNV2_HAS_ROCM is implicit in either of these cases.\n#\n# FIXME trb: \"migrate\" includes all the dynamic mi300a handling, etc.\n# Should it always be available at this level? (vs just in\n# register_ops.cpp)\nif (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)\n  target_sources(lbannv2\n    PUBLIC\n    FILE_SET HEADERS\n    FILES\n    nonzero.hpp\n    scalar.hpp\n  )\n  target_sources(lbannv2\n    PRIVATE\n    migrate.cpp\n    nonzero.hip\n    scalar.cpp\n  )\nendif ()\n"
  },
  {
    "path": "src/lbannv2/ops/migrate.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2_config.h>\n\n#include <lbannv2/memory/mi300a_allocator.hpp>\n#include <lbannv2/memory/registry.hpp>\n#include <lbannv2/ops/migrate.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n#include <lbannv2/utils/logging.hpp>\n#include <lbannv2/utils/tensor_helpers.hpp>\n\n#include <ATen/Tensor.h>\n#include <c10/core/Device.h>\n#if LBANNV2_HAS_ROCM\n#include <c10/hip/HIPFunctions.h>\n#endif\n\n#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\nnamespace\n{\n\n// NOTE: This function assumes a binary view of memory: pointers only\n// come from \"CPU\" or \"CUDA\" (i.e., HIP).\nat::Device get_origin_device(void const* const ptr)\n{\n  // Note to future!me: the HIP runtime can give us both the context\n  // pointer and the buffer size for any pointer allocated by HIP.\n  // HOWEVER, so can the pytorch DataPtr object, which we have in the\n  // context in which this function is used...\n  int device_idx;\n  if (hipPointerGetAttribute(&device_idx,\n                             HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL,\n                             const_cast<void*>(ptr))\n      == hipSuccess)\n  {\n    return {c10::kCUDA, static_cast<c10::DeviceIndex>(device_idx)};\n  }\n  return c10::kCPU;\n}\n\n// PyTorch still admits the possibility of a single process using\n// multiple GPUs, though this historically has not been LBANN's\n// preferred approach (instead preferring 1 GPU per rank and 1 rank\n// per GPU). On MI300A, we can migrate a pointer from any \"GPU\" to the\n// CPU freely. HOWEVER, we can only migrate from the CPU to the\n// specific device on which the migrateable memory was allocated.\nbool is_ok_device(c10::Device const& d)\n{\n  return d.is_cpu()\n#if LBANNV2_HAS_GPU\n         || d.is_cuda()\n#endif\n    ;\n}\n\nc10::DispatchKeySet get_default_keyset(c10::Device const& d)\n{\n  switch (d.type())\n  {\n  case c10::kCPU: return c10::DispatchKeySet {c10::DispatchKey::CPU};\n  case c10::kCUDA: return c10::DispatchKeySet {c10::DispatchKey::CUDA};\n  default: throw std::runtime_error(\"Unknown device type\");\n  }\n}\n\n}  // namespace\n#endif\n\nat::Tensor lbannv2::migrate(at::Tensor& t, c10::Device const& d)\n{\n  auto const src_d = t.device();\n  LBANNV2_TRACE(\n    \"migrate(ptr={}, from={}, to={})\", t.data_ptr(), src_d.str(), d.str());\n\n  // Short-circuit\n  if (src_d == d)\n    return t;\n\n#if LBANNV2_UNKNOWN_MI300A || LBANNV2_WITH_MI300A\n  // NOTE: \"LBANNV2_HAS_ROCM\" is implied here.\n\n  // At its heart, this isn't really \"migrate\", it's \"rebrand\"... I\n  // don't actually care what the device annotations on the Tensor or\n  // Storage are, I care about the origin of the pointer. It might\n  // also be good to look into p2p memory access, but I don't know how\n  // to query that just given the pointer (i.e., even if p2p mem\n  // access is enabled *now*, I haven't discovered a way to tell if it\n  // was enabled when a particular buffer was allocated (well, other\n  // than trying to read it and letting the segfault happen)).\n  auto const real_src_d = get_origin_device(t.const_data_ptr());\n\n  // We need to get the \"real\" \"CUDA\" target.\n  c10::Device const real_tgt_d =\n    (d.is_cuda() && !d.has_index())\n      ? c10::Device {c10::kCUDA, gpu::current_device()}\n      : d;\n\n  // If the real_src_d is \"cpu\", it can be migrated to \"cpu\".\n  // If the real_src_d is \"cuda:N\", it can be migrated to \"cpu\" or \"cuda:N\".\n  LBANNV2_ASSERT(real_tgt_d.is_cpu() || (real_src_d == real_tgt_d),\n                 std::runtime_error,\n                 \"Migrate: ptr is not migrateable to given device.\");\n  LBANNV2_ASSERT(\n    is_ok_device(real_src_d),\n    std::runtime_error,\n    \"Migrate: source tensor's device type not supported by LBANNv2.\");\n  LBANNV2_ASSERT(is_ok_device(real_tgt_d),\n                 std::runtime_error,\n                 \"Migrate: destination device type not supported by LBANNv2.\");\n\n  // FIXME: If the pointer is not owned by LBANNv2, how do we handle\n  // its associated stream?\n  //  ---> The PyTorch CUDA caching allocator provides \"recordStream\" :)\n\n#if LBANNV2_UNKNOWN_MI300A\n  if (lbannv2::gpu::is_integrated())\n#endif\n  {\n    c10::Stream stream = real_tgt_d.is_cpu()\n                           ? c10::Stream {c10::Stream::DEFAULT, d}\n                           : getDeviceCurrentStream(real_tgt_d.index());\n\n    lbannv2::migrate_ptr(t.storage().mutable_data_ptr(), d, stream);\n\n    // Report the number of meaningful bytes migrated. This is\n    // inherently based on the tensor shape rather than the allocated\n    // buffer size (think: binned allocations, subtensor \"views\",\n    // etc).\n    LBANNV2_TRACE(\"migrated {} bytes (ptr={})\",\n                  std::accumulate(t.sizes().cbegin(),\n                                  t.sizes().cend(),\n                                  static_cast<int64_t>(1),\n                                  std::multiplies<int64_t> {})\n                    * t.dtype().itemsize(),\n                  t.const_data_ptr());\n\n    auto storage = t.storage();\n    // FIXME (trb): I initially created this as a 'VIEW', but that\n    // puts it in \"inference mode\" (i.e., out.is_inference() == true).\n    // This is bad for training workloads. We may need to be a bit\n    // more careful in general here... E.g., migrating views to views,\n    // etc.\n    auto out =\n      at::detail::make_tensor<at::TensorImpl>(  // at::TensorImpl::VIEW,\n        std::move(storage),\n        get_default_keyset(d),\n        t.dtype());\n    sync_metadata(t, out);\n\n    if (src_d.is_cuda())\n    {\n      getDeviceCurrentStream(src_d.index()).synchronize();\n    }\n\n    return out;\n  }\n#endif\n  return t.to(t.options().device(d));\n}\n"
  },
  {
    "path": "src/lbannv2/ops/migrate.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#include <ATen/Tensor.h>\n#include <c10/core/Device.h>\n\nnamespace lbannv2\n{\n\n/** @brief Migrate a tensor to a new device, eliding copies when\n *         possible.\n *\n *  If we have an APU (e.g., MI300A), we are able to zero-copy migrate\n *  the memory between the \"cpu\" backend and the \"cuda\" backend, under\n *  certain circumstances. The semantics differ from the \"to\" operator\n *  in the sense that the original tensor is considered \"invalid\"\n *  (implicitly, of course) after the migration.\n *\n *  The primary prerequisite for migrating a tensor is that its\n *  backing memory must have been allocated using a \"cuda\" allocator\n *  (that is, somewhere in the allocator stack, the raw memory must\n *  come from \"hipMalloc\" in the case of MI300A). LBANNv2 provides a\n *  context manager that replaces the underlying CPU allocator with\n *  one that allocates \"cuda\" memory, essentially providing\n *  \"migrateable\" CPU tensors.\n *\n *  At this time, we do NOT support IPC memory buffers or P2P device\n *  memory access. Thus, tensors are only migrateable between the CPU\n *  and whichever CUDA device their allocation is tied to. In the case\n *  of CPU tensors allocated using the LBANNv2 allocator, this will be\n *  whichever CUDA device was selected at the time of its allocation.\n *\n *  If we do not have an APU, this is just a direct call to \"to\".\n *\n *  Upon successful migration, the input tensor is invalidated to\n *  prevent foot wounds.\n *\n *  Schema: migrate(Tensor(a!), Device) -> Tensor(a!)\n *\n *  @param[in] t The tensor to (possibly) migrate.\n *  @param[in] d The target device.\n *\n *  @returns A tensor associated with the given target device.\n */\nat::Tensor migrate(at::Tensor& t, c10::Device const& d);\n\n}// namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/ops/nonzero.hip",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include \"lbannv2/ops/nonzero.hpp\"\n#include <lbannv2/memory/allocator.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n#include <lbannv2/utils/logging.hpp>\n\n#include <ATen/hip/EmptyTensor.h>\n#include <ATen/hip/HIPContext.h>\n#include <hipcub/hipcub.hpp>\n\n// Note (trb): UGH PyTorch 2.11\n#ifdef C10_HIP_KERNEL_LAUNCH_CHECK\n#define LBANNV2_KERNEL_LAUNCH_CHECK() C10_HIP_KERNEL_LAUNCH_CHECK()\n#else\n#define LBANNV2_KERNEL_LAUNCH_CHECK() C10_CUDA_KERNEL_LAUNCH_CHECK()\n#endif\n\nnamespace\n{\ntemplate <typename T>\nT const* get_const(c10::DataPtr const& ptr)\n{\n  return static_cast<T const*>(ptr.get());\n}\n\n// Hoisted from PyTorch; clang-format to LBANNv2's style.\n//\n//   path: aten/src/ATen/native/cuda/Nonzero.cu\n//   commit: 36eb64d60ea6371e3a617ba5026d27be7f88a6af\n//\n// FIXME: Point to <pytorch>/LICENSE or copy thereof.\n\ntemplate <typename T>\nstruct NonZeroOp\n{\n  __host__ __device__ __forceinline__ bool operator()(T const& a) const\n  {\n    return (a != T(0));\n  }\n};\n\n#define MAX_DIMS 16\ntemplate <typename index_t>\nstruct TensorDims\n{\n  index_t sizes[MAX_DIMS];\n};\n\ntemplate <typename index_t>\n__global__ void write_indices(int64_t* inp,\n                              TensorDims<index_t> dims,\n                              int ndim,\n                              index_t n,\n                              int64_t* total = nullptr,\n                              int64_t fill_value = -1)\n{\n  auto index = threadIdx.x + (int64_t) blockIdx.x * blockDim.x;\n  bool cond = (total == nullptr || index < *total);\n  if (index < n && cond)\n  {\n    index_t div = 1;\n    int64_t idx_flat = inp[index];\n#pragma unroll\n    for (int dim = MAX_DIMS; dim >= 0; dim--)\n    {\n      if (dim > ndim - 1)\n        continue;\n      auto dim_size = dims.sizes[dim];\n      inp[index + dim * n] = (idx_flat / div) % dim_size;\n      div *= dim_size;\n    }\n  }\n  else if (index < n)\n  {\n    // 0th dim has correct values already\n    for (int dim = ndim - 1; dim > 0; dim--)\n    {\n      inp[index + dim * n] = fill_value;\n    }\n  }\n}\n\n// NOTE (trb): Majority from PyTorch. We removed use of host-based\n// pinned_num_nonzeros_h. Instead we just sync the stream and use the\n// memory directly on the CPU. We also change\n// `const_data_ptr<scalar_t>()` to `static_cast<scalar_t\n// const*>(const_data_ptr())` to sidestep a linker error with\n// amdclang++.\ntemplate <typename scalar_t>\nvoid nonzero_out_mi300a_impl(at::Tensor const& self, at::Tensor& out)\n{\n  at::Tensor self_ = self.contiguous();\n  hipStream_t const stream = at::hip::getCurrentHIPStream();\n  int64_t chunk_size, num_chunks;\n  if (self.numel() < std::numeric_limits<int>::max())\n  {\n    chunk_size = self.numel();\n    num_chunks = 1;\n  }\n  else\n  {\n    chunk_size = std::numeric_limits<int>::max() / 2 + 1;  // 2**30\n    num_chunks = (self.numel() + chunk_size - 1) / chunk_size;\n  }\n  // compute number of nonzero elements\n  size_t temp_storage_bytes = 0;\n  auto* const allocator = c10::GetAllocator(self.device().type());\n\n  auto num_nonzeros = allocator->allocate(sizeof(int) * num_chunks);\n  for (int64_t idx = 0; idx < num_chunks; idx++)\n  {\n    int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);\n    hipcub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t const*>\n      itr(static_cast<scalar_t const*>(self_.const_data_ptr()) + idx * chunk_size,\n          NonZeroOp<scalar_t>());\n    AT_CUDA_CHECK(hipcub::DeviceReduce::Sum(nullptr,\n                                            temp_storage_bytes,\n                                            itr,\n                                            ((int*) num_nonzeros.get()) + idx,\n                                            remaining,\n                                            stream));\n    auto temp_storage = allocator->allocate(temp_storage_bytes);\n    AT_CUDA_CHECK(hipcub::DeviceReduce::Sum(temp_storage.get(),\n                                            temp_storage_bytes,\n                                            itr,\n                                            ((int*) num_nonzeros.get()) + idx,\n                                            remaining,\n                                            stream));\n  }\n\n  // TOM: Skip the copy...\n\n  // auto pinned_num_nonzeros_h = at::detail::empty_cpu(\n  //     {num_chunks}, /* size */\n  //     c10::CppTypeToScalarType<int>(), /* dtype */\n  //     std::nullopt, /* layout */\n  //     std::nullopt, /* device */\n  //     true, /* pin_memory */\n  //     std::nullopt /* memory format */\n  // );\n  // at::cuda::memcpy_and_sync(\n  //     (void*)pinned_num_nonzeros_h.const_data_ptr<int>(),\n  //     num_nonzeros.get(),\n  //     sizeof(int) * num_chunks,\n  //     cudaMemcpyDeviceToHost,\n  //     stream);\n\n  // TOM: ...just sync the stream...\n  LBANNV2_CHECK_GPU(hipStreamSynchronize(stream));\n\n  int64_t num_nonzeros_h = 0;\n\n  // TOM: ...and use the pointer.\n  for (int64_t idx = 0; idx < num_chunks; idx++)\n  {\n    num_nonzeros_h += (int) *(get_const<int>(num_nonzeros) + idx);\n  }\n\n  // num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());\n  // expected output size is num_nonzeros x ndim\n  // we are producing output with size {num_nonzeros, ndim} and strides {1,\n  // num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able\n  // to directly use passed output with this size and strides, and we can also\n  // (per contract) resize passed output with incorrect sizes anyway we want.\n  // However, out with correct sizes and incorrect strides will have to be\n  // copied to from the intermediate we've produced.\n  bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h\n                      && out.sizes()[1] == self.dim()\n                      && !out.t().is_contiguous();\n  at::Tensor out_temp = need_to_copy\n                          ? at::Tensor(at::detail::empty_cuda(\n                              {self.dim(), num_nonzeros_h}, out.options()))\n                          : out.resize_({self.dim(), num_nonzeros_h});\n  // Scalars are expected to produce output of size (1,0), so we can't write to\n  // it\n  int64_t curr_nonzeros = 0;\n  if (self.dim() > 0)\n  {\n    for (int64_t idx = 0; idx < num_chunks; idx++)\n    {\n      int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);\n\n      hipcub::CountingInputIterator<int64_t> counting_itr(idx * chunk_size);\n      hipcub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t const*>\n        itr(static_cast<scalar_t const*>(self_.const_data_ptr()) + idx * chunk_size,\n            NonZeroOp<scalar_t>());\n      temp_storage_bytes = 0;\n      AT_CUDA_CHECK(\n        hipcub::DeviceSelect::Flagged(nullptr,\n                                      temp_storage_bytes,\n                                      counting_itr,\n                                      itr,\n                                      out_temp.mutable_data_ptr<int64_t>(),\n                                      ((int*) num_nonzeros.get()) + idx,\n                                      remaining,\n                                      stream));\n      auto temp_storage = allocator->allocate(temp_storage_bytes);\n      AT_CUDA_CHECK(hipcub::DeviceSelect::Flagged(\n        temp_storage.get(),\n        temp_storage_bytes,\n        counting_itr,\n        itr,\n        out_temp.mutable_data_ptr<int64_t>() + curr_nonzeros,\n        ((int*) num_nonzeros.get()) + idx,\n        remaining,\n        stream));\n      // TOM: Oh look, we use it again.\n      curr_nonzeros += (int) *(get_const<int>(num_nonzeros) + idx);\n    }\n    if (num_nonzeros_h > 0 && self.dim() > 1)\n    {\n      TensorDims<int64_t> dims;\n      for (int i = 0; i < self.dim(); i++)\n      {\n        dims.sizes[i] = self.sizes()[i];\n      }\n      int const nthreads = 256;\n      int const nblocks = (num_nonzeros_h + nthreads - 1) / nthreads;\n      write_indices<<<nblocks, nthreads, 0, stream>>>(\n        out_temp.mutable_data_ptr<int64_t>(), dims, self.dim(), num_nonzeros_h);\n      LBANNV2_KERNEL_LAUNCH_CHECK();\n    }\n  }\n  if (need_to_copy)\n  {\n    out.copy_(out_temp.t());\n  }\n  else\n  {\n    // transpose out so it is correct size\n    at::Tensor out_ = out_temp.t();\n    out.set_(out_);\n  }\n}\n}  // namespace\n\nat::Tensor& lbannv2::nonzero_out(at::Tensor const& self, at::Tensor& out)\n{\n  c10::ScalarType const dtype = self.scalar_type();\n\n  LBANNV2_TRACE(\"lbannv2::nonzero_out(device={}, dtype={})\",\n                self.device().str(),\n                c10::toString(dtype));\n\n  switch (dtype)\n  {\n  case c10::ScalarType::Bool: nonzero_out_mi300a_impl<bool>(self, out); break;\n  case c10::ScalarType::Float:  nonzero_out_mi300a_impl<float>(self, out); break;\n  case c10::ScalarType::Double: nonzero_out_mi300a_impl<double>(self, out); break;\n  case c10::ScalarType::Int:  nonzero_out_mi300a_impl<int>(self, out); break;\n  case c10::ScalarType::UInt32: nonzero_out_mi300a_impl<std::uint32_t>(self, out); break;\n  case c10::ScalarType::Long:  nonzero_out_mi300a_impl<long>(self, out); break;\n  default: return at::native::nonzero_out_cuda(self, out);\n  }\n\n  return out;\n}\n\nat::Tensor lbannv2::nonzero(at::Tensor const& self)\n{\n  at::Tensor out =\n    at::detail::empty_cuda({0}, self.options().dtype(c10::kLong));\n  return nonzero_out(self, out);\n}\n"
  },
  {
    "path": "src/lbannv2/ops/nonzero.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <ATen/ATen.h>\n\nnamespace lbannv2\n{\n\nat::Tensor nonzero(at::Tensor const& self);\nat::Tensor& nonzero_out(at::Tensor const& self, at::Tensor& out);\n\n} // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/ops/scalar.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2_config.h>\n\n#include <lbannv2/ops/scalar.hpp>\n#include <lbannv2/utils/errors.hpp>\n\n#include <ATen/ops/_local_scalar_dense_native.h>\n\n#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n#include <lbannv2/types.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n#include <lbannv2/utils/logging.hpp>\n\n#include <ATen/core/TensorBase.h>\n#include <c10/core/Scalar.h>\n#include <c10/core/ScalarType.h>\n#include <c10/hip/HIPStream.h>\n\n// FIXME: We should integrate this better with either H2 dispatch or\n// Torch dispatch (I don't really care, honestly).\nnamespace\n{\n\ntemplate <typename ScalarT>\nat::Scalar mi300a_impl(at::Tensor const& self)\n{\n  // The contract is a sync, so we sync. (It's also likely a\n  // requirement for correctness, so we can assume the value can be\n  // safely accessed.\n  auto const stream = at::hip::getCurrentHIPStream();\n  lbannv2::gpu::sync(stream);\n  return at::Scalar(*reinterpret_cast<ScalarT const*>(self.const_data_ptr()));\n}\n\nat::Scalar mi300a_dispatch(at::Tensor const& self)\n{\n  c10::ScalarType const dtype = self.scalar_type();\n\n  LBANNV2_TRACE(\"lbannv2::_local_scalar_dense_mi300a(device={}, dtype={})\",\n                self.device().str(),\n                c10::toString(dtype));\n  switch (dtype)\n  {\n  case c10::ScalarType::Bool: return mi300a_impl<bool>(self);\n  case c10::ScalarType::Float: return mi300a_impl<float>(self);\n  case c10::ScalarType::Double: return mi300a_impl<double>(self);\n  case c10::ScalarType::Int: return mi300a_impl<int>(self);\n  case c10::ScalarType::UInt32: return mi300a_impl<std::uint32_t>(self);\n  case c10::ScalarType::Long: return mi300a_impl<long>(self);\n  default: return at::native::_local_scalar_dense_cuda(self);\n  }\n}\n}  // namespace\n#endif  // LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n\nat::Scalar lbannv2::local_scalar_dense_hip(at::Tensor const& self)\n{\n  // self.numel() == 1 is asserted elsewhere.\n  c10::ScalarType const dtype = self.dtype().toScalarType();\n\n  // Technically, the \"right\" fallback is implemented in all\n  // subsequent code paths, but I want to know about it if there's\n  // another type we should be supporting.\n  LBANNV2_ASSERT(is_supported(dtype), std::runtime_error, c10::toString(dtype));\n\n#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n#if LBANNV2_UNKNOWN_MI300A\n  if (lbannv2::gpu::is_integrated())\n#endif  // LBANNV2_UNKNOWN_MI300A\n    return mi300a_dispatch(self);\n#endif  // LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A\n\n  // Fallback to the Torch impl (cannot call at::_local_scalar_dense\n  // -- it will cause an infinite recursion through this function).\n  return at::native::_local_scalar_dense_cuda(self);\n}\n"
  },
  {
    "path": "src/lbannv2/ops/scalar.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#include <ATen/core/Tensor.h>\n\nnamespace lbannv2\n{\n\nLBANNV2_EXPORT at::Scalar local_scalar_dense_hip(at::Tensor const&);\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/python/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\nif (NOT SKBUILD)\n  message(FATAL_ERROR \"You should not be here. Not doing a SKBUILD.\")\nendif ()\n\ntarget_sources(_lbannv2\n  PRIVATE\n  register_lbannv2.cpp\n  register_memory_funcs.cpp\n)\n\nif (LBANNV2_WITH_MI300A OR LBANNV2_UNKNOWN_MI300A)\n  target_sources(_lbannv2\n    PRIVATE\n    register_mi300a_ops.cpp\n  )\nendif ()\n"
  },
  {
    "path": "src/lbannv2/python/register_lbannv2.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2_config.h>\n\n#include <lbannv2/utils/gpu_utils.hpp>\n#include <lbannv2/utils/logging.hpp>\n\n#include <c10/core/Device.h>\n#include <pybind11/pybind11.h>\n\n#include <cstdlib>\n#include <iostream>\n\n#include <sys/types.h>\n#include <unistd.h>\n\nnamespace\n{\nbool _lbannv2_initialized = false;\nbool _lbannv2_gpu_initialized = false;\n\nvoid init_lbannv2()\n{\n  if (_lbannv2_initialized)\n    return;\n\n  if (std::getenv(\"LBANNV2_HANG_FOR_DEBUG\"))\n  {\n    // Raw vs spdlog here because I want to force the flush.\n    std::cout << \"LBANNV2 WAITING ON PID \" << getpid() << std::endl;\n    int volatile wait = 1;\n    while (wait) {}\n  }\n\n#if LBANNV2_HAS_GPU\n  if (!_lbannv2_gpu_initialized)\n  {\n    // There's nothing to do here since getting rid of H2.\n    _lbannv2_gpu_initialized = true;\n  }\n#endif\n\n  _lbannv2_initialized = true;\n}\n\nbool is_lbannv2_initialized() noexcept\n{\n  return _lbannv2_initialized;\n}\n\nbool is_lbannv2_gpu_initialized() noexcept\n{\n  return _lbannv2_gpu_initialized;\n}\n\nbool is_lbannv2_gpu_available() noexcept\n{\n  return LBANNV2_HAS_GPU;\n}\n\n}  // namespace\n\nnamespace _lbannv2\n{\nvoid add_memory_funcs(pybind11::module_& m);\n}  // namespace _lbannv2\n\nPYBIND11_MODULE(_lbannv2, m)\n{\n  m.def(\"init_lbannv2\", &init_lbannv2, \"Initialize state for LBANNv2\");\n  m.def(\"is_lbannv2_initialized\",\n        &is_lbannv2_initialized,\n        \"Query initialization state for LBANNv2\");\n  m.def(\"is_lbannv2_gpu_initialized\",\n        &is_lbannv2_gpu_initialized,\n        \"Query initialization state for LBANNv2 GPU support.\");\n  m.def(\"is_lbannv2_gpu_available\",\n        &is_lbannv2_gpu_available,\n        \"Query whether LBANNv2 has GPU support.\");\n  m.def(\"set_log_level\",\n        &lbannv2::set_log_level,\n        \"Set the output level for LBANNv2 logging.\");\n\n  _lbannv2::add_memory_funcs(m);\n}\n"
  },
  {
    "path": "src/lbannv2/python/register_memory_funcs.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2_config.h>\n\n#include <lbannv2/memory/memory_utils.hpp>\n#include <lbannv2/memory/registry.hpp>\n#include <lbannv2/ops/migrate.hpp>\n#include <lbannv2/utils/logging.hpp>\n\n#if LBANNV2_HAS_GPU\n#include <lbannv2/utils/gpu_utils.hpp>\n#endif\n\n#include <lbannv2/memory/allocator.hpp>\n\n#include <ATen/ops/to_native.h>\n#include <c10/core/Device.h>\n#include <pybind11/pybind11.h>\n#include <torch/csrc/utils/pybind.h>\n#include <torch/extension.h>\n#include <torch/library.h>\n\nnamespace\n{\n\n// Migrate\nat::Tensor py_migrate(at::Tensor& t, at::Device const& d)\n{\n  return lbannv2::migrate(t, d);\n}\n\nbool py_supports_migrate() noexcept\n{\n#if LBANNV2_WITH_MI300A\n  return true;\n#elif LBANNV2_HAS_GPU\n  return lbannv2::gpu::is_integrated();\n#else\n  return false;\n#endif\n}\n\nvoid py_use_mi300a_host_allocator()\n{\n  lbannv2::use_mi300a_cpu_allocator();\n}\n\nvoid py_use_torch_host_allocator()\n{\n  lbannv2::use_torch_cpu_allocator();\n}\n\nbool py_using_lbannv2_memory(torch::Tensor const& t)\n{\n  return lbannv2::pointer_registry().known(t.const_data_ptr());\n}\n\n}  // namespace\n\nnamespace _lbannv2\n{\n\nvoid add_memory_funcs(pybind11::module_& m)\n{\n  // Pointer migration\n  m.def(\"supports_migrate\",\n        &py_supports_migrate,\n        \"Determine whether device migration is supported\");\n\n  m.def(\"migrate\",\n        &py_migrate,\n        \"Try to migrate an LBANNv2-owned pointer to a new device.\");\n\n  m.def(\"use_mi300a_host_allocator\",\n        &py_use_mi300a_host_allocator,\n        \"Use the LBANNv2 MI300A allocator for CPU allocations\");\n\n  m.def(\"use_pytorch_host_allocator\",\n        &py_use_torch_host_allocator,\n        \"Use the default pytorch CPU allocator for CPU allocations\");\n\n  m.def(\n    \"using_lbannv2_memory\",\n    &py_using_lbannv2_memory,\n    \"Determine whether LBANNv2 allocated the memory backing a given tensor\");\n}\n\n}  // namespace _lbannv2\n"
  },
  {
    "path": "src/lbannv2/python/register_mi300a_ops.cpp",
    "content": "// NOTE: this file is only compiled when LBANNV2_WITH_MI300A or\n// LBANNV2_UNKNOWN_MI300A, so the \"#else\" clauses below are really\n// \"#elif LBANNV2_UNKNOWN_MI300A\".\n#include \"lbannv2_config.h\"\n\n#include <lbannv2/memory/mi300a_allocator.hpp>\n#include <lbannv2/ops/nonzero.hpp>\n#include <lbannv2/ops/scalar.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n\n#include <torch/extension.h>\n#include <torch/library.h>\n\nnamespace\n{\n\nat::Scalar lbannv2__local_scalar_dense_cuda(at::Tensor const& self)\n{\n#if LBANNV2_WITH_MI300A\n  return lbannv2::local_scalar_dense_hip(self);\n#else\n  if (lbannv2::gpu::is_integrated())\n    return lbannv2::local_scalar_dense_hip(self);\n  return at::native::_local_scalar_dense_cuda(self);\n#endif\n}\n\nat::Tensor lbannv2_nonzero(at::Tensor const& self)\n{\n#if LBANNV2_WITH_MI300A\n  return lbannv2::nonzero(self);\n#else\n  if (lbannv2::gpu::is_integrated())\n    return lbannv2::nonzero(self);\n  return at::native::nonzero_cuda(self);\n#endif\n}\n\nat::Tensor& lbannv2_nonzero_out(at::Tensor const& self, at::Tensor& out)\n{\n#if LBANNV2_WITH_MI300A\n  return lbannv2::nonzero_out(self, out);\n#else\n  if (lbannv2::gpu::is_integrated())\n    return lbannv2::nonzero_out(self, out);\n  return at::native::nonzero_out_cuda(self, out);\n#endif\n}\n\n} // namespace\n\nTORCH_LIBRARY_IMPL(aten, CUDA, m)\n{\n  m.impl(\"_local_scalar_dense\", TORCH_FN(lbannv2__local_scalar_dense_cuda));\n  m.impl(\"nonzero\", TORCH_FN(lbannv2_nonzero));\n  m.impl(\"nonzero.out\", TORCH_FN(lbannv2_nonzero_out));\n}\n"
  },
  {
    "path": "src/lbannv2/types.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n// FIXME (trb): Where should this file live??\n\n#include <c10/core/ScalarType.h>\n\nnamespace lbannv2\n{\n\n/** @brief Decide if a data type is supported by LBANNv2. */\ninline bool is_supported(c10::ScalarType t) noexcept\n{\n  switch (t)\n  {\n  case c10::ScalarType::Bool:\n  case c10::ScalarType::Float:\n  case c10::ScalarType::Double:\n  case c10::ScalarType::Int:\n  case c10::ScalarType::UInt32:\n  case c10::ScalarType::Long: return true;\n  default: return false;\n  }\n}\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/utils/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ntarget_sources(lbannv2\n  PUBLIC\n  FILE_SET HEADERS\n  FILES\n  debugging_helpers.hpp\n  errors.hpp\n  gpu_utils.hpp\n  logging.hpp\n  tensor_helpers.hpp\n)\ntarget_sources(lbannv2\n  PRIVATE\n  gpu_utils.cpp\n  logging.cpp\n)\n"
  },
  {
    "path": "src/lbannv2/utils/debugging_helpers.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <cxxabi.h>\n#include <execinfo.h>\n\n#include <iomanip>\n#include <iostream>\n#include <sstream>\n#include <string>\n#include <vector>\n\nnamespace lbannv2\n{\n\ninline std::string demngl(std::string symb)\n{\n  int status;\n  char* const demangled_name =\n    abi::__cxa_demangle(symb.data(), nullptr, nullptr, &status);\n  if (demangled_name && status == 0)\n  {\n    std::string out(demangled_name);\n    free(demangled_name);\n    return out;\n  }\n\n  std::ostringstream oss;\n  oss << symb << \" (demangling failed)\";\n  return oss.str();\n}\n\ninline void print_bt(size_t nframes = 128, std::ostream& os = std::cout)\n{\n  std::vector<void*> frames(nframes);\n  nframes = backtrace(frames.data(), nframes);\n  char** symbs = backtrace_symbols(frames.data(), nframes);\n\n  os << \"-------------------------------------------------\\n\";\n  for (size_t i = 0; i < nframes; ++i)\n  {\n    os << std::setw(4) << std::right << i << \": (\" << frames[i]\n       << \"): \" << demngl(symbs[i]) << \"\\n\";\n  }\n  os << \"-------------------------------------------------\" << std::endl;\n  free(symbs);\n}\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/utils/errors.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n#define LBANNV2_ASSERT(cond, excpt, msg)                                       \\\n  do                                                                           \\\n  {                                                                            \\\n    if (!(cond))                                                               \\\n    {                                                                          \\\n      throw excpt(msg);                                                        \\\n    }                                                                          \\\n  } while (0)\n\n#define LBANNV2_ASSERT_ALWAYS(cond)                                            \\\n  LBANNV2_ASSERT(cond, std::runtime_error, \"Assertion \\\"\" #cond \"\\\" failed.\")\n\n#if LBANNV2_DEBUG\n#define LBANNV2_ASSERT_DEBUG(cond) (void)\n#else\n#define LBANNV2_ASSERT_DEBUG(cond) LBANNV2_ASSERT_ALWAYS(cond)\n#endif\n"
  },
  {
    "path": "src/lbannv2/utils/gpu_utils.cpp",
    "content": "#include \"gpu_utils.hpp\"\n\n#include \"errors.hpp\"\n#include \"logging.hpp\"\n\nbool lbannv2::gpu::is_integrated() noexcept\n{\n#if LBANNV2_WITH_MI300A\n  return true;\n#else\n#if LBANNV2_HAS_ROCM\n  hipDeviceProp_t props;\n  if (hipGetDeviceProperties(&props, current_device()) == hipSuccess)\n    return props.integrated;\n  LBANNV2_ERROR(\"Failed to get device properties of current HIP device {}.\",\n                current_device());\n#endif\n#endif\n  return false;\n}\n\nc10::DeviceIndex lbannv2::gpu::num_devices() noexcept\n{\n#if LBANNV2_HAS_GPU\n  return c10_gpu::device_count();\n#else\n  return 0;\n#endif\n}\n\nc10::DeviceIndex lbannv2::gpu::current_device()\n{\n#if LBANNV2_HAS_GPU\n  return c10_gpu::current_device();\n#else\n  return -1;\n#endif\n}\n\nvoid lbannv2::gpu::set_device(c10::DeviceIndex const d)\n{\n  LBANNV2_TRACE(\"lbannv2::gpu::set_device(d={})\", d);\n  LBANNV2_ASSERT_ALWAYS(d >= 0 && d < num_devices());\n#if LBANNV2_HAS_GPU\n  c10_gpu::set_device(d, false);\n#endif\n}\n\n#if LBANNV2_HAS_GPU\n#if LBANNV2_HAS_CUDA\n#define lbannv2StreamCreate cudaStreamCreate\n#define lbannv2StreamCreateWithFlags cudaStreamCreateWithFlags\n#define lbannv2StreamNonBlocking cudaStreamNonBlocking\n#define lbannv2StreamSync cudaStreamSynchronize\n#define lbannv2StreamDestroy cudaStreamDestroy\n#elif LBANNV2_HAS_ROCM\n#define lbannv2StreamCreate hipStreamCreate\n#define lbannv2StreamCreateWithFlags hipStreamCreateWithFlags\n#define lbannv2StreamNonBlocking hipStreamNonBlocking\n#define lbannv2StreamSync hipStreamSynchronize\n#define lbannv2StreamDestroy hipStreamDestroy\n#endif\n\nauto lbannv2::gpu::make_stream() -> Stream_t\n{\n  Stream_t stream;\n  LBANNV2_CHECK_GPU(lbannv2StreamCreate(&stream));\n  LBANNV2_TRACE(\"lbannv2::gpu::make_stream(): created stream {}\",\n                (void*) stream);\n  return stream;\n}\n\nauto lbannv2::gpu::make_nonblocking_stream() -> Stream_t\n{\n  Stream_t stream;\n  LBANNV2_CHECK_GPU(\n    lbannv2StreamCreateWithFlags(&stream, lbannv2StreamNonBlocking));\n  LBANNV2_TRACE(\"lbannv2::gpu::make_nonblocking_stream(): created stream {}\",\n                (void*) stream);\n  return stream;\n}\n\nvoid lbannv2::gpu::sync(Stream_t const stream)\n{\n  LBANNV2_CHECK_GPU(lbannv2StreamSync(stream));\n  LBANNV2_TRACE(\"lbannv2::gpu::sync(stream={})\", (void const*) stream);\n}\n\nvoid lbannv2::gpu::destroy_stream(Stream_t const stream)\n{\n  LBANNV2_CHECK_GPU(lbannv2StreamDestroy(stream));\n  LBANNV2_TRACE(\"lbannv2::gpu::destroy_stream(stream={})\", (void*) stream);\n}\n\n#endif\n"
  },
  {
    "path": "src/lbannv2/utils/gpu_utils.hpp",
    "content": "#pragma once\n#include <lbannv2_config.h>\n\n#include <c10/core/Device.h>\n\n#if LBANNV2_HAS_CUDA\n\n#include <lbannv2/utils/logging.hpp>\n\n#include <c10/cuda/CUDAFunctions.h>\n\n#include <stdexcept>\n\n#include <cuda_runtime.h>\n\n#define LBANNV2_CHECK_GPU(cmd)                                                 \\\n  do                                                                           \\\n  {                                                                            \\\n    auto const lbannv2_check_gpu_status = (cmd);                               \\\n    if (lbannv2_check_gpu_status != cudaSuccess)                               \\\n    {                                                                          \\\n      LBANNV2_DEBUG(\"CUDA command \\\"\" #cmd \"\\\" failed. Error: {}\",             \\\n                    cudaGetErrorString(lbannv2_check_gpu_status));             \\\n      throw std::runtime_error(\"CUDA command \\\"\" #cmd \"\\\" failed.\");           \\\n    }                                                                          \\\n  } while (0)\n\n#elif LBANNV2_HAS_ROCM\n\n#include <lbannv2/utils/logging.hpp>\n\n#include <c10/hip/HIPFunctions.h>\n#include <c10/hip/HIPStream.h>\n\n#include <stdexcept>\n\n#include <hip/hip_runtime.h>\n\n#define LBANNV2_CHECK_GPU(cmd)                                                 \\\n  do                                                                           \\\n  {                                                                            \\\n    auto const lbannv2_check_gpu_status = (cmd);                               \\\n    if (lbannv2_check_gpu_status != hipSuccess)                                \\\n    {                                                                          \\\n      LBANNV2_DEBUG(\"HIP command \\\"\" #cmd \"\\\" failed. Error: {}\",              \\\n                    hipGetErrorString(lbannv2_check_gpu_status));              \\\n      throw std::runtime_error(\"HIP command \\\"\" #cmd \"\\\" failed.\");            \\\n    }                                                                          \\\n  } while (0)\n#endif\n\nnamespace lbannv2\n{\n#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS\nnamespace c10_gpu = c10::hip;\nusing TorchGPUStream_t = c10::hip::HIPStream;\ninline auto& getDeviceCurrentStream = c10::hip::getCurrentHIPStream;\n#elif LBANNV2_HAS_GPU\nnamespace c10_gpu = c10::cuda;\nusing TorchGPUStream_t = c10::cuda::CUDAStream;\ninline auto& getDeviceCurrentStream = c10::cuda::getCurrentCUDAStream;\n#endif\n\ninline constexpr bool has_cuda() noexcept\n{\n  return LBANNV2_HAS_CUDA;\n}\n\ninline constexpr bool has_hip() noexcept\n{\n  return LBANNV2_HAS_ROCM;\n}\n\ninline constexpr bool has_gpu() noexcept\n{\n  return LBANNV2_HAS_GPU;\n}\n\nnamespace gpu\n{\n\n#if LBANNV2_HAS_CUDA\nusing Stream_t = cudaStream_t;\n#elif LBANNV2_HAS_ROCM\nusing Stream_t = hipStream_t;\n#endif\n\n// Returns 'false' if no GPU support\nbool is_integrated() noexcept;\n\n// Returns 0 if no GPU support\nc10::DeviceIndex num_devices() noexcept;\n\n// Returns -1 if no GPU support\nc10::DeviceIndex current_device();\n\n// Throws if d >= num_devices() or d < 0.\nvoid set_device(c10::DeviceIndex d);\n\n#if LBANNV2_HAS_GPU\nStream_t make_stream();\nStream_t make_nonblocking_stream();\nvoid sync(Stream_t);\nvoid destroy_stream(Stream_t);\n#endif\n\n}  // namespace gpu\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/utils/logging.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include \"lbannv2/utils/logging.hpp\"\n\n#include <memory>\n#include <string>\n\n#include <spdlog/pattern_formatter.h>\n#include <spdlog/sinks/basic_file_sink.h>\n#include <spdlog/sinks/stdout_color_sinks.h>\n\n#if __has_include(<unistd.h>)\n#include <unistd.h>\n#define _HAVE_UNISTD_H\n#endif\n\nnamespace\n{\nspdlog::level::level_enum get_env_log_level()\n{\n  if (char const* const var = std::getenv(\"LBANNV2_LOG_LEVEL\"))\n  {\n    std::string level_str {var};\n    std::for_each(begin(level_str), end(level_str), [](char& c) {\n      c = static_cast<char>(std::tolower(static_cast<unsigned char>(c)));\n    });\n\n    if (level_str == \"trace\")\n      return ::spdlog::level::trace;\n    if (level_str == \"debug\")\n      return ::spdlog::level::debug;\n    if (level_str == \"info\")\n      return ::spdlog::level::info;\n    if (level_str == \"warn\")\n      return ::spdlog::level::warn;\n    if (level_str == \"err\")\n      return ::spdlog::level::err;\n    if (level_str == \"critical\")\n      return ::spdlog::level::critical;\n    if (level_str == \"off\")\n      return ::spdlog::level::off;\n  }\n  return ::spdlog::level::info;\n}\n\nstd::string get_hostname()\n{\n#ifdef _HAVE_UNISTD_H\n  char buf[256];\n  if (gethostname(buf, 256) == 0)\n    return std::string {buf, std::find(buf, buf + 256, '\\0')};\n#endif\n\n  return \"<unknownhost>\";\n}\n\n// The one in H2 is not exported, but it's a quick reimplementation.\nclass HostFlag final : public spdlog::custom_flag_formatter\n{\npublic:\n  std::unique_ptr<custom_flag_formatter> clone() const final\n  {\n    return spdlog::details::make_unique<HostFlag>();\n  }\n  void format(::spdlog::details::log_msg const&,\n              ::std::tm const&,\n              ::spdlog::memory_buf_t& dest)\n  {\n    static std::string const hostname = get_hostname();\n    dest.append(hostname);\n  }\n};  // class HostFlag\n\nstd::unique_ptr<::spdlog::pattern_formatter> make_default_formatter()\n{\n  auto formatter = std::make_unique<::spdlog::pattern_formatter>();\n  formatter->add_flag<HostFlag>('h');\n  formatter->set_pattern(\"[%h:%P:%t] [%n:%^%l%$] %v\");\n  // formatter->set_pattern(\"[%m-%d-%Y %T.%f] [%h:%P] [%n] [%^%l%$] %v\");\n  return formatter;\n}\n\n::spdlog::sink_ptr make_default_sink()\n{\n  char const* sink_name = std::getenv(\"LBANNV2_LOG_FILE\");\n  std::string const sink_name_str(sink_name ? sink_name : \"stdout\");\n  if (sink_name_str == \"stdout\")\n    return std::make_shared<spdlog::sinks::stdout_color_sink_mt>();\n  if (sink_name_str == \"stderr\")\n    return std::make_shared<spdlog::sinks::stderr_color_sink_mt>();\n  return std::make_shared<spdlog::sinks::basic_file_sink_mt>(sink_name_str);\n}\n\nstd::shared_ptr<::spdlog::logger> make_default_logger()\n{\n  auto logger =\n    std::make_shared<::spdlog::logger>(\"lbannv2\", make_default_sink());\n  logger->set_formatter(make_default_formatter());\n  logger->set_level(get_env_log_level());\n  return logger;\n}\n\n}  // namespace\n\nvoid lbannv2::set_log_level(std::string const& lvl_str)\n{\n  // Valid inputs: trace, debug, info, warn, error, critical, off\n  auto const lvl = spdlog::level::from_str(lvl_str);\n  lbannv2::default_logger()->set_level(lvl);\n}\n\nstd::shared_ptr<::spdlog::logger>& lbannv2::default_logger()\n{\n  static std::shared_ptr<::spdlog::logger> logger_ = make_default_logger();\n  return logger_;\n}\n"
  },
  {
    "path": "src/lbannv2/utils/logging.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include <lbannv2_config.h>\n\n/**\n * @file Enable spdlog logging for LBANNv2.\n *\n * The symbols in this file are not exported by default so any\n * hypothetical downstream doesn't take over our logger.\n *\n * The logger macros that include `LOG` in their names take a logger\n * pointer as their first argument. The other macros use the default\n * LBANNv2 logger.\n */\n\n#include <spdlog/spdlog.h>\n\n// These dispatch through SPDLOG's default macros. Hence, their\n// behavior is ultimately determined by the SPDLOG_ACTIVE_LEVEL macro.\n#define LBANNV2_LOG_TRACE(logger, ...) SPDLOG_LOGGER_TRACE(logger, __VA_ARGS__)\n#define LBANNV2_LOG_DEBUG(logger, ...) SPDLOG_LOGGER_DEBUG(logger, __VA_ARGS__)\n#define LBANNV2_LOG_INFO(logger, ...) SPDLOG_LOGGER_INFO(logger, __VA_ARGS__)\n#define LBANNV2_LOG_WARN(logger, ...) SPDLOG_LOGGER_WARN(logger, __VA_ARGS__)\n#define LBANNV2_LOG_ERROR(logger, ...) SPDLOG_LOGGER_ERROR(logger, __VA_ARGS__)\n#define LBANNV2_LOG_CRITICAL(logger, ...)                                      \\\n  SPDLOG_LOGGER_CRITICAL(logger, __VA_ARGS__)\n\n#define LBANNV2_TRACE(...)                                                     \\\n  LBANNV2_LOG_TRACE(::lbannv2::default_logger(), __VA_ARGS__)\n#define LBANNV2_DEBUG(...)                                                     \\\n  LBANNV2_LOG_DEBUG(::lbannv2::default_logger(), __VA_ARGS__)\n#define LBANNV2_INFO(...)                                                      \\\n  LBANNV2_LOG_INFO(::lbannv2::default_logger(), __VA_ARGS__)\n#define LBANNV2_WARN(...)                                                      \\\n  LBANNV2_LOG_WARN(::lbannv2::default_logger(), __VA_ARGS__)\n#define LBANNV2_ERROR(...)                                                     \\\n  LBANNV2_LOG_ERROR(::lbannv2::default_logger(), __VA_ARGS__)\n#define LBANNV2_CRITICAL(...)                                                  \\\n  LBANNV2_LOG_CRITICAL(::lbannv2::default_logger(), __VA_ARGS__)\n\nnamespace lbannv2\n{\n/** @brief Get LBANNv2's default logger.\n *\n *  The default logger is configured through the environment variable\n *  `LBANNV2_LOG_FILE`. Acceptable values are 'stdout', 'stderr', and\n *  a valid filename pattern.\n *\n *  @todo Enable logging to a process-specific file.\n */\nstd::shared_ptr<::spdlog::logger>& default_logger();\n\n/** @brief Set the logging level.\n *\n *  \\param[in] level Desired log level. Valid choices are \"trace\", \"debug\",\n *                   \"info\", \"warn\", \"error\", \"critical\", and \"off\".\n */\nvoid set_log_level(std::string const& level);\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "src/lbannv2/utils/tensor_helpers.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include \"lbannv2/utils/errors.hpp\"\n\n#include <ATen/NamedTensorUtils.h>\n#include <ATen/Tensor.h>\n#include <c10/util/ArrayRef.h>\n\nnamespace lbannv2\n{\n\n/** @brief Determines if t is associated with LBANN */\ninline bool is_lbann(at::Tensor const& t)\n{\n  return t.is_privateuseone();\n}\n\ninline bool is_scalar(at::Tensor const& t)\n{\n  return t.defined() && (t.dim() == 0);\n}\n\ninline void set_data_ptr_device(c10::DataPtr& dp, c10::Device d)\n{\n  dp.unsafe_set_device(std::move(d));\n}\n\ninline void set_data_ptr_device(c10::Storage const& s, c10::Device d)\n{\n  set_data_ptr_device(s.mutable_data_ptr(), std::move(d));\n}\n\ninline void sync_metadata(at::Tensor const& src, at::Tensor& dst)\n{\n  auto* dst_tensor_info = dst.unsafeGetTensorImpl();\n  dst_tensor_info->set_storage_offset(src.storage_offset());\n  dst_tensor_info->set_sizes_and_strides(src.sizes(), src.strides());\n\n  // I assume this restores named dimensions? Not sure if it\n  // should be here or not. See \"alias_with_sizes_and_strides\"\n  // in <pytorch>/aten/src/ATen/native/TensorShape.cpp\n  at::namedinference::propagate_names(dst, src);\n}\n\n/** @brief Make an alias of the tensor on a new backend\n *\n *  This function can be used to produce aliases with diffent devices,\n *  different dispatch keys, or both (or neither, I suppose).\n *\n *  @post The original tensor will keep its device type and keys, but\n *        its DataPtr will appear to be on the new device if queried.\n */\ninline at::Tensor alias_as_device(at::Tensor const& orig_tensor,\n                                  c10::Device const& d,\n                                  c10::DispatchKeySet ks)\n{\n  // Make (soft) copy of the storage and set the device to be the real\n  // underlying device.\n  at::Storage aliased_storage(orig_tensor.storage());\n  set_data_ptr_device(aliased_storage, d);\n\n  // Set up a view with this storage, using the modified keyset.\n  auto alias_tensor =\n    at::detail::make_tensor<at::TensorImpl>(c10::TensorImpl::VIEW,\n                                            std::move(aliased_storage),\n                                            std::move(ks),\n                                            orig_tensor.dtype());\n\n  // Setup sizes, strides, and storage offset.\n  sync_metadata(orig_tensor, alias_tensor);\n\n  // Quick sanity check before we go\n  LBANNV2_ASSERT(alias_tensor.const_data_ptr() == orig_tensor.const_data_ptr(),\n                 std::runtime_error,\n                 \"Aliasing tensor data has failed\");\n\n  return alias_tensor;\n}\n\n/** @brief Minimal tensor stringification.\n *\n *  Returns \"[ {device type}{data type}[d1, d2, ..., dn] ]\", for\n *  example, \"[ lbannFloatType[2, 2] ]\" for a 2x2 Float32 tensor on\n *  the LBANN backend.\n */\ninline std::string to_str(at::Tensor const& t)\n{\n  std::ostringstream oss;\n  oss << \"[ \" << t.toString() << t.sizes() << \" ]\";\n  return oss.str();\n}\n\n/** @brief ArrayRef stringification */\ntemplate <typename T>\nstd::string to_str(c10::ArrayRef<T> const& ar)\n{\n  std::ostringstream oss;\n  oss << ar;\n  return oss.str();\n}\n\n}  // namespace lbannv2\n"
  },
  {
    "path": "test/CMakeLists.txt",
    "content": "################################################################################\n## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n## LBANN Project Developers. See the top-level LICENSE file for details.\n##\n## SPDX-License-Identifier: Apache-2.0\n################################################################################\ninclude(FetchContent)\nFetchContent_Declare(\n  Catch2\n  GIT_REPOSITORY https://github.com/catchorg/Catch2\n  GIT_TAG fa43b77429ba76c462b1898d6cd2f2d7a9416b14 # v3.7.1\n  FIND_PACKAGE_ARGS 3.0.0 CONFIG)\nFetchContent_MakeAvailable(Catch2)\n\nadd_executable(catch-tests\n  cpp/test_pointer_registry.cpp\n)\n\nif (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)\n  target_sources(catch-tests\n    PRIVATE\n    cpp/test_mi300a_allocator.cpp\n  )\nendif ()\n\ntarget_link_libraries(catch-tests\n  PRIVATE\n  lbann::lbannv2\n  Catch2::Catch2WithMain\n)\n\nset_target_properties(catch-tests\n  PROPERTIES\n  CXX_STANDARD 20\n  CXX_STANDARD_REQUIRED ON\n  CXX_EXTENSIONS ON\n)\n"
  },
  {
    "path": "test/cpp/test_empty_tensor.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2_config.h>\n\n#include <lbannv2/ops/empty_tensor.hpp>\n#include <lbannv2/utils/device_helpers.hpp>\n\n#include <ATen/Tensor.h>\n#include <c10/core/ScalarType.h>\n#include <c10/util/ArrayRef.h>\n#include <catch2/catch_test_macros.hpp>\n#include <catch2/generators/catch_generators.hpp>\n#include <catch2/matchers/catch_matchers_string.hpp>\n\nnamespace\n{\n// This factory function can throw, and we cannot wrap an assignment\n// in Catch's `REQUIRE_NOTHROW`/`REQUIRE_THROWS*` macros. We use this\n// simple wrapper to facilitate things. eglot+clangd is able to still\n// forward the inlay hints from `empty_lbann` out to the\n// `make_empty_tensor` signature, so that's cool.\ntemplate <typename... Args>\nvoid make_empty_tensor(at::Tensor& t, Args&&... args)\n{\n  t = lbannv2::empty_lbann(std::forward<Args>(args)...);\n}\n}  // namespace\n\nTEST_CASE(\"empty_lbann\", \"[ops][empty]\")\n{\n  at::Tensor t;\n  c10::Device lbann_cpu {lbannv2::LBANNDeviceT, 0},\n    lbann_gpu {lbannv2::LBANNDeviceT, 1};\n  SECTION(\"Zero-size tensor is ok\")\n  {\n#if LBANNV2_HAS_GPU\n    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));\n#else\n    auto lbann_device = lbann_cpu;\n#endif\n\n    REQUIRE_NOTHROW(make_empty_tensor(t,\n                                      c10::IntArrayRef {0},\n                                      c10::ScalarType::Float,\n                                      std::nullopt,\n                                      lbann_device,\n                                      false,\n                                      std::nullopt));\n    REQUIRE(t.dim() == 1);\n    REQUIRE(t.sizes() == c10::IntArrayRef {0});\n    REQUIRE(t.strides() == c10::IntArrayRef {1});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n  }\n\n  SECTION(\"Nonzero tensor is ok\")\n  {\n#if LBANNV2_HAS_GPU\n    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));\n#else\n    auto lbann_device = lbann_cpu;\n#endif\n    REQUIRE_NOTHROW(make_empty_tensor(t,\n                                      c10::IntArrayRef {3, 4},\n                                      c10::ScalarType::Float,\n                                      std::nullopt,\n                                      lbann_device,\n                                      false,\n                                      std::nullopt));\n    REQUIRE(t.dim() == 2);\n    REQUIRE(t.sizes() == c10::IntArrayRef {3, 4});\n    REQUIRE(t.strides() == c10::IntArrayRef {4, 1});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n\n    REQUIRE_NOTHROW(make_empty_tensor(t,\n                                      c10::IntArrayRef {2, 3, 4, 5},\n                                      c10::ScalarType::Float,\n                                      std::nullopt,\n                                      lbann_device,\n                                      false,\n                                      std::nullopt));\n    REQUIRE(t.dim() == 4);\n    REQUIRE(t.sizes() == c10::IntArrayRef {2, 3, 4, 5});\n    REQUIRE(t.strides() == c10::IntArrayRef {60, 20, 5, 1});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n  }\n\n  SECTION(\"Non-LBANN devices throw\")\n  {\n    REQUIRE_THROWS_WITH(\n      make_empty_tensor(t,\n                        c10::IntArrayRef {3, 4},\n                        c10::ScalarType::Float,\n                        std::nullopt,\n                        c10::DeviceType::CPU,\n                        false,\n                        std::nullopt),\n      \"LBANN should only be constructing tensors on \\\"PrivateUse1\\\" backend\");\n  }\n}\n\nnamespace\n{\n\ntemplate <typename... Args>\nvoid make_empty_strided_tensor(at::Tensor& t, Args&&... args)\n{\n  t = lbannv2::empty_strided_lbann(std::forward<Args>(args)...);\n}\n\n}  // namespace\n\nTEST_CASE(\"empty_strided_lbann\", \"[ops][empty]\")\n{\n  at::Tensor t;\n  c10::Device lbann_cpu {lbannv2::LBANNDeviceT, 0},\n    lbann_gpu {lbannv2::LBANNDeviceT, 1};\n\n  SECTION(\"Zero-size tensor is ok\")\n  {\n#if LBANNV2_HAS_GPU\n    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));\n#else\n    auto lbann_device = lbann_cpu;\n#endif\n    REQUIRE_NOTHROW(make_empty_strided_tensor(t,\n                                              c10::IntArrayRef {0},\n                                              c10::IntArrayRef {1},\n                                              c10::ScalarType::Float,\n                                              std::nullopt,\n                                              lbann_device,\n                                              false));\n    REQUIRE(t.dim() == 1);\n    REQUIRE(t.sizes() == c10::IntArrayRef {0});\n    REQUIRE(t.strides() == c10::IntArrayRef {1});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n  }\n\n  SECTION(\"Nonzero tensor is ok\")\n  {\n#if LBANNV2_HAS_GPU\n    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));\n#else\n    auto lbann_device = lbann_cpu;\n#endif\n    REQUIRE_NOTHROW(make_empty_strided_tensor(t,\n                                              c10::IntArrayRef {3, 4},\n                                              c10::IntArrayRef {8, 2},\n                                              c10::ScalarType::Float,\n                                              std::nullopt,\n                                              lbann_device,\n                                              false));\n    REQUIRE(t.dim() == 2);\n    REQUIRE(t.sizes() == c10::IntArrayRef {3, 4});\n    REQUIRE(t.strides() == c10::IntArrayRef {8, 2});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n\n    REQUIRE_NOTHROW(make_empty_strided_tensor(t,\n                                              c10::IntArrayRef {2, 3, 4, 5},\n                                              c10::IntArrayRef {120, 40, 10, 2},\n                                              c10::ScalarType::Float,\n                                              std::nullopt,\n                                              std::nullopt,\n                                              false));\n    REQUIRE(t.dim() == 4);\n    REQUIRE(t.sizes() == c10::IntArrayRef {2, 3, 4, 5});\n    REQUIRE(t.strides() == c10::IntArrayRef {120, 40, 10, 2});\n    REQUIRE(t.is_privateuseone());\n    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);\n    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));\n    REQUIRE_FALSE(t.is_pinned());\n  }\n\n  SECTION(\"Non-LBANN devices throw\")\n  {\n    REQUIRE_THROWS_WITH(\n      make_empty_strided_tensor(t,\n                                c10::IntArrayRef {3, 4},\n                                c10::IntArrayRef {8, 2},\n                                c10::ScalarType::Float,\n                                std::nullopt,\n                                c10::DeviceType::CPU,\n                                false),\n      \"LBANN should only be constructing tensors on \\\"PrivateUse1\\\" backend\");\n  }\n}\n"
  },
  {
    "path": "test/cpp/test_helpers.hpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#pragma once\n\n#include \"lbannv2_config.h\"\n\n#include <lbannv2/utils/gpu_utils.hpp>\n\n#include <catch2/catch_test_macros.hpp>\n\n#if LBANNV2_WITH_MI300A\n#define SKIP_WHEN_NO_MI300A()\n#elif LBANNV2_WITHOUT_MI300A\n#define SKIP_WHEN_NO_MI300A() SKIP(\"No MI300A support\")\n#elif LBANNV2_UNKNOWN_MI300A\n#include <h2/gpu/runtime.hpp>\n#define SKIP_WHEN_NO_MI300A()                                                  \\\n  do                                                                           \\\n  {                                                                            \\\n    if (!lbannv2::gpu::is_integrated())                                        \\\n    {                                                                          \\\n      SKIP(\"No MI300A support\");                                               \\\n    }                                                                          \\\n  } while (0)\n#endif\n"
  },
  {
    "path": "test/cpp/test_mi300a_allocator.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2/memory/mi300a_allocator.hpp>\n#include <lbannv2/memory/registry.hpp>\n#include <lbannv2/utils/gpu_utils.hpp>\n\n#include \"test_helpers.hpp\"\n\n#include <ATen/hip/HIPContextLight.h>\n#include <c10/core/Allocator.h>\n#include <c10/core/CPUAllocator.h>\n#include <c10/hip/HIPCachingAllocator.h>\n#include <c10/hip/HIPStream.h>\n\n#include <catch2/catch_test_macros.hpp>\n#include <catch2/matchers/catch_matchers_string.hpp>\n\nnamespace\n{\n\nvoid do_raw_allocate(void** ptr, size_t size, lbannv2::MI300Allocator& alloc)\n{\n  *ptr = alloc.raw_allocate(size);\n}\n\n}  // namespace\n\nTEST_CASE(\"MI300Allocator::raw_allocate and MI300Allocator::raw_deallocate\",\n          \"[memory][mi300a]\")\n{\n  SKIP_WHEN_NO_MI300A();\n\n  auto& alloc = lbannv2::MI300Allocator::instance();\n  size_t const size = 64;\n  void* ptr = nullptr;\n  REQUIRE_NOTHROW(do_raw_allocate(&ptr, size, alloc));\n\n  CHECK(ptr != nullptr);\n\n  REQUIRE_NOTHROW(alloc.raw_deallocate(ptr));\n}\n\nnamespace\n{\nc10::Device lbann_cpu() noexcept\n{\n  return c10::Device {c10::kCPU};\n}\nc10::Device lbann_gpu() noexcept\n{\n  return c10::Device {\n    c10::kCUDA, static_cast<c10::DeviceIndex>(lbannv2::gpu::current_device())};\n}\n}  // namespace\n\nTEST_CASE(\"MI300Allocator::allocate and MI300Allocator::deallocate\",\n          \"[memory][mi300a]\")\n{\n  SKIP_WHEN_NO_MI300A();\n\n  auto& alloc = lbannv2::MI300Allocator::instance();\n  size_t const size = 64;\n\n  void* raw_ptr = nullptr;\n  {\n    auto ptr = alloc.allocate(size);\n    raw_ptr = ptr.get();\n    CHECK(ptr.device() == lbann_cpu());\n    CHECK(lbannv2::pointer_registry().known(raw_ptr));\n  }\n\n  // DataPtr goes out of scope, should be deleted.\n\n  CHECK(!lbannv2::pointer_registry().known(raw_ptr));\n}\n\n// The \"kernel\" here is loosely inspired by Aluminum's \"GPUWait\", but\n// less fussy about things like \"cache-line allocation\" and\n// \"atomics\"... All I need is something to guarantee the stream isn't\n// synced before the second allocation, and this saves me the trouble\n// of compiling a HIP kernel.\nTEST_CASE(\"MI300Allocator stream semantics are working\", \"[memory][mi300a]\")\n{\n  auto const gpu = lbann_gpu();\n\n  // Some memory we can use later.\n  int32_t* wait_mem;\n  LBANNV2_CHECK_GPU(hipMalloc(&wait_mem, sizeof(int32_t)));\n  *wait_mem = 0;\n\n  int32_t const wait_value = 1;\n  auto& alloc = lbannv2::MI300Allocator::instance();\n  size_t const size = 64;\n\n  // open block\n  //   do an allocation\n  //   migrate allocation to GPU\n  //   \"run a kernel\" on the same stream\n  // close block (delete the allocation)\n  // allocate new buffer\n  // check old and new buffers have different addresses\n\n  auto torch_stream = lbannv2::getDeviceCurrentStream(gpu.index());\n  void* orig_ptr = nullptr;  // never dereferenced\n  {\n    auto ptr = alloc.allocate(size);\n    // cache the buffer address -- NEVER DEREFERENCED\n    orig_ptr = ptr.get();\n\n    // Add the ptr to the stream on GPU\n    lbannv2::migrate_ptr(ptr, gpu, torch_stream);\n    // Fake a kernel on the stream\n    LBANNV2_CHECK_GPU(hipStreamWaitValue32(\n      torch_stream, wait_mem, wait_value, hipStreamWaitValueEq));\n  }\n  // GPU allocation will \"FREE_REQUESTED\" here, but it should NOT be\n  // available for reuse\n\n  auto ptr = alloc.allocate(size);\n  CHECK(ptr.get() != orig_ptr);  // NOT REQUIRE -- need to clean up.\n\n  // Write the new value\n  *wait_mem = wait_value;\n\n  // Ensure the \"kernel\" is done.\n  torch_stream.synchronize();\n\n  // Free our wait memory\n  LBANNV2_CHECK_GPU(hipFree(wait_mem));\n}\n"
  },
  {
    "path": "test/cpp/test_pointer_registry.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2/memory/registry.hpp>\n\n#include <catch2/catch_test_macros.hpp>\n#include <catch2/matchers/catch_matchers_string.hpp>\n\nTEST_CASE(\"RangeLessAndDisjoint\", \"[memory][registry]\")\n{\n  std::vector<unsigned char> buffer(8);\n  lbannv2::PointerRegistry::RangeLessAndDisjoint rng_less;\n\n  SECTION(\"Non-overlapping ranges behave sanely\")\n  {\n    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[3], &buffer[4]}));\n    CHECK_FALSE(rng_less({&buffer[3], &buffer[4]}, {&buffer[1], &buffer[2]}));\n  }\n\n  SECTION(\"Abutting ranges are nonoverlapping\")\n  {\n    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[2], &buffer[3]}));\n    CHECK_FALSE(rng_less({&buffer[2], &buffer[3]}, {&buffer[1], &buffer[2]}));\n  }\n\n  SECTION(\"Identical ranges\")\n  {\n    CHECK_FALSE(rng_less({&buffer[1], &buffer[4]}, {&buffer[1], &buffer[4]}));\n  }\n\n  SECTION(\"Partially overlapping ranges\")\n  {\n    CHECK_FALSE(rng_less({&buffer[1], &buffer[4]}, {&buffer[2], &buffer[5]}));\n    CHECK_FALSE(rng_less({&buffer[2], &buffer[5]}, {&buffer[1], &buffer[4]}));\n  }\n\n  SECTION(\"One range proper subset of the other\")\n  {\n    CHECK_FALSE(rng_less({&buffer[1], &buffer[8]}, {&buffer[3], &buffer[4]}));\n    CHECK_FALSE(rng_less({&buffer[3], &buffer[4]}, {&buffer[1], &buffer[8]}));\n  }\n\n  SECTION(\"Zero-size ranges work appropriately\")\n  {\n    CHECK(rng_less({&buffer[1], &buffer[1]}, {&buffer[2], &buffer[2]}));\n    CHECK_FALSE(rng_less({&buffer[2], &buffer[2]}, {&buffer[1], &buffer[1]}));\n\n    CHECK(rng_less({&buffer[1], &buffer[1]}, {&buffer[2], &buffer[4]}));\n    CHECK_FALSE(rng_less({&buffer[2], &buffer[4]}, {&buffer[1], &buffer[1]}));\n\n    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[2], &buffer[2]}));\n    CHECK_FALSE(rng_less({&buffer[2], &buffer[2]}, {&buffer[1], &buffer[2]}));\n\n    CHECK(rng_less({&buffer[1], &buffer[2]}, &buffer[2]));\n    CHECK(rng_less(&buffer[1], {&buffer[2], &buffer[3]}));\n\n    CHECK_FALSE(rng_less(&buffer[1], {&buffer[1], &buffer[2]}));\n    CHECK_FALSE(rng_less({&buffer[1], &buffer[2]}, &buffer[1]));\n\n    CHECK_FALSE(rng_less(&buffer[1], {&buffer[1], &buffer[1]}));\n    CHECK_FALSE(rng_less({&buffer[1], &buffer[1]}, &buffer[1]));\n  }\n}\n\nnamespace\n{\nsize_t rng_bytes(std::pair<void*, void*> const& r)\n{\n  return std::distance((std::byte*) r.first, (std::byte*) r.second);\n}\n}  // namespace\n\nTEST_CASE(\"PointerRegistry::add()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(32);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  SECTION(\"Adding nonoverlapping regions is successful.\")\n  {\n    RangeT const rng1 = {&buffer[4], &buffer[8]};\n    RangeT const rng2 = {&buffer[12], &buffer[16]};\n    RangeT const rng3 = {&buffer[16], &buffer[20]};\n    RangeT const rng4 = {&buffer[8], &buffer[12]};\n\n    size_t expected_bytes = 0UL;\n    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));\n    expected_bytes += rng_bytes(rng1);\n\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == expected_bytes);\n\n    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));\n    expected_bytes += rng_bytes(rng2);\n\n    REQUIRE(registry.num_registered() == 2UL);\n    REQUIRE(registry.bytes_registered() == expected_bytes);\n\n    REQUIRE_NOTHROW(registry.add(rng3.first, rng_bytes(rng3), nullptr));\n    expected_bytes += rng_bytes(rng3);\n\n    REQUIRE(registry.num_registered() == 3UL);\n    REQUIRE(registry.bytes_registered() == expected_bytes);\n\n    REQUIRE_NOTHROW(registry.add(rng4.first, rng_bytes(rng4), nullptr));\n    expected_bytes += rng_bytes(rng4);\n\n    REQUIRE(registry.num_registered() == 4UL);\n    REQUIRE(registry.bytes_registered() == expected_bytes);\n  }\n\n  SECTION(\"Zero-size regions\")\n  {\n    SECTION(\"Adding zero-size regions is ok.\")\n    {\n      RangeT const rng1 = {&buffer[0], &buffer[0]};\n      RangeT const rng2 = {&buffer[2], &buffer[2]};\n\n      REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));\n      REQUIRE(registry.num_registered() == 1UL);\n      REQUIRE(registry.bytes_registered() == 0UL);\n\n      REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));\n      REQUIRE(registry.num_registered() == 2UL);\n      REQUIRE(registry.bytes_registered() == 0UL);\n    }\n\n    SECTION(\"Zero-size regions are not valid start points for other regions\")\n    {\n      RangeT const zero_rng = {&buffer[0], &buffer[0]};\n      RangeT const other_rng = {&buffer[0], &buffer[2]};\n\n      REQUIRE_NOTHROW(\n        registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));\n      REQUIRE(registry.num_registered() == 1UL);\n      REQUIRE(registry.bytes_registered() == 0UL);\n\n      REQUIRE_THROWS_WITH(\n        registry.add(other_rng.first, rng_bytes(other_rng), nullptr),\n        \"Address range overlaps existing range\");\n    }\n\n    SECTION(\"Zero-size regions are valid end points for other regions\")\n    {\n      RangeT const other_rng = {&buffer[0], &buffer[2]};\n      RangeT const zero_rng = {&buffer[2], &buffer[2]};\n\n      REQUIRE_NOTHROW(\n        registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));\n      REQUIRE(registry.num_registered() == 1UL);\n      REQUIRE(registry.bytes_registered() == 0UL);\n\n      REQUIRE_NOTHROW(\n        registry.add(other_rng.first, rng_bytes(other_rng), nullptr));\n      REQUIRE(registry.num_registered() == 2UL);\n      REQUIRE(registry.bytes_registered() == rng_bytes(other_rng));\n    }\n  }\n}\n\nTEST_CASE(\"PointerRegistry::remove()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(32);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  SECTION(\"Removing a context pointer works\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[8]};\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE_NOTHROW(registry.remove(rng.first));\n    REQUIRE(registry.num_registered() == 0UL);\n    REQUIRE(registry.bytes_registered() == 0UL);\n  }\n\n  SECTION(\"Removing a known non-context pointer fails\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[8]};\n    void* const noncontext_ptr = &buffer[6];\n\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE_THROWS_WITH(registry.remove(noncontext_ptr),\n                        \"Cannot remove ptr; not beginning of range.\");\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n  }\n\n  SECTION(\"Removing an unknown pointer fails\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[8]};\n    void* const unknown_ptr = &buffer[16];\n\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE_THROWS_AS(registry.remove(unknown_ptr), lbannv2::UnknownAddress);\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n  }\n\n  SECTION(\"Removing a zero-size region is ok\")\n  {\n    RangeT const rng = {&buffer[2], &buffer[2]};\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == 0UL);\n\n    REQUIRE_NOTHROW(registry.remove(rng.first));\n    REQUIRE(registry.num_registered() == 0UL);\n    REQUIRE(registry.bytes_registered() == 0UL);\n  }\n}\n\nTEST_CASE(\"PointerRegistry::known()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(32);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  SECTION(\"Pointers in registered ranges are known\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[8]};\n    void const* const context_ptr = &buffer[4];\n    void const* const noncontext_ptr = &buffer[6];\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE(registry.known(context_ptr));\n    REQUIRE(registry.known(noncontext_ptr));\n  }\n\n  SECTION(\"Registered pointer in size-zero ranges are known\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[4]};\n    void const* const context_ptr = &buffer[4];\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE(registry.known(context_ptr));\n  }\n\n  SECTION(\"Pointers outside registered ranges are not known\")\n  {\n    RangeT const rng = {&buffer[4], &buffer[8]};\n    void const* const unknown_low_ptr = &buffer[2];\n    void const* const unknown_ub_ptr = &buffer[8];\n    void const* const unknown_high_ptr = &buffer[14];\n\n    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n    REQUIRE(registry.num_registered() == 1UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng));\n\n    REQUIRE_FALSE(registry.known(unknown_low_ptr));\n    REQUIRE_FALSE(registry.known(unknown_ub_ptr));\n    REQUIRE_FALSE(registry.known(unknown_high_ptr));\n  }\n}\n\nTEST_CASE(\"PointerRegistry::get_context()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(32);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  SECTION(\"Context pointers are their own context\")\n  {\n    RangeT const rng1 = {&buffer[4], &buffer[8]};\n    RangeT const rng2 = {&buffer[12], &buffer[16]};\n    RangeT const zero_rng = {&buffer[20], &buffer[20]};\n\n    void const* const context_ptr1 = &buffer[4];\n    void const* const context_ptr2 = &buffer[12];\n    void const* const zero_context_ptr = &buffer[20];\n\n    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));\n    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));\n    REQUIRE_NOTHROW(registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));\n    REQUIRE(registry.num_registered() == 3UL);\n    REQUIRE(registry.bytes_registered()\n            == rng_bytes(rng1) + rng_bytes(rng2) + rng_bytes(zero_rng));\n\n    REQUIRE(registry.get_context(context_ptr1) == rng1.first);\n    REQUIRE(registry.get_context(context_ptr2) == rng2.first);\n    REQUIRE(registry.get_context(zero_context_ptr) == zero_rng.first);\n  }\n\n  SECTION(\"Noncontext pointers return the proper context pointer\")\n  {\n    RangeT const rng1 = {&buffer[4], &buffer[8]};\n    RangeT const rng2 = {&buffer[12], &buffer[16]};\n\n    void const* const noncontext_ptr1 = &buffer[6];\n    void const* const noncontext_ptr2 = &buffer[14];\n\n    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));\n    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));\n\n    REQUIRE(registry.num_registered() == 2UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng1) + rng_bytes(rng2));\n\n    REQUIRE(registry.get_context(noncontext_ptr1) == rng1.first);\n    REQUIRE(registry.get_context(noncontext_ptr2) == rng2.first);\n  }\n\n  SECTION(\"Unknown pointers fail\")\n  {\n    RangeT const rng1 = {&buffer[4], &buffer[8]};\n    RangeT const rng2 = {&buffer[12], &buffer[16]};\n\n    void const* const ptr1 = &buffer[2];\n    void const* const ptr2 = &buffer[8];\n    void const* const ptr3 = &buffer[10];\n    void const* const ptr4 = &buffer[16];\n    void const* const ptr5 = &buffer[20];\n\n    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));\n    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));\n\n    REQUIRE(registry.num_registered() == 2UL);\n    REQUIRE(registry.bytes_registered() == rng_bytes(rng1) + rng_bytes(rng2));\n\n    REQUIRE_THROWS_AS(registry.get_context(ptr1), lbannv2::UnknownAddress);\n    REQUIRE_THROWS_AS(registry.get_context(ptr2), lbannv2::UnknownAddress);\n    REQUIRE_THROWS_AS(registry.get_context(ptr3), lbannv2::UnknownAddress);\n    REQUIRE_THROWS_AS(registry.get_context(ptr4), lbannv2::UnknownAddress);\n    REQUIRE_THROWS_AS(registry.get_context(ptr5), lbannv2::UnknownAddress);\n  }\n}\n\nTEST_CASE(\"PointerRegistry::unsafe_reset_allocator()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(32);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  RangeT const rng = {&buffer[4], &buffer[8]};\n\n  void const* const ctxt_ptr = &buffer[4];\n  void const* const mid_ptr = &buffer[6];\n  void const* const bad_ptr = &buffer[0];\n\n  c10::Allocator& alloc = *c10::GetAllocator(c10::kCPU);\n  c10::Allocator* orig_alloc = &alloc;\n\n  // FAKE -- DO NOT DEREFERENCE!\n  c10::Allocator* other_alloc = ++orig_alloc;\n\n  // Get the allocator setup\n  REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), orig_alloc));\n  REQUIRE(registry.get_allocator(ctxt_ptr) == orig_alloc);\n  REQUIRE(registry.get_allocator(mid_ptr) == orig_alloc);\n\n  SECTION(\"Resetting by context is ok\")\n  {\n    REQUIRE_NOTHROW(registry.unsafe_reset_allocator(ctxt_ptr, other_alloc));\n    REQUIRE(registry.get_allocator(ctxt_ptr) == other_alloc);\n    REQUIRE(registry.get_allocator(mid_ptr) == other_alloc);\n  }\n\n  SECTION(\"Resetting by an interior pointer is ok\")\n  {\n    // FIXME: Perhaps this should actually be disallowed??\n    REQUIRE_NOTHROW(registry.unsafe_reset_allocator(ctxt_ptr, other_alloc));\n    REQUIRE(registry.get_allocator(ctxt_ptr) == other_alloc);\n    REQUIRE(registry.get_allocator(mid_ptr) == other_alloc);\n  }\n\n  SECTION(\"Resetting an unknown pointer fails\")\n  {\n    REQUIRE_THROWS_AS(registry.unsafe_reset_allocator(bad_ptr, other_alloc),\n                      lbannv2::UnknownAddress);\n  }\n}\n\nTEST_CASE(\"PointerRegistry::bytes_registered()\", \"[memory][registry]\")\n{\n  using RangeT = std::pair<void*, void*>;\n\n  lbannv2::PointerRegistry registry;\n  std::vector<unsigned char> buffer(16);\n\n  // Establish preconditions\n  REQUIRE(registry.num_registered() == 0UL);\n  REQUIRE(registry.bytes_registered() == 0UL);\n\n  RangeT const rng = {&buffer[4], &buffer[8]};\n  size_t const rng_size = rng_bytes(rng);\n\n  void const* const ctxt_ptr = &buffer[4];\n  void const* const mid_ptr = &buffer[6];\n  void const* const extern_ptr_1 = &buffer[0];\n  void const* const extern_ptr_2 = &buffer[16];\n\n  REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));\n  REQUIRE(registry.bytes_registered() == rng_size);\n\n  CHECK(registry.bytes_registered(ctxt_ptr) == rng_size);\n  CHECK(registry.bytes_registered(mid_ptr) == rng_size);\n  CHECK(registry.bytes_registered(extern_ptr_1) == 0UL);\n  CHECK(registry.bytes_registered(extern_ptr_2) == 0UL);\n}\n"
  },
  {
    "path": "test/cpp/test_tensor_helpers.cpp",
    "content": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other\n// LBANN Project Developers. See the top-level LICENSE file for details.\n//\n// SPDX-License-Identifier: Apache-2.0\n////////////////////////////////////////////////////////////////////////////////\n#include <lbannv2/ops/empty_tensor.hpp>\n#include <lbannv2/utils/tensor_helpers.hpp>\n\n#include <ATen/EmptyTensor.h>\n\n// A c10 header file in PyTorch has left a macro called `CHECK`\n// defined. To prevent warnings, we need to clear that out. This\n// should not cause problems as we don't use the PyTorch macro\n// directly, and all PyTorch includes should precede this line in this\n// source code.\n#ifdef CHECK\n#undef CHECK\n#endif\n\n#include <catch2/catch_test_macros.hpp>\n\nTEST_CASE(\"alias_as_device\", \"[tensor][utils]\")\n{\n  SECTION(\"Aliasing from LBANN to native device\")\n  {\n    at::Tensor t = lbannv2::empty_lbann({2, 3, 4},\n                                        c10::ScalarType::Float,\n                                        std::nullopt,\n                                        std::nullopt,\n                                        false,\n                                        std::nullopt);\n    auto const orig_keys = t.key_set();\n    auto const orig_device = t.device();\n\n    at::Tensor cpu_alias = lbannv2::alias_as_device(\n      t, c10::DeviceType::CPU, c10::DispatchKeySet {c10::DispatchKey::CPU});\n\n    CHECK(t.is_privateuseone());\n    CHECK(t.key_set() == orig_keys);\n    CHECK(t.device() == orig_device);\n\n    CHECK(cpu_alias.is_alias_of(t));\n    CHECK(cpu_alias.is_cpu());\n\n    // This is documented to change\n    CHECK(t.storage().data_ptr().device().is_cpu());\n\n    // Metadata should match\n    CHECK(cpu_alias.sizes() == t.sizes());\n    CHECK(cpu_alias.strides() == t.strides());\n    CHECK(cpu_alias.names() == t.names());\n    CHECK(cpu_alias.dtype() == t.dtype());\n  }\n}\n\nTEST_CASE(\"alias_as_native_device\", \"[tensor][utils]\")\n{\n  SECTION(\"Aliasing a native PyTorch tensor does nothing\")\n  {\n    at::Tensor t = at::detail::empty_cpu({3, 2, 4},\n                                         c10::ScalarType::Float,\n                                         std::nullopt,\n                                         std::nullopt,\n                                         std::nullopt,\n                                         std::nullopt);\n    at::Tensor alias = lbannv2::alias_as_native_device(t);\n    CHECK(alias.is_alias_of(t));\n    CHECK(alias.key_set() == t.key_set());\n    CHECK(alias.device() == t.device());\n    CHECK(alias.dtype() == t.dtype());\n    CHECK(alias.unsafeGetTensorImpl() == t.unsafeGetTensorImpl());\n  }\n\n  SECTION(\"Aliasing an LBANN tensor is ok\")\n  {\n    using namespace lbannv2;\n    static constexpr auto LBANNbit = c10::BackendComponent::PrivateUse1Bit;\n\n    at::Tensor t = lbannv2::empty_lbann({2, 3, 4},\n                                        c10::ScalarType::Float,\n                                        std::nullopt,\n                                        c10::Device {LBANNDeviceT, LBANN_CPU},\n                                        false,\n                                        std::nullopt);\n    at::Tensor lbann_alias = lbannv2::alias_as_native_device(t);\n\n    // Still an alias (based on storage objects)\n    CHECK(lbann_alias.is_alias_of(t));\n    CHECK(lbann_alias.key_set() == t.key_set().remove_backend(LBANNbit));\n    CHECK(lbann_alias.sizes() == t.sizes());\n    CHECK(lbann_alias.strides() == t.strides());\n    CHECK(lbann_alias.dtype() == t.dtype());\n    CHECK(lbann_alias.device() != t.device());\n    CHECK(lbann_alias.device().is_cpu());\n    CHECK(lbann_alias.unsafeGetTensorImpl()->data()\n          == t.unsafeGetTensorImpl()->data());\n    CHECK(lbann_alias.unsafeGetTensorImpl()->storage_offset()\n          == t.unsafeGetTensorImpl()->storage_offset());\n  }\n}\n"
  }
]