Full Code of LLNL/lbann for AI

main b8445e793297 cached
52 files
144.9 KB
38.8k tokens
107 symbols
1 requests
Download .txt
Repository: LLNL/lbann
Branch: main
Commit: b8445e793297
Files: 52
Total size: 144.9 KB

Directory structure:
gitextract_2qpyp2vm/

├── .clang-format
├── .gitignore
├── CMakeLists.txt
├── CONTRIBUTING.md
├── CONTRIBUTORS
├── LICENSE
├── NOTICE
├── README.md
├── cmake/
│   ├── LBANNv2DetectTorchNVIDIALibraries.cmake
│   ├── LBANNv2DetermineMI300A.cmake
│   ├── lbannv2Config.cmake.in
│   └── lbannv2_config.h.in
├── pyproject.toml
├── python/
│   └── lbannv2/
│       ├── __init__.py
│       └── _automigrate.py
├── src/
│   └── lbannv2/
│       ├── CMakeLists.txt
│       ├── memory/
│       │   ├── CMakeLists.txt
│       │   ├── allocator.cpp
│       │   ├── allocator.hpp
│       │   ├── h2_allocator_wrappers.cpp
│       │   ├── h2_allocator_wrappers.hpp
│       │   ├── memory_utils.hpp
│       │   ├── mi300a_allocator.cpp
│       │   ├── mi300a_allocator.hpp
│       │   ├── registry.cpp
│       │   └── registry.hpp
│       ├── ops/
│       │   ├── CMakeLists.txt
│       │   ├── migrate.cpp
│       │   ├── migrate.hpp
│       │   ├── nonzero.hip
│       │   ├── nonzero.hpp
│       │   ├── scalar.cpp
│       │   └── scalar.hpp
│       ├── python/
│       │   ├── CMakeLists.txt
│       │   ├── register_lbannv2.cpp
│       │   ├── register_memory_funcs.cpp
│       │   └── register_mi300a_ops.cpp
│       ├── types.hpp
│       └── utils/
│           ├── CMakeLists.txt
│           ├── debugging_helpers.hpp
│           ├── errors.hpp
│           ├── gpu_utils.cpp
│           ├── gpu_utils.hpp
│           ├── logging.cpp
│           ├── logging.hpp
│           └── tensor_helpers.hpp
└── test/
    ├── CMakeLists.txt
    └── cpp/
        ├── test_empty_tensor.cpp
        ├── test_helpers.hpp
        ├── test_mi300a_allocator.cpp
        ├── test_pointer_registry.cpp
        └── test_tensor_helpers.cpp

================================================
FILE CONTENTS
================================================

================================================
FILE: .clang-format
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################

Language: Cpp
BasedOnStyle: LLVM
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignArrayOfStructures: None
AlignConsecutiveMacros: None
AlignConsecutiveAssignments: None
AlignConsecutiveBitFields: None
AlignConsecutiveDeclarations: None
AlignEscapedNewlines: Right
AlignOperands: Align
AlignTrailingComments: true
AllowAllArgumentsOnNextLine: true
# AllowAllConstructorInitializersOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortEnumsOnASingleLine: true
AllowShortBlocksOnASingleLine: Always
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: InlineOnly
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
  AfterCaseLabel: true
  AfterClass: true
  AfterControlStatement: true
  AfterEnum: true
  AfterFunction: true
  AfterNamespace: true
  AfterStruct: true
  AfterUnion: true
  AfterExternBlock: true
  BeforeCatch: true
  BeforeElse: true
  BeforeLambdaBody: false
  BeforeWhile:     false
  IndentBraces: false
  SplitEmptyFunction: false
  SplitEmptyRecord: false
  SplitEmptyNamespace: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeConceptDeclarations: true
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializers: BeforeColon
BreakInheritanceList: BeforeColon
BreakStringLiterals: true
ColumnLimit: 80
CommentPragmas: "^ H2_DISPATCH"
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
FixNamespaceComments: true
# TypenameMacros: ['TODO']
IncludeBlocks: Regroup
# Hopefully this will keep the config file at the top...
IncludeCategories:
  - Regex:           '^<catch2/catch.hpp>'
    Priority:        -1                         # Always up front
  - Regex:           '^((<|").+_(config|export)\.h(pp)?(>|"))'        # Configure headers
    Priority:        -1
  - Regex:           '^((<|")lbannv2/)'   # Project headers
    Priority:         2
  - Regex:           '^((<|")h2/)'   # Project headers
    Priority:         2
  - Regex:           '^"[[:alnum:]_.]+\.(hpp|cuh)"'   # Project headers
    Priority:         3
  - Regex:           '^<[[:alnum:]_.]+\.(hpp|cuh)>'  # "Normal headers"
    Priority:         4
  - Regex:           '^(<(ATen|c10|torch|pybind11|python)/)'
    Priority:         4
  - Regex:           '<[[:alnum:]_]+>'          # STL Headers last
    Priority:         6
  - Regex:           '*.h'
    Priority:         7
IncludeIsMainRegex: "(Test)?$"
IndentCaseLabels: false
IndentPPDirectives: None
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
# PenaltyBreakAssignment
# PenaltyBreakBeforeFirstCallParameter
# PenaltyBreakComment
# PenaltyBreakFirstLessLess
# PenaltyBreakString
# PenaltyBreakTemplateDeclaration
# PenaltyExcessCharacter
# PenaltyReturnTypeOnItsOwnLine
PointerAlignment: Left
QualifierAlignment: Right
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: true
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Latest
# StatementMacros
TabWidth: 2
UseTab: Never


================================================
FILE: .gitignore
================================================
################################################################################
## Copyright 2019-2020 Lawrence Livermore National Security, LLC and other
## DiHydrogen Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################

# Emacs backup garbage
.backup/
.cache/

# Other standard ignores
*~
*.pyc
\#*#
.#*
.*.swp
.DS_Store
*.bak
.dir-locals.el
compile_commands.json

# building/install not-entirely-out-of-source stuff
build*/
install*/



================================================
FILE: CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
cmake_minimum_required(VERSION 3.27)
project(LBANNv2
  VERSION 0.0.1
  DESCRIPTION "DiHydrogen integration with PyTorch"
  HOMEPAGE_URL "https://github.com/lbann"
  LANGUAGES CXX
)

option(LBANNV2_DEBUG_MODE
  "Enable extra assertions helpful in debugging."
  OFF)

# Make Tom's life easier
set(CMAKE_EXPORT_COMPILE_COMMANDS ON
  CACHE BOOL "Write compile_commands.json" FORCE)

# FIXME (trb): This is probably the right thing, but we should think
# about if this is strictly needed.
set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 20) # For DiHydrogen

# FIXME (trb): These are generally useful for development and
# debugging. I should probably pass them on cmd line, but again, lazy.
set(CMAKE_CXX_FLAGS_DEBUG "-g3 -O0 -fno-omit-frame-pointer")
set(CMAKE_HIP_FLAGS_DEBUG "-g3 -O0 -fno-omit-frame-pointer")

# Language support
#
# Just set things for CUDA *and* HIP hoping they'll be ignored on
# irrelevant platforms.

# ampere, hopper
set(CMAKE_CUDA_ARCHITECTURES 80 90)
set(TORCH_CUDA_ARCH_LIST 8.0 9.0)
set(CMAKE_CUDA_STANDARD 17)

# MI50, MI250X, MI300A, MI300X
set(CMAKE_HIP_ARCHITECTURES gfx906 gfx90a gfx942)
set(ENV{PYTORCH_ROCM_ARCH} "${CMAKE_HIP_ARCHITECTURES}")
set(PYTORCH_ROCM_ARCH ${CMAKE_HIP_ARCHITECTURES})

# Setup dependencies

set(LBANNV2_MINIMUM_Python_VERSION 3.10)
set(LBANNV2_MINIMUM_H2_VERSION 0.4.0)
set(LBANNV2_MINIMUM_Torch_VERSION 2.9.0)

find_package(Python
  ${LBANNV2_MINIMUM_Python_VERSION}
  REQUIRED
  COMPONENTS Interpreter Development.Module)

# Interrogate the Python environment (via pip) to detect NVIDIA
# dependencies in the environment. Currently, this is based on the
# Torch module that's installed in the environment, if any exists, and
# meaningful values will only be returned if such a module exists.
#
# FIXME (trb): We just handle cuDNN and NCCL here because those are
# the only ones that overlap with Al/H2 needs, but we might consider
# adding paths for the rest of them since Torch will (presumably)
# depend on them.
#
# An alternative approach _could_ be to detect all NVIDIA modules
# known to pip and simply parse those. I'm not sure how realistic this
# might be in practice, but presumably one _could_ have
# nvidia-cudnn-cu11 and nvidia-cudnn-cu12 in the same environment, and
# one could imagine that those packages would provide distinct
# installations of these libraries (fun fact: they don't). Hence the
# preference to let PyTorch tell me which modules it should use. If
# someone was trying to use a Torch that Pip couldn't detect but with
# pip-managed NVIDIA modules, I would classify them as a "power user"
# and expect that they can handle adding command line arguments to the
# LBANNv2 build.
list(PREPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")

include(LBANNv2DetectTorchNVIDIALibraries)
detect_torch_nvidia_libraries(LIBRARIES cudnn nccl)

foreach (pkg cudnn nccl)
  if (LBANNV2_DETECTED_${pkg})
    string(TOUPPER "${pkg}" pkg_upper)
    set(${pkg_upper}_LIBRARY
      "${LBANNV2_DETECTED_${pkg}_LIBRARY}"
      CACHE
      FILEPATH
      "Path to ${pkg_upper} library." FORCE)
    set(${pkg_upper}_INCLUDE_PATH
      "${LBANNV2_DETECTED_${pkg}_INCLUDE_PATH}"
      CACHE
      PATH
      "Include directory for ${pkg}" FORCE)
  endif()
endforeach ()

# Special handling for Torch+cuDNN
if (LBANNV2_DETECTED_cudnn)
  # Torch uses "LIBRARY_PATH" for the location of the main cuDNN
  # library. Because why wouldn't they??
  set(CUDNN_LIBRARY_PATH
    "${LBANNV2_DETECTED_cudnn_LIBRARY}"
    CACHE
    FILEPATH
    "Path to cuDNN library.")

  set(CAFFE2_USE_CUDNN ON CACHE BOOL "Have the build search for cuDNN")
endif ()

# Ok, the CMake here gets a little rocky. The goal is to "pip install
# ." and it should just build "the right thing". So we need to
# auto-detect as much as we can under the weakest assumptions possible
# (e.g., we should not assume "torch.cuda.is_available()" gives
# meaningful information, as we may be building on a GPU-less head
# node). It seems reasonable to just find Torch and see what its CMake
# export can tell us. For instance, "torch_hip" will be found on ROCm
# platforms, and "torch_cuda" will be found on CUDA platforms -- we
# assume (hope!) that these are truly orthogonal! From there, we can
# pull a few additional flags in by further interrogating the targets,
# if needed.

find_package(Torch
  ${LBANNV2_MINIMUM_Torch_VERSION}
  REQUIRED
)

# We also don't care about the limited API nonsense, so we can use
# libtorch. Let's find it.
if (TORCH_LIBRARY)
  get_filename_component(TORCH_LIB_DIR "${TORCH_LIBRARY}" DIRECTORY)
endif ()
find_library(TORCH_PYTHON_LIBRARY
  torch_python
  HINTS
  ${TORCH_LIB_DIR}
  ${Python_SITELIB}/torch/lib64
  ${Python_SITELIB}/torch/lib
  NO_DEFAULT_PATH)
find_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED)

# MI300A only becomes a factor when doing a ROCm build. So start by
# assuming we don't have it.
#
# FIXME (trb): This should, of course, be relaxed to just represent
# memory coherence. However, I don't have access to any non-MI300A
# memory-coherent architectures. If anyone does, I'm happy to abstract
# this now; otherwise, I'll wait until I acquire such access myself.
set(LBANNV2_WITHOUT_MI300A ON)
unset(LBANNV2_WITH_MI300A)
unset(LBANNV2_UNKNOWN_MI300A)
unset(LBANNV2_HAS_CUDA)
unset(LBANNV2_HAS_ROCM)

if (TARGET torch_cuda)
  set(LBANNV2_HAS_CUDA ON)
  # We need to edit out the CUDA arch flags out. Or at least edit them
  # down to supported archs (>=70).
elseif (TARGET torch_hip)
  enable_language(HIP)
  set(LBANNV2_HAS_ROCM ON)

  # Handle MI300A configure checks.
  include(LBANNv2DetermineMI300A)
  set(_valid_mi300a_status "WITH" "WITHOUT" "UNKNOWN")
  set(LBANNV2_MI300A_STATUS "DETECT"
    CACHE STRING
    "On MI300A? Valid values: WITH, WITHOUT, UNKNOWN, DETECT")
  string(TOUPPER "${LBANNV2_MI300A_STATUS}" _mi300a_status_upper)
  if (NOT _mi300a_status_upper IN_LIST _valid_mi300a_status)
    determine_mi300a_support(_mi300a_status_upper)
  endif ()
  unset(LBANNV2_WITH_MI300A)
  unset(LBANNV2_WITHOUT_MI300A)
  unset(LBANNV2_UNKNOWN_MI300A)
  set(LBANNV2_${_mi300a_status_upper}_MI300A ON)
  # If we determine that we have MI300A, we can make some static
  # optimizations and eliminate some flow control. In the "UNKNOWN"
  # case, these static branches are replaced by dynamic ones, possibly
  # incurring some small overhead.
  #
  # As far as I can figure, the only case in which this could cause
  # problems (rather than just being suboptimal) is if we declare (or
  # decide) that we have MI300A when we actually do not. In
  # particular, this would cause our assumptions about CPU/GPU memory
  # visibility to be invalid -- hipMalloc'd memory would not be valid
  # on the CPU.

  # We need to remove any "std=c++<XY>" type options because we're
  # ahead of PyTorch's minimum requirements there.
  get_target_property(
    _torch_hip_compile_opts
    torch_hip
    INTERFACE_COMPILE_OPTIONS)
  foreach (_opt ${_torch_hip_compile_opts})
    if (_opt MATCHES "-std=c\\+\\+[0-9a-z]+")
      list(REMOVE_ITEM _torch_hip_compile_opts "${_opt}")
    endif ()
  endforeach()
  set_target_properties(torch_hip
    PROPERTIES INTERFACE_COMPILE_OPTIONS "${_torch_hip_compile_opts}")

  # FIXME (trb): So I really, truly hate this, but this seems to be
  # the shortest approach to dealing with version-based switches in
  # the C++ code. Other approaches involve obscure preprocessor macros
  # or long-winded SFINAE tricks, and because I want you, dear human
  # reader, to be happy, I opted for this very simple, highly readable
  # implementation instead.
  if (Torch_VERSION VERSION_LESS "2.11.0")
    set(LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS TRUE)
  endif ()
endif ()

# We need to determine if we should be using a CXX11_ABI macro or not
# so we can forward as appropriate to spdlog/Catch2/etc. We need to do
# this *BEFORE* adding DiHydrogen(/spdlog/Catch2); otherwise it won't
# get picked up and we'd have to add it to the respective targets
# later on.
if (TORCH_CXX_FLAGS AND TORCH_CXX_FLAGS MATCHES "GLIBCXX_USE_CXX11_ABI=([01])")
  add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${CMAKE_MATCH_1})
endif ()

# spdlog
include(FetchContent)
FetchContent_Declare(
  spdlog
  GIT_REPOSITORY https://github.com/gabime/spdlog.git
  GIT_TAG 79524ddd08a4ec981b7fea76afd08ee05f83755d # v1.17.0
  GIT_SHALLOW 1
  FIND_PACKAGE_ARGS CONFIG
)

# Ensure spdlog gets installed and exported properly. I can probably
# make this a non-cached (or a FORCE'd) variable, but this is fine.
set(SPDLOG_INSTALL ON CACHE INTERNAL "Install spdlog")
FetchContent_MakeAvailable(spdlog)

# Python module stuff
find_package(pybind11 CONFIG REQUIRED)

# Set a few RPATH handling things
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
if(APPLE)
  list(PREPEND CMAKE_INSTALL_RPATH "@loader_path")
else()
  list(PREPEND CMAKE_INSTALL_RPATH "\$ORIGIN")
endif()

# Add the library
add_library(lbannv2 SHARED)
add_library(lbann::lbannv2 ALIAS lbannv2)
target_sources(lbannv2
  PUBLIC
  FILE_SET HEADERS
  BASE_DIRS src
)
target_link_libraries(lbannv2
  PUBLIC
  torch
  spdlog::spdlog
)
set_target_properties(lbannv2
  PROPERTIES
  CXX_STANDARD 20
  CXX_STANDARD_REQUIRED ON
  CXX_EXTENSIONS OFF
  VERSION ${LBANNv2_VERSION}
  SOVERSION ${LBANNv2_VERSION_MAJOR}
)

# Create the Python module
python_add_library(_lbannv2 MODULE WITH_SOABI)
target_link_libraries(_lbannv2
  PUBLIC
  lbann::lbannv2
  "${TORCH_PYTHON_LIBRARY}"
  PRIVATE
  pybind11::headers
)
set_target_properties(_lbannv2
  PROPERTIES
  CXX_STANDARD 20
  CXX_STANDARD_REQUIRED ON
  CXX_EXTENSIONS OFF
)

# Handle logging. If `LBANNV2_LOG_LEVEL` is not set,
# SPDLOG_ACTIVE_LEVEL will not be set on the command line and will
# default to `SPDLOG_LEVEL_TRACE` in the C++ code
# (src/lbannv2/utils/logging.hpp).
#
# NOTE that this is the *compile time* log level. That is, if
# LBANN_LOG_LEVEL is set to "TRACE", every log message (*using the
# LBANNV2_LOG* macros) will be compiled; if it's set to "INFO",
# messages flagged as "TRACE" or "DEBUG" will not even be compiled.
# The default is set to "TRACE" so that all log messages are
# available, depending on the log level selected at runtime.
set(lbannv2_ok_log_levels
  "TRACE" "DEBUG" "INFO" "WARN" "ERROR" "CRITICAL" "OFF")
if (LBANNV2_LOG_LEVEL IN_LIST lbannv2_ok_log_levels)
  target_compile_definitions(
    lbannv2
    PRIVATE
    SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${LBANNV2_LOG_LEVEL}
  )

  target_compile_definitions(
    _lbannv2
    PRIVATE
    SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_${LBANNV2_LOG_LEVEL}
  )
endif ()

# Add the sources to the library
add_subdirectory(src/lbannv2)

# Generate the export header
include(GenerateExportHeader)
generate_export_header(lbannv2)

# Generate the configuration header
configure_file(
  ${PROJECT_SOURCE_DIR}/cmake/lbannv2_config.h.in
  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_config.h
  @ONLY
)

# Include it in the file set
target_sources(lbannv2 PUBLIC
  FILE_SET HEADERS
  BASE_DIRS ${CMAKE_CURRENT_BINARY_DIR}
  FILES
  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_config.h
  ${CMAKE_CURRENT_BINARY_DIR}/lbannv2_export.h
)

# Handle unit testing
include(CTest)
if (BUILD_TESTING)
  add_subdirectory(test)
endif ()

# Install stuff
#
# When building the Python bindings, we still install the whole C++
# library. We might want to clean this up. Also, we set
# tools.scikit-build.wheel.install-dir=lbannv2 so it installs into
# <site-packages>/lbannv2.
include(GNUInstallDirs)

set(
  CMAKE_INSTALL_CMAKEDIR
  "${CMAKE_INSTALL_LIBDIR}/cmake/lbannv2"
)

install(TARGETS lbannv2
  EXPORT lbannv2Targets
  FILE_SET HEADERS
)

install(EXPORT lbannv2Targets
  DESTINATION ${CMAKE_INSTALL_CMAKEDIR}
  NAMESPACE lbann::
)

install(TARGETS _lbannv2
  DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

include(CMakePackageConfigHelpers)
configure_package_config_file(
  cmake/lbannv2Config.cmake.in
  "${CMAKE_BINARY_DIR}/lbannv2Config.cmake"
  INSTALL_DESTINATION "${CMAKE_INSTALL_CMAKEDIR}"
)
write_basic_package_version_file(
  lbannv2ConfigVersion.cmake
  COMPATIBILITY SameMinorVersion
)
install(
  FILES
  "${CMAKE_BINARY_DIR}/lbannv2Config.cmake"
  "${CMAKE_BINARY_DIR}/lbannv2ConfigVersion.cmake"
  DESTINATION "${CMAKE_INSTALL_CMAKEDIR}"
)


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guidelines for LBANN

We welcome any contributions to LBANN in the form of Pull Requests.
Please follow the guidelines below for more information.

## Attribution

If you have not added yourself to the authors list in 
[CONTRIBUTORS](https://github.com/LLNL/lbann/blob/develop/CONTRIBUTORS), please do so in the appropriate place.

## git guidelines

When ready for review and merge, Pull Requests must match the latest `develop` branch commit.
If not ready, **rebase** the commits onto the latest commit. Avoid merge commits.

## Style guidelines

For C/C++ and GPU code, we follow the [LLVM coding style](https://llvm.org/docs/CodingStandards.html) with
adaptations, see the [coding style README](https://github.com/LLNL/lbann/blob/develop/README_coding_style.txt) and the
[clang-format configuration](https://github.com/LLNL/lbann/blob/develop/.clang-format) for more information.

For Python code, we follow the [Google coding style](https://google.github.io/styleguide/pyguide.html) guidelines,
but allow some exceptions to create layers in the LBANN Python frontend.

## Setting up automatic formatting

To enforce file formatting at every commit, you can use the pre-commit hook provided in the repository.
Make a symbolic link from `.git/hooks/pre-commit` to our script by running the following command
**from the root of your git repository**:

```sh
user@/path/to/lbann$ ln -s ../../scripts/pre-commit-hook.sh .git/hooks/pre-commit
```

Make sure you have `clang-format` installed for C/C++ formatting. If you do not have it installed in the path,
you may override it by setting the `$CLANG_FORMAT` environment variable to its path.


================================================
FILE: CONTRIBUTORS
================================================
LLNL Core Team:
  Brian Van Essen <vanessen1@llnl.gov> [@bvanessen]
  Tom Benson <benson31@llnl.gov> [@benson31]
  Nikoli Dryden <dryden1@llnl.gov> [@ndryden]
  Tal Ben-Nun <talbn@llnl.gov> [@tbennun]
  Pier Fiedorowicz <fiedorowicz1@llnl.gov> [@fiedorowicz1]
  
Collaborators:
  Shehtab Zaman [@szaman19]

Notable Prior LLNL Team Members
  Ryan Forsyth [@forsyth2]
  David Hysom [@davidHysom]
  Katie Graham [@graham63]
  Keita Iwabuchi [@KIwabuchi]
  Sam Ade Jacobs [@samadejacobs]
  Arpan Jain [@aj-prime]
  Hyojin Kim
  Naoya Maruyama [@naoyam]
  Erin McCarthy
  Adam Moody [@adammoody]
  Tim Moon [@timmoon10]
  Yosuke Oyama [@oyamay}
  Michael Wyatt [@mrwyattii]
  Jae-Seung Yeom [@JaeseungYeom]



================================================
FILE: LICENSE
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
Produced at the Lawrence Livermore National Laboratory.
Written by the LBANN Research Team (B. Van Essen, et al.) listed in
the CONTRIBUTORS file. <lbann-dev@llnl.gov>

LLNL-CODE-697807.
All rights reserved.

This file is part of LBANN: Livermore Big Artificial Neural Network
Toolkit. For details, see http://software.llnl.gov/LBANN or
https://github.com/LBANN and https://github.com/LLNL/LBANN.

Licensed under the Apache License, Version 2.0 (the "Licensee"); you
may not use this file except in compliance with the License.  You may
obtain a copy of the License at:

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing
permissions and limitations under the license.


================================================
FILE: NOTICE
================================================
This work was produced under the auspices of the U.S. Department of Energy by
Lawrence Livermore National Laboratory under Contract DE-AC52-07NA27344.

This work was prepared as an account of work sponsored by an agency of the
United States Government. Neither the United States Government nor Lawrence
Livermore National Security, LLC, nor any of their employees makes any warranty,
expressed or implied, or assumes any legal liability or responsibility for the
accuracy, completeness, or usefulness of any information, apparatus, product, or
process disclosed, or represents that its use would not infringe privately owned
rights. Reference herein to any specific commercial product, process, or service
by trade name, trademark, manufacturer, or otherwise does not necessarily
constitute or imply its endorsement, recommendation, or favoring by the United
States Government or Lawrence Livermore National Security, LLC. The views and
opinions of authors expressed herein do not necessarily state or reflect those
of the United States Government or Lawrence Livermore National Security, LLC,
and shall not be used for advertising or product endorsement purposes.


================================================
FILE: README.md
================================================
# Build

To save some pip-related heartburn, LBANNv2 is currently BYOT ("bring
your own Torch").

```
pip install torch
pip install .
```

# License

Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
LBANN Project Developers. See the top-level LICENSE file for details.

SPDX-License-Identifier: Apache-2.0

LLNL-CODE-697807


================================================
FILE: cmake/LBANNv2DetectTorchNVIDIALibraries.cmake
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
function(detect_torch_nvidia_libraries)
  set(_detect_opts)
  set(_detect_single_val_args)
  set(_detect_multi_value_args LIBRARIES)
  cmake_parse_arguments(PARSE_ARGV 0 _detect
    "${_detect_opts}" "${_detect_single_value_args}" "${_detect_multi_value_args}")

  find_package(Python 3.9 REQUIRED COMPONENTS Interpreter Development.Module)

  # Get information about torch. If Pip doesn't know about torch, that's
  # fine. We just stop and fall back on the user's environment, assuming
  # Torch to have been built from source.
  execute_process(
    COMMAND "${Python_EXECUTABLE}" -m pip show --no-color torch
    ERROR_VARIABLE _detect_pip_show_error
    OUTPUT_VARIABLE _detect_pip_show_output
    RESULT_VARIABLE _detect_pip_show_result)

  # Split the string on newlines
  string(REPLACE "\n" ";" _detect_torch_show_lines "${_detect_pip_show_output}")

  # And find the "requires" line:
  list(FILTER _detect_torch_show_lines INCLUDE REGEX "^Requires")

  # Now filter that down to the NVIDIA modules:
  string(REGEX MATCHALL "nvidia-[-a-z]+-cu[0-9]+" _detect_nvidia_modules "${_detect_torch_show_lines}")

  # Now that we have a list of modules, we need to search the file lists
  # of these. There are at least 2 approaches.
  #
  #  1. We can interrogate 'pip show --files <module name>', parse the
  #     base from the "Location" line and prepend it to any matching
  #     lines under the "Files" header to get the full paths to any
  #     relevant files.
  #
  #  2. We can use 'importlib.metadata' to parse the metadata associated
  #     with each module. This has the advantage that we don't have to
  #     do as much manual parsing and string manipulation -- the data we
  #     need can be generated with a simple list comprehension.
  #
  # While neither approach is particularly difficult, I've opted for
  # number 2. I especially like that by joining the output string with
  # semicolons, CMake will natively interpret the list of paths as a
  # CMake list, further simplifying things.

  # Get the list of paths out of the metadata. Separate with semicolon
  # so CMake interprets the output as a list directly.
  set(_detect_get_paths_program
    "import importlib.metadata as md; import sys; print(\";\".join([str(f.locate()) for f in md.files(sys.argv[1])]))")

  foreach (lib IN LISTS _detect_LIBRARIES)
    string(REGEX MATCH "nvidia-${lib}-cu[0-9]+" _detect_nvidia_lib_module "${_detect_nvidia_modules}")

    # Find paths
    execute_process(
      COMMAND "${Python_EXECUTABLE}" -c "${_detect_get_paths_program}" "${_detect_nvidia_lib_module}"
      ERROR_VARIABLE _detect_get_paths_error
      OUTPUT_VARIABLE _detect_get_paths_output
      RESULT_VARIABLE _detect_get_paths_result)

    foreach (path ${_detect_get_paths_output})
      if (path MATCHES ".*${lib}\\.h$")

        cmake_path(GET
          path
          PARENT_PATH
          _detect_parent_path)
        set(LBANNV2_DETECTED_${lib}_INCLUDE_PATH
          "${_detect_parent_path}"
          CACHE
          PATH
          "Include directory for ${lib}")

      elseif (path MATCHES ".*lib${lib}${CMAKE_SHARED_LIBRARY_SUFFIX}.*")

        set(LBANNV2_DETECTED_${lib}_LIBRARY
          "${path}"
          CACHE
          FILEPATH
          "Library for ${lib}")

      endif ()
    endforeach ()

    # Consider the thing found if both the include path and the
    # library are available.
    if (LBANNV2_DETECTED_${lib}_LIBRARY AND LBANNV2_DETECTED_${lib}_INCLUDE_PATH)
      set(LBANNV2_DETECTED_${lib} TRUE PARENT_SCOPE)
    endif ()
  endforeach ()
endfunction ()


================================================
FILE: cmake/LBANNv2DetermineMI300A.cmake
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
cmake_minimum_required(VERSION 3.24.0)

# Tries to determine whether the machine in question is MI300A. The
# check hinges on "rocm-smi" returning sane things. Sadly we must do
# this rather than just testing the arch flag because "gfx942" refers
# to both the MI300A and the MI300X.
#
# The return value is ternary:
#
#   - "WITH" means we have determined that we have MI300A
#   - "WITHOUT" means we have determined that we do NOT have MI300A
#   - "UNKNOWN" means that we cannot determine whether we have MI300A,
#     generally because "rocm-smi" produced no usable output.
#
# In cases where the node on which LBANNv2 is being built does not
# have GPUs, or if it happens to have *different* GPUs from the ones
# on the compute nodes, users are advised to provide this information
# directly, if possible.
#
# As this all comes down to "rocm-smi" output on the node on which
# LBANNv2 is configured, users are advised that there is high risk for
# incorrect or suboptimal information if not configuring on a compute
# node.
function (determine_mi300a_support OUTPUT_VARIABLE)

  # Call rocm-smi (should be in the PATH). If rsmi fails, then we say "unknown".
  execute_process(
    COMMAND rocm-smi --showproductname --json
    OUTPUT_VARIABLE _rsmi_info
    ERROR_VARIABLE _rsmi_error
    ERROR_QUIET
  )

  if (_rsmi_error AND _rsmi_error MATCHES ".*ERROR.*")
    set(${OUTPUT_VARIABLE} "UNKNOWN" PARENT_SCOPE)
    return ()
  endif ()

  string(JSON _gfx_version
    ERROR_VARIABLE _json_err
    GET "${_rsmi_info}" "card0" "GFX Version")

  # To get here, rsmi returned something valid, and this path just was not right.
  if (_json_err)
    message(DEBUG
      "JSON Error: ${_json_err}\n\nAssuming MI300A status is 'UNKNOWN'.")
    set(${OUTPUT_VARIABLE} "UNKNOWN" PARENT_SCOPE)
    return ()
  endif ()

  if (_gfx_version MATCHES ".*gfx942.*")
    execute_process(
      COMMAND rocminfo
      OUTPUT_VARIABLE _rocminfo_output
      ERROR_VARIABLE _rocminfo_error
      ERROR_QUIET
    )
    string(FIND "${_rocminfo_output}" "MI300A" _mi300a_exists)
    if (_mi300a_exists EQUAL -1)
      set(${OUTPUT_VARIABLE} "WITHOUT" PARENT_SCOPE)
    else ()
      set(${OUTPUT_VARIABLE} "WITH" PARENT_SCOPE)
    endif ()
  else ()
    set(${OUTPUT_VARIABLE} "WITHOUT" PARENT_SCOPE)
  endif ()
endfunction ()


================================================
FILE: cmake/lbannv2Config.cmake.in
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
include("${CMAKE_CURRENT_LIST_DIR}/lbannv2ConfigVersion.cmake")
set(LBANNv2_VERSION ${PACKAGE_VERSION})

include(CMakeFindDependencyMacro)

set(lbannv2_MINIMUM_H2_VERSION @lbannv2_MINIMUM_H2_VERSION@)
set(lbannv2_MINIMUM_Torch_VERSION @lbannv2_MINIMUM_Torch_VERSION@)

find_dependency(DiHydrogen
  ${lbannv2_MINIMUM_H2_VERSION}
  COMPONENTS Core Meta Patterns
)
find_dependency(Torch
  ${lbannv2_MINIMUM_Torch_VERSION}
)

@PACKAGE_INIT@

if (NOT TARGET lbann::lbannv2)
  include("${CMAKE_CURRENT_LIST_DIR}/lbannv2Targets.cmake")
endif ()

check_required_components(lbannv2)
set(LBANNv2_LIBRARIES lbann::lbannv2)


================================================
FILE: cmake/lbannv2_config.h.in
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

// clang-format off

#include <lbannv2_export.h>

// Version information
#define LBANNV2_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define LBANNV2_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define LBANNV2_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define LBANNV2_VERSION "@PROJECT_VERSION@"

#cmakedefine01 LBANNV2_DEBUG_MODE

#cmakedefine01 LBANNV2_HAS_CUDA
#cmakedefine01 LBANNV2_HAS_ROCM
#define LBANNV2_HAS_GPU (LBANNV2_HAS_CUDA + LBANNV2_HAS_ROCM)

#cmakedefine01 LBANNV2_WITH_MI300A
#cmakedefine01 LBANNV2_WITHOUT_MI300A
#cmakedefine01 LBANNV2_UNKNOWN_MI300A

#cmakedefine01 LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS

#ifndef SPDLOG_ACTIVE_LEVEL
// This defaults to "TRACE" so that all messages are compiled and
// available. Use the runtime environment variable to control which
// are seen.
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE
#endif

// clang-format on


================================================
FILE: pyproject.toml
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
[build-system]
requires = [
  "scikit-build-core>=0.10",
  "pybind11"
]
build-backend = "scikit_build_core.build"

[project]
name = "lbannv2"
version = "0.0.1"
description = "LBANN's core integration with PyTorch"
authors = [
  { name = "Tal Ben Nun", email = "bennun2@llnl.gov" },
  { name = "Tom Benson", email = "benson31@llnl.gov" },
  { name = "Nikoli Dryden", email = "dryden1@llnl.gov" },
  { name = "Pier Fiedorowicz", email = "fiedorowicz1@llnl.gov" },
  { name = "Brian Van Essen", email = "vanessen1@llnl.gov" },
]
license = { file = "LICENSE" }
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
  "Development Status :: 2 - Pre-Alpha",

  "License :: OSI Approved :: Apache Software License",

  "Programming Language :: C++",

  "Programming Language :: Python :: 3",
  "Programming Language :: Python :: 3.9",
  "Programming Language :: Python :: 3.10",
  "Programming Language :: Python :: 3.11",
  "Programming Language :: Python :: 3.12",
  "Programming Language :: Python :: 3.13",
  "Programming Language :: Python :: 3.14",

  "Topic :: Scientific/Engineering :: Artificial Intelligence",
  "Topic :: Software Development :: Libraries",
  "Topic :: Software Development :: Libraries :: Python Modules",
  "Topic :: Software Development :: Version Control :: Git",

  "Private :: Do Not Upload"
  ]
dependencies = [
  "pybind11",
  "torch>=2.9"
  ]

[project.optional-dependencies]
test = ["pytest"]

[tool.scikit-build]
minimum-version = "build-system.requires"
build-dir = "build"
cmake.version = ">=3.30.0"
ninja.version = ">=1.11"
ninja.make-fallback = false
wheel.expand-macos-universal-tags = true
wheel.install-dir = "lbannv2"

[tool.pytest]
minversion = "9.0"
testpaths = [
    "test/py",
]

================================================
FILE: python/lbannv2/__init__.py
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
import sys
import torch

try:
    from .lib._lbannv2 import *
except ModuleNotFoundError:
    from .lib64._lbannv2 import *

from ._automigrate import automigrate

# Setup state needed by the library
init_lbannv2()

def is_available():
    try:
        return bool(is_lbannv2_gpu_available())
    except Exception:
        return False

class MigratableMemory:
    """Use LBANNv2's allocator for the given device"""

    def __enter__(self):
        use_mi300a_host_allocator()

    def __exit__(self, exc_type, exc_value, traceback):
        use_pytorch_host_allocator()


def make_migratory_tensor(ctor, *args, **kwargs):
    with MigratableMemory():
        return ctor(*args, **kwargs)


================================================
FILE: python/lbannv2/_automigrate.py
================================================
import torch
from typing import Callable, Union

try:
    from .lib._lbannv2 import migrate
except ModuleNotFoundError:
    from .lib64._lbannv2 import migrate


def automigrate(f: Union[Callable, torch.fx.GraphModule]) -> torch.fx.GraphModule:
    """Check the graph for candidates for automatic pointer migration,
    replacing them with appropriate calls to 'migrate'. This function
    operates at the ATen IR (FX Graph) level, so it cannot perfectly
    determine all cases in which a migrate is possible. Symbolic
    tracing cannot, for instance, tell the device on which inputs or
    "member tensors" (e.g., of some nn layer) reside. We can make some
    inferences, though (e.g., all nodes downstream of a memory
    relocation call ("to", "cpu", etc) can be assumed to live on that
    device until the next such relocation call). Additionally, we
    cannot, in general, know the provenance of the underlying memory
    of a given tensor. At this time, we only support migration that
    comes from LBANNv2 allocators, as we are 100% sure of its
    allocation. We could likely support any memory that was allocated
    by or registered with the HIP runtime, though -- "future work".

    Args:
        f (Union[Callable, torch.fx.GraphModule]): Any callable
            amenable to symbolic_trace()-ing. If this is a
            torch.fx.GraphModule, it will be modified in-place and
            returned.

    Returns:
        A torch.fx.GraphModule representing the input Callable, with
            "data movement" nodes replaced with LBANNv2 pointer
            "migration", when appropriate. If the input was already a
            torch.fx.GraphModule, it is modified in-place and
            returned.

    """

    def safe_for_migrate(n: torch.fx.graph.Node) -> bool:
        """
        At the IR level, a tensor is a candidate for migration if
        it isn't used multiple places and if the underlying operation
        isn't trying to change more than just the device.
        """
        input_ok = len(n.args[0].users) == 1
        args_ok = len(node.kwargs) == 1 and "device" in node.kwargs
        # If we're dealing with 'cuda' or 'cpu', then it's ok for the
        # kwargs to be empty (note that 'cuda' also supports 'device'
        # as a kwarg).
        if n.target != "to":
            args_ok = args_ok or len(node.kwargs) == 0
        return input_ok and args_ok

    def get_target_device(n: torch.fx.graph.Node) -> torch.device:
        return (
            torch.device(n.kwargs["device"])
            if "device" in n.kwargs
            else torch.device(str(n.target))
        )

    if isinstance(f, torch.fx.GraphModule):
        gm = f
    else:
        gm = torch.fx.symbolic_trace(f)

    # We can handle "to" or the device-specific methods ("cuda", e.g.).
    migrate_candidates = ["to", "cuda", "cpu"]
    for node in gm.graph.nodes:
        if node.target in migrate_candidates and safe_for_migrate(node):
            with gm.graph.inserting_before(node):
                # Add a new node
                new_node = gm.graph.call_function(
                    migrate,
                    args=(*node.args, get_target_device(node)),
                )
                node.replace_all_uses_with(new_node)

            gm.graph.erase_node(node)

    gm.recompile()
    return gm


================================================
FILE: src/lbannv2/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
target_sources(lbannv2
  PUBLIC
  FILE_SET HEADERS
  FILES
  types.hpp
)

add_subdirectory(memory)
add_subdirectory(ops)
add_subdirectory(utils)

# Pybind/Torch registration
if (SKBUILD)
  add_subdirectory(python)
endif ()


================================================
FILE: src/lbannv2/memory/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
target_sources(lbannv2
  PUBLIC
  FILE_SET HEADERS
  FILES
  allocator.hpp
  # h2_allocator_wrappers.hpp
  registry.hpp
)
target_sources(lbannv2
  PRIVATE
  allocator.cpp
  registry.cpp
)

if (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)
  target_sources(lbannv2
    PUBLIC
    FILE_SET HEADERS
    FILES
    mi300a_allocator.hpp
  )
  target_sources(lbannv2
    PRIVATE
    mi300a_allocator.cpp
  )
endif ()


================================================
FILE: src/lbannv2/memory/allocator.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include "lbannv2/memory/allocator.hpp"

#include "lbannv2/memory/registry.hpp"
#include "lbannv2/utils/errors.hpp"
#include "lbannv2/utils/logging.hpp"

#include <c10/core/CPUAllocator.h>

#if LBANNV2_HAS_CUDA
#include <ATen/cuda/CUDAContextLight.h>
#elif LBANNV2_HAS_ROCM
#include <ATen/hip/HIPContextLight.h>
#endif

#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A
#include "lbannv2/memory/mi300a_allocator.hpp"
#endif

namespace lbannv2
{

c10::DataPtr Allocator::allocate(size_t n)
{
  // Do the allocation
  void* const buffer = this->raw_alloc(n);

  // Log the allocation
  LBANNV2_TRACE("Allocator::allocate(n={}, ptr={})", n, buffer);
  pointer_registry().add(buffer, n, this);

  // Decorate the allocation.
  return {buffer, buffer, this->raw_deleter(), this->get_device()};
}

}  // namespace lbannv2

bool lbannv2::is_managed_ptr(void const* const ptr) noexcept
{
  return pointer_registry().known(ptr);
}

namespace
{

c10::Allocator* pt_orig_cpu_alloc_ = nullptr;

}  // namespace

void lbannv2::use_mi300a_cpu_allocator()
{
#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A
#if LBANNV2_UNKNOWN_MI300A
  if (gpu::is_integrated())
#endif
  {
    if (!pt_orig_cpu_alloc_)
      pt_orig_cpu_alloc_ = c10::GetCPUAllocator();
    c10::SetCPUAllocator(&MI300Allocator::instance());
    return;
  }
#endif
  LBANNV2_WARN("No MI300A allocator available");
}

void lbannv2::use_torch_cpu_allocator()
{
  if (pt_orig_cpu_alloc_)
    c10::SetCPUAllocator(pt_orig_cpu_alloc_);
}


================================================
FILE: src/lbannv2/memory/allocator.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <lbannv2_config.h>

#include <c10/core/Allocator.h>

namespace lbannv2
{

/** @class Allocator
 *  @brief A simplistic interface for LBANN allocators.
 */
class LBANNV2_EXPORT Allocator : public c10::Allocator
{
public:
  virtual void* raw_alloc(size_t nbytes) = 0;
  virtual void raw_dealloc(void* ptr) = 0;
  virtual c10::Device get_device() const noexcept = 0;

  c10::DataPtr allocate(size_t n) final;
};  // class Allocator

LBANNV2_EXPORT bool is_managed_ptr(void const* ptr) noexcept;

LBANNV2_EXPORT void use_mi300a_cpu_allocator();
LBANNV2_EXPORT void use_torch_cpu_allocator();

}  // namespace lbannv2


================================================
FILE: src/lbannv2/memory/h2_allocator_wrappers.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2/memory/h2_allocator_wrappers.hpp>

namespace lbannv2
{

template <h2::Device D>
H2AllocatorWrapper<D>& H2AllocatorWrapper<D>::instance()
{
  static H2AllocatorWrapper<D> allocator;
  return allocator;
}

}  // namespace lbannv2


================================================
FILE: src/lbannv2/memory/h2_allocator_wrappers.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#include <lbannv2/memory/allocator.hpp>
#include <lbannv2/utils/logging.hpp>

#include <h2/core/allocator.hpp>

#include <c10/core/Allocator.h>

namespace lbannv2
{

template <h2::Device D>
class H2AllocatorWrapper : public Allocator
{
  using AllocatorType = h2::internal::Allocator<std::byte, D>;

public:
  /** @name Virtual function overrides */
  ///@{

  // memcpy
  void copy_data(void* dst, void const* src, size_t n) const final
  {
    if constexpr (D == h2::Device::CPU)
    {
      LBANNV2_TRACE("H2AllocatorWrapper<CPU>::copy_data(dst={}, src={}, bytes={})",
                    dst, src, n);
      std::memcpy(dst, src, n);
    }
#if LBANNV2_HAS_GPU
    if constexpr (D == h2::Device::GPU)
    {
      LBANNV2_TRACE("H2AllocatorWrapper<GPU>::copy_data(dst={}, src={}, bytes={})",
                    dst, src, n);
      h2::gpu::mem_copy(dst, src, n);
    }
#endif
  }

  void* raw_allocate(size_t n) final
  {
    return reinterpret_cast<void*>(
      AllocatorType::allocate(n, h2::ComputeStream {D}));
  }

  void raw_deallocate(void* ptr) final
  {
    AllocatorType::deallocate(reinterpret_cast<std::byte*>(ptr),
                              h2::ComputeStream {D});
  }

  c10::Device get_device() const noexcept final
  {
#if LBANNV2_HAS_GPU
    if constexpr (D == h2::Device::GPU)
      return c10::Device {
        c10::kCUDA, static_cast<c10::DeviceIndex>(h2::gpu::current_gpu())};
#endif
    return c10::Device {c10::kCPU};
  }

  ///@}

  // Singleton
  static H2AllocatorWrapper& instance()
  {
    static H2AllocatorWrapper<D> allocator;
    return allocator;
  }

private:
  H2AllocatorWrapper() = default;
  ~H2AllocatorWrapper() = default;
  H2AllocatorWrapper(H2AllocatorWrapper const&) = delete;
  H2AllocatorWrapper(H2AllocatorWrapper&&) = delete;
  H2AllocatorWrapper& operator=(H2AllocatorWrapper const&) = delete;
  H2AllocatorWrapper& operator=(H2AllocatorWrapper&&) = delete;
};

using H2CPUAllocatorWrapper = H2AllocatorWrapper<h2::Device::CPU>;
#if LBANNV2_HAS_GPU
using H2GPUAllocatorWrapper = H2AllocatorWrapper<h2::Device::GPU>;
#endif

}  // namespace lbannv2


================================================
FILE: src/lbannv2/memory/memory_utils.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#include <lbannv2/memory/allocator.hpp>

namespace lbannv2
{

/** @class AllocatorWrapper
 *  @brief Wrap an allocator with a different device.
 *
 *  This wraps a c10::Allocator instance. Allocations from that
 *  allocator are intercepted and the DataPtr is updated to have the
 *  specified Device.
 *
 *  The primary intention is to wrap LBANN allocators as "native
 *  device" allocators, though it could be used the other way, too.
 *  However, there is no pointer registration in this class -- LBANNv2
 *  allocators handle this internally, so including that here would
 *  "double register" pointers. This could be cleaned up a bit down
 *  the road.
 */
class AllocatorWrapper : public c10::Allocator
{
public:
  /** @brief Constructor
   *
   *  @param[in] alloc The allocator to wrap.
   *  @param[in] device The device to use for DataPtrs produced by
   *                    this allocator.
   */
  AllocatorWrapper(c10::Allocator& alloc, c10::Device device)
    : m_alloc {alloc}, m_device {std::move(device)}
  {}
  ~AllocatorWrapper() = default;

  c10::DataPtr allocate(size_t n) final
  {
    auto dptr = m_alloc.allocate(n);
    dptr.unsafe_set_device(m_device);
    // NOTE (trb): We could replace the deleter fn to be
    // this->raw_deleter, but since this->raw_deleter() just calls
    // that->raw_deleter(), what would be the point?? This story
    // changes if we start tracking memory allocations in the registry
    // through this class.
    return dptr;
  }

  c10::DeleterFnPtr raw_deleter() const noexcept final
  {
    return m_alloc.raw_deleter();
  }

  void copy_data(void* dst, void const* src, size_t n) const final
  {
    m_alloc.copy_data(dst, src, n);
  }

private:
  c10::Allocator& m_alloc;
  c10::Device m_device;
};  // class AllocatorWrapper

}  // namespace lbannv2


================================================
FILE: src/lbannv2/memory/mi300a_allocator.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include "lbannv2_config.h"

#include "lbannv2/memory/mi300a_allocator.hpp"

#include "lbannv2/memory/registry.hpp"
#include "lbannv2/utils/errors.hpp"
#include "lbannv2/utils/gpu_utils.hpp"
#include "lbannv2/utils/logging.hpp"

#if LBANNV2_HAS_CUDA
#include <ATen/cuda/CUDAContextLight.h>
#include <c10/cuda/CUDAStream.h>
#elif LBANNV2_HAS_ROCM
#include <ATen/hip/HIPContextLight.h>
#include <c10/hip/HIPStream.h>
#endif

#include <c10/core/CachingDeviceAllocator.h>

namespace
{
bool get_use_nonblocking_stream_env_var()
{
  char* env = std::getenv("LBANNV2_NONBLOCKING_HOST_ALLOC_STREAM");
  return env && std::strlen(env) && env[0] != '0';
}

bool use_nonblocking_stream()
{
  static bool const nonblock = get_use_nonblocking_stream_env_var();
  LBANNV2_DEBUG("Using nonblocking MI300A allocation stream? {}", nonblock);
  return nonblock;
}

struct StreamRAII
{
  ::lbannv2::TorchGPUStream_t stream;

  StreamRAII()
    : stream {lbannv2::c10_gpu::getStreamFromExternal(
        use_nonblocking_stream() ? lbannv2::gpu::make_nonblocking_stream()
                                 : lbannv2::gpu::make_stream(),
        lbannv2::gpu::current_device())}
  {}
  ~StreamRAII()
  {
    try
    {
      lbannv2::gpu::destroy_stream(stream.stream());
    }
    catch (...)
    {}
  }
};  // struct StreamRAII

// Internal stream for managing "host" allocations through CUB
::lbannv2::TorchGPUStream_t host_allocation_stream(c10::DeviceIndex const idx)
{
  static std::vector<StreamRAII> stream_raii(lbannv2::gpu::num_devices());
  LBANNV2_ASSERT_ALWAYS(idx >= 0 && idx < lbannv2::gpu::num_devices());
  return stream_raii[idx].stream;
}

c10::Device resolve_device(c10::Device const& d)
{
  if (d.is_cuda() && !d.has_index())
    return {c10::kCUDA, lbannv2::gpu::current_device()};

  return d;
}

#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS
namespace DeviceAlloc_ns = c10::hip::HIPCachingAllocator;
#else
namespace DeviceAlloc_ns = c10::cuda::CUDACachingAllocator;
#endif

void lbannv2_report_free(DeviceAlloc_ns::TraceEntry const& entry)
{
  try
  {
    void* const ptr = reinterpret_cast<void*>(entry.addr_);
    lbannv2::pointer_registry().remove(ptr);
    LBANNV2_TRACE("Deallocate (ptr={})", (void const*) ptr);
  }
  catch (lbannv2::UnknownAddress const&)
  {
    // ignore -- ptr allocated in Torch
  }
}

void lbannv2_trace_alloc(DeviceAlloc_ns::TraceEntry const& entry)
{
  if (entry.action_ == DeviceAlloc_ns::TraceEntry::FREE_COMPLETED)
    lbannv2_report_free(entry);
}
}  // namespace

namespace lbannv2
{

MI300Allocator::MI300Allocator()
{
#if LBANNV2_WITHOUT_MI300A || LBANNV2_UNKNOWN_MI300A
#if LBANNV2_UNKNOWN_MI300A
  if (!lbannv2::gpu::is_integrated())
#endif
    throw std::runtime_error("MI300Allocator is only supported on MI300A");
#endif

  auto* const dev_alloc =
    dynamic_cast<DeviceAlloc_t*>(at::cuda::getCUDADeviceAllocator());
  LBANNV2_ASSERT_ALWAYS(dev_alloc);
  if (!dev_alloc->initialized())
    dev_alloc->init(gpu::num_devices());

  // Trace memory stuff
  dev_alloc->attachAllocatorTraceTracker(lbannv2_trace_alloc);

  alloc_ = dev_alloc;
}

void MI300Allocator::copy_data(void* const dst,
                               void const* const src,
                               size_t const bytes) const
{
  LBANNV2_TRACE(
    "MI300Allocator::copy_data(dst={}, src={}, bytes={})", dst, src, bytes);
  std::memcpy(dst, src, bytes);
}

void* MI300Allocator::raw_alloc(size_t const nbytes)
{
  auto* const ptr = alloc_->raw_alloc_with_stream(
    nbytes, host_allocation_stream(lbannv2::gpu::current_device()));

  LBANNV2_TRACE(
    "MI300Allocator::raw_allocate(nbytes={}): ptr={}, current_device={}",
    nbytes,
    ptr,
    lbannv2::gpu::current_device());
  lbannv2::gpu::sync(host_allocation_stream(lbannv2::gpu::current_device()));

  return ptr;
}

void MI300Allocator::raw_dealloc(void* ptr)
{
  LBANNV2_TRACE("MI300Allocator::raw_deallocate(ptr={})", ptr);
  alloc_->raw_delete(ptr);
}

c10::Device MI300Allocator::get_device() const noexcept
{
  return c10::Device {c10::kCPU};
}

c10::DeleterFnPtr MI300Allocator::raw_deleter() const
{
  return alloc_->raw_deleter();
}

MI300Allocator& MI300Allocator::instance()
{
  static MI300Allocator alloc;
  return alloc;
}

}  // namespace lbannv2

c10::DeviceIndex lbannv2::get_device_idx(void const* const ptr) noexcept
{
  int device_idx;
  auto const hip_status = hipPointerGetAttribute(
    &device_idx, HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, const_cast<void*>(ptr));
  if (hip_status == hipSuccess)
  {
    return static_cast<c10::DeviceIndex>(device_idx);
  }
  else
  {
    LBANNV2_DEBUG("lbannv2::get_device_idx(ptr={}) failed. Error: {}",
                  ptr,
                  hipGetErrorString(hip_status));
    return -1;
  }
}

// Let's aim for a fully robust implementation here. We must consider:
//   1. Migrating from D(:m) -> D(:m) is a no-op.
//   2. Migrating from D:m -> D:n is a deep copy
void lbannv2::migrate_ptr(c10::DataPtr& ptr,
                          c10::Device to_device,
                          c10::Stream with_stream)
{
  auto const real_tgt_device = resolve_device(to_device);

  // If no migration actually happens, just short-circuit...
  if (ptr.device() == real_tgt_device)
    return;

#if LBANNV2_WITHOUT_MI300A || LBANNV2_UNKNOWN_MI300A
#if LBANNV2_UNKNOWN_MI300A
  if (!lbannv2::gpu::is_integrated())
#endif
  {
    throw std::runtime_error("migrate_ptr is only supported on MI300A");
  }
#endif

  // Check that the migration is valid
  auto const ptr_dev_idx = get_device_idx(ptr.get_context());
  c10::Device const real_src_device = ptr_dev_idx == -1
                                        ? c10::Device {c10::kCPU}
                                        : c10::Device {c10::kCUDA, ptr_dev_idx};
  LBANNV2_ASSERT(real_tgt_device.is_cpu() || real_src_device == real_tgt_device,
                 std::runtime_error,
                 "lbannv2::migrate_ptr: invalid src/tgt device combo");

  // Update the stream
  auto const new_stream = real_tgt_device.is_cpu()
                            ? host_allocation_stream(ptr_dev_idx)
                            : TorchGPUStream_t(with_stream);

  // UGH. Oh well.
  MI300Allocator::instance().alloc_->recordStream(ptr, new_stream);

  // Finally, update the DataPtr itself
  ptr.unsafe_set_device(real_tgt_device);
}


================================================
FILE: src/lbannv2/memory/mi300a_allocator.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2/memory/allocator.hpp>
#include <lbannv2/utils/gpu_utils.hpp>

#include <c10/core/Stream.h>

#if LBANNV2_HAS_CUDA
#include <c10/cuda/CUDACachingAllocator.h>
#elif LBANNV2_HAS_ROCM
#include <c10/hip/HIPCachingAllocator.h>
#endif

namespace lbannv2
{

// Call when moving pointer to a different device
void migrate_ptr(c10::DataPtr& ptr,
                 c10::Device to_device,
                 c10::Stream with_stream);

class MI300Allocator final : public Allocator
{
public:
  void copy_data(void* dst, void const* src, size_t bytes) const final;

  void* raw_alloc(size_t nbytes) final;

  void raw_dealloc(void* ptr) final;

  c10::DeleterFnPtr raw_deleter() const final;

  c10::Device get_device() const noexcept final;

  static MI300Allocator& instance();

private:
  MI300Allocator();
  ~MI300Allocator() = default;
  MI300Allocator(MI300Allocator const&) = delete;
  MI300Allocator(MI300Allocator&&) = delete;
  MI300Allocator& operator=(MI300Allocator const&) = delete;
  MI300Allocator& operator=(MI300Allocator&&) = delete;

#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS
  using DeviceAlloc_t = ::c10::hip::HIPCachingAllocator::HIPAllocator;
#else
  using DeviceAlloc_t = ::c10::cuda::CUDACachingAllocator::CUDAAllocator;
#endif
  DeviceAlloc_t* alloc_;

  friend void migrate_ptr(c10::DataPtr&, c10::Device, c10::Stream);

};

/** @brief Get the device with which the allocation is associated.
 *
 * @note From what I can tell, this is any valid pointer -- it doesn't
 *       have to be the "context" pointer, for instance.
 *
 * @param[in] A pointer to valid memory.
 *
 * @returns The (GPU) device index with which the allocation is
 *          associated. -1 if not GPU memory or nullptr.
 */
c10::DeviceIndex get_device_idx(void const* const ptr) noexcept;

}  // namespace lbannv2


================================================
FILE: src/lbannv2/memory/registry.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include "registry.hpp"

#include "lbannv2/utils/errors.hpp"
#include "lbannv2/utils/logging.hpp"

namespace
{

// Syntactic sugar. Iterators kinda suck for readability.
auto const& get_ptr_range(std::input_iterator auto const& map_iter) noexcept
{
  return map_iter->first;
}

auto const& get_allocator_ptr(std::input_iterator auto const& map_iter) noexcept
{
  return map_iter->second;
}

std::size_t range_bytes(std::pair<void*, void*> const& r) noexcept
{
  return std::distance((std::byte*) r.first, (std::byte*) r.second);
}

}  // namespace

namespace lbannv2
{

void PointerRegistry::add(void* const ptr,
                          size_t const size,
                          c10::Allocator* const allocator)
{
  if (!ptr)
    return;

  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const [it, added] = m_registry.emplace(
    KeyT {ptr, static_cast<std::byte*>(ptr) + size}, allocator);
  LBANNV2_ASSERT(
    added, std::runtime_error, "Address range overlaps existing range");

  LBANNV2_TRACE("Registered pointer range start={}, size={}, allocator={}",
                ptr,
                size,
                (void*) allocator);
}

void PointerRegistry::remove(void* const ptr)
{
  if (!ptr)
    return;

  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const it = m_registry.find(ptr);
  if (it == m_registry.cend())
    throw UnknownAddress {};
  else if (get_ptr_range(it).first != ptr)
    throw std::runtime_error("Cannot remove ptr; not beginning of range.");

  {
    [[maybe_unused]] auto const& [ptr_range, alloc_ptr] = *it;
    LBANNV2_TRACE("Deregistered pointer range start={}, size={}, allocator={}",
                  ptr_range.first,
                  range_bytes(ptr_range),
                  (void*) alloc_ptr);
  }

  m_registry.erase(it);
}

bool PointerRegistry::known(void const* const ptr) const noexcept
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  return m_registry.contains(ptr);
}

c10::Allocator* PointerRegistry::get_allocator(void const* const ptr) const
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const it = m_registry.find(ptr);
  if (it == m_registry.cend())
    throw UnknownAddress {};
  return get_allocator_ptr(it);
}

void PointerRegistry::unsafe_reset_allocator(void const* const ptr,
                                             c10::Allocator* const new_alloc)
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const it = m_registry.find(ptr);
  if (it == m_registry.cend())
    throw UnknownAddress {};
  it->second = new_alloc;
}

void* PointerRegistry::get_context(void const* const ptr) const
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const it = m_registry.find(ptr);
  if (it == m_registry.cend())
    throw UnknownAddress {};
  return get_ptr_range(it).first;
}

std::size_t PointerRegistry::bytes_registered() const noexcept
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  size_t bytes = 0UL;
  for (auto const& kvp : m_registry)
  {
    bytes += range_bytes(kvp.first);
  }
  return bytes;
}

std::size_t
PointerRegistry::bytes_registered(void const* const ptr) const noexcept
{
  std::lock_guard<std::mutex> lock(m_registry_mtx);
  auto const it = m_registry.find(ptr);
  if (it != m_registry.cend())
  {
    return range_bytes(it->first);
  }
  return 0;
}

}  // namespace lbannv2

auto lbannv2::pointer_registry() -> PointerRegistry&
{
  static PointerRegistry registry;
  return registry;
}


================================================
FILE: src/lbannv2/memory/registry.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#include <lbannv2/memory/allocator.hpp>

#include <map>
#include <mutex>
#include <stdexcept>

#include <c10/core/DeviceType.h>

namespace lbannv2
{

struct LBANNV2_EXPORT UnknownAddress : std::runtime_error
{
  UnknownAddress() : std::runtime_error {"Unknown address"} {}
};

// We should consider the issue of registering nullptr or equivalent
// zero-size allocations. Note that if ISO C++ is the only source of
// memory, this should be an error. But I'm not sure how all of the
// allocators we encounter might handle a zero-size allocation (e.g.,
// cudaMalloc and friends). ISO C++, however, requires zero-size
// allocations to still return unique, non-null pointers (section
// 6.7.5.5.2, paragraph 2).

/** @class PointerRegistry
 *  @brief Tracks known memory regions
 */
class LBANNV2_EXPORT PointerRegistry
{
public:
  /** @brief Register an allocation.
   *
   *  @param[in] ptr The beginning of the allocated range.
   *  @param[in] size The size in bytes of the allocated range.
   *  @param[in] allocator The allocator responsible for deleting the range.
   */
  void add(void* ptr, size_t size, c10::Allocator* allocator);

  /** @brief Deregister an allocation.
   *
   *  The pointer passed must match a pointer registered with add().
   *
   *  @param[in] ptr The (context) pointer to deregister.
   */
  void remove(void* ptr);

  /** @brief Query whether this address is part of a registered
   *         allocation.
   *
   *  Returns @c true for any address that is included in a registered
   *  allocation, that is, in the range [ptr, ptr + size) for any
   *  (ptr, size) passed to add().
   *
   *  @param[in] ptr The pointer in question.
   */
  bool known(void const* ptr) const noexcept;

  /** @brief Get the allocator used to allocate this pointer.
   *
   *  @param[in] ptr The pointer whose allocator is needed.
   *
   *  @throws UnknownAddress if the pointer is not part of a
   *          registered allocation.
   */
  c10::Allocator* get_allocator(void const* ptr) const;

  /** @brief Reset the allocator associated with a pointer.
   *
   *  In cases of MI300A pointer migration, this allows us to keep our
   *  internal bookkeeping consistent. It should not be used outside
   *  of this context.
   */
  void unsafe_reset_allocator(void const* ptr, c10::Allocator* new_alloc);
  // FIXME (trb): An alternative would be to make this similar to
  // "compare and swap" semantics (i.e., having to provide what the
  // user thinks the current allocator is); see also, replacing a
  // deleter on a DataPtr. My concern is this will never be called
  // "properly" but rather just with a dummy
  // "registry.get_allocator(ptr)" in that argument, so what would the
  // point really be?

  /** @brief Get the context of the given pointer.
   *
   *  The context is the address returned by the raw allocator when
   *  the allocation is requested. It is the pointer that must be
   *  passed to @c delete.
   *
   *  @param[in] ptr The pointer whose context is needed.
   *
   *  @throws UnknownAddress if the pointer is not part of a
   *          registered allocation.
   */
  void* get_context(void const* ptr) const;

  /** @brief Get the current number of registered ranges */
  size_t num_registered() const noexcept
  {
    std::lock_guard<std::mutex> lock(m_registry_mtx);
    return m_registry.size();
  }

  /** @brief Get the current number of registered bytes */
  size_t bytes_registered() const noexcept;

  /** @brief Get the number of bytes associated with the given
   *         pointer.
   *
   *  Unregistered pointers return 0. Since zero-sized ranges are
   *  allowed in the registry, this function cannot serve as a proxy
   *  for known().
   *
   *  @param[in] ptr Any valid address.
   *
   *  @returns The number of bytes in an allocation associated with
   *           the pointer.
   */
  size_t bytes_registered(void const*) const noexcept;

public:
  using KeyT = std::pair<void*, void*>;
  /** @class RangeLessAndDisjoint
   *  @brief Comparison operator for pointer ranges
   *
   *  'a' is RangeLessAndDisjoint from 'b' if its upper bound is <=
   *  the lower bound of 'b', and, because we consider zero-size
   *  ranges to be valid, if its lower bound is strictly less than the
   *  lower bound of 'b'. A consequence of this definition is that two
   *  ranges will be "equivalent", by the STL's definition of the
   *  concept, if and only if they overlap. Thus, using this as the
   *  `compare` operator in an associative map keyed on ranges [a,b),
   *  a<=b (with the equality case denoting a valid but zero-sized
   *  range) allows us to quickly identify overlapping ranges.
   *
   *  This provides benefits to our use-case in two ways. First,
   *  overlapping regions are forbidden. Thus, we will never add a
   *  range that overlaps a previously added range because the new key
   *  will present as equivalent to an existing key. Second, we can
   *  search for pointers p efficiently, using `key_type{p,p}` as the
   *  key. Searching this way will yield a range containing `p`, if
   *  one exists. I have included comparison operators that take a
   *  single pointer to facilitate this computation directly. Because
   *  they operate exactly "as though" we had passed a zero-size
   *  range, the ordering remains consistent and searches maintain
   *  their logarithmic complexity.
   */
  struct RangeLessAndDisjoint
  {
    /** @brief Needed to enable the templated overloads to find,
     *         contains, etc.
     */
    typedef std::true_type is_transparent;

    bool operator()(KeyT const& a, KeyT const& b) const noexcept
    {
      return a.second <= b.first && a.first != b.first;
    }

    bool operator()(void const* const a, KeyT const& b) const noexcept
    {
      return a < b.first;
    }

    bool operator()(KeyT const& a, void const* const b) const noexcept
    {
      return a.first != b && a.second <= b;
    }
  };

private:
  using MapType = std::map<KeyT, c10::Allocator*, RangeLessAndDisjoint>;
  MapType m_registry;
  mutable std::mutex m_registry_mtx;
};  // struct PointerRegistry

LBANNV2_EXPORT PointerRegistry& pointer_registry();

}  // namespace lbannv2


================================================
FILE: src/lbannv2/ops/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
target_sources(lbannv2
  PUBLIC
  FILE_SET HEADERS
  FILES
  migrate.hpp
)
target_sources(lbannv2
  PRIVATE
  migrate.cpp
)

# Note that LBANNV2_HAS_ROCM is implicit in either of these cases.
#
# FIXME trb: "migrate" includes all the dynamic mi300a handling, etc.
# Should it always be available at this level? (vs just in
# register_ops.cpp)
if (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)
  target_sources(lbannv2
    PUBLIC
    FILE_SET HEADERS
    FILES
    nonzero.hpp
    scalar.hpp
  )
  target_sources(lbannv2
    PRIVATE
    migrate.cpp
    nonzero.hip
    scalar.cpp
  )
endif ()


================================================
FILE: src/lbannv2/ops/migrate.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2_config.h>

#include <lbannv2/memory/mi300a_allocator.hpp>
#include <lbannv2/memory/registry.hpp>
#include <lbannv2/ops/migrate.hpp>
#include <lbannv2/utils/gpu_utils.hpp>
#include <lbannv2/utils/logging.hpp>
#include <lbannv2/utils/tensor_helpers.hpp>

#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#if LBANNV2_HAS_ROCM
#include <c10/hip/HIPFunctions.h>
#endif

#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A
namespace
{

// NOTE: This function assumes a binary view of memory: pointers only
// come from "CPU" or "CUDA" (i.e., HIP).
at::Device get_origin_device(void const* const ptr)
{
  // Note to future!me: the HIP runtime can give us both the context
  // pointer and the buffer size for any pointer allocated by HIP.
  // HOWEVER, so can the pytorch DataPtr object, which we have in the
  // context in which this function is used...
  int device_idx;
  if (hipPointerGetAttribute(&device_idx,
                             HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
                             const_cast<void*>(ptr))
      == hipSuccess)
  {
    return {c10::kCUDA, static_cast<c10::DeviceIndex>(device_idx)};
  }
  return c10::kCPU;
}

// PyTorch still admits the possibility of a single process using
// multiple GPUs, though this historically has not been LBANN's
// preferred approach (instead preferring 1 GPU per rank and 1 rank
// per GPU). On MI300A, we can migrate a pointer from any "GPU" to the
// CPU freely. HOWEVER, we can only migrate from the CPU to the
// specific device on which the migrateable memory was allocated.
bool is_ok_device(c10::Device const& d)
{
  return d.is_cpu()
#if LBANNV2_HAS_GPU
         || d.is_cuda()
#endif
    ;
}

c10::DispatchKeySet get_default_keyset(c10::Device const& d)
{
  switch (d.type())
  {
  case c10::kCPU: return c10::DispatchKeySet {c10::DispatchKey::CPU};
  case c10::kCUDA: return c10::DispatchKeySet {c10::DispatchKey::CUDA};
  default: throw std::runtime_error("Unknown device type");
  }
}

}  // namespace
#endif

at::Tensor lbannv2::migrate(at::Tensor& t, c10::Device const& d)
{
  auto const src_d = t.device();
  LBANNV2_TRACE(
    "migrate(ptr={}, from={}, to={})", t.data_ptr(), src_d.str(), d.str());

  // Short-circuit
  if (src_d == d)
    return t;

#if LBANNV2_UNKNOWN_MI300A || LBANNV2_WITH_MI300A
  // NOTE: "LBANNV2_HAS_ROCM" is implied here.

  // At its heart, this isn't really "migrate", it's "rebrand"... I
  // don't actually care what the device annotations on the Tensor or
  // Storage are, I care about the origin of the pointer. It might
  // also be good to look into p2p memory access, but I don't know how
  // to query that just given the pointer (i.e., even if p2p mem
  // access is enabled *now*, I haven't discovered a way to tell if it
  // was enabled when a particular buffer was allocated (well, other
  // than trying to read it and letting the segfault happen)).
  auto const real_src_d = get_origin_device(t.const_data_ptr());

  // We need to get the "real" "CUDA" target.
  c10::Device const real_tgt_d =
    (d.is_cuda() && !d.has_index())
      ? c10::Device {c10::kCUDA, gpu::current_device()}
      : d;

  // If the real_src_d is "cpu", it can be migrated to "cpu".
  // If the real_src_d is "cuda:N", it can be migrated to "cpu" or "cuda:N".
  LBANNV2_ASSERT(real_tgt_d.is_cpu() || (real_src_d == real_tgt_d),
                 std::runtime_error,
                 "Migrate: ptr is not migrateable to given device.");
  LBANNV2_ASSERT(
    is_ok_device(real_src_d),
    std::runtime_error,
    "Migrate: source tensor's device type not supported by LBANNv2.");
  LBANNV2_ASSERT(is_ok_device(real_tgt_d),
                 std::runtime_error,
                 "Migrate: destination device type not supported by LBANNv2.");

  // FIXME: If the pointer is not owned by LBANNv2, how do we handle
  // its associated stream?
  //  ---> The PyTorch CUDA caching allocator provides "recordStream" :)

#if LBANNV2_UNKNOWN_MI300A
  if (lbannv2::gpu::is_integrated())
#endif
  {
    c10::Stream stream = real_tgt_d.is_cpu()
                           ? c10::Stream {c10::Stream::DEFAULT, d}
                           : getDeviceCurrentStream(real_tgt_d.index());

    lbannv2::migrate_ptr(t.storage().mutable_data_ptr(), d, stream);

    // Report the number of meaningful bytes migrated. This is
    // inherently based on the tensor shape rather than the allocated
    // buffer size (think: binned allocations, subtensor "views",
    // etc).
    LBANNV2_TRACE("migrated {} bytes (ptr={})",
                  std::accumulate(t.sizes().cbegin(),
                                  t.sizes().cend(),
                                  static_cast<int64_t>(1),
                                  std::multiplies<int64_t> {})
                    * t.dtype().itemsize(),
                  t.const_data_ptr());

    auto storage = t.storage();
    // FIXME (trb): I initially created this as a 'VIEW', but that
    // puts it in "inference mode" (i.e., out.is_inference() == true).
    // This is bad for training workloads. We may need to be a bit
    // more careful in general here... E.g., migrating views to views,
    // etc.
    auto out =
      at::detail::make_tensor<at::TensorImpl>(  // at::TensorImpl::VIEW,
        std::move(storage),
        get_default_keyset(d),
        t.dtype());
    sync_metadata(t, out);

    if (src_d.is_cuda())
    {
      getDeviceCurrentStream(src_d.index()).synchronize();
    }

    return out;
  }
#endif
  return t.to(t.options().device(d));
}


================================================
FILE: src/lbannv2/ops/migrate.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#include <ATen/Tensor.h>
#include <c10/core/Device.h>

namespace lbannv2
{

/** @brief Migrate a tensor to a new device, eliding copies when
 *         possible.
 *
 *  If we have an APU (e.g., MI300A), we are able to zero-copy migrate
 *  the memory between the "cpu" backend and the "cuda" backend, under
 *  certain circumstances. The semantics differ from the "to" operator
 *  in the sense that the original tensor is considered "invalid"
 *  (implicitly, of course) after the migration.
 *
 *  The primary prerequisite for migrating a tensor is that its
 *  backing memory must have been allocated using a "cuda" allocator
 *  (that is, somewhere in the allocator stack, the raw memory must
 *  come from "hipMalloc" in the case of MI300A). LBANNv2 provides a
 *  context manager that replaces the underlying CPU allocator with
 *  one that allocates "cuda" memory, essentially providing
 *  "migrateable" CPU tensors.
 *
 *  At this time, we do NOT support IPC memory buffers or P2P device
 *  memory access. Thus, tensors are only migrateable between the CPU
 *  and whichever CUDA device their allocation is tied to. In the case
 *  of CPU tensors allocated using the LBANNv2 allocator, this will be
 *  whichever CUDA device was selected at the time of its allocation.
 *
 *  If we do not have an APU, this is just a direct call to "to".
 *
 *  Upon successful migration, the input tensor is invalidated to
 *  prevent foot wounds.
 *
 *  Schema: migrate(Tensor(a!), Device) -> Tensor(a!)
 *
 *  @param[in] t The tensor to (possibly) migrate.
 *  @param[in] d The target device.
 *
 *  @returns A tensor associated with the given target device.
 */
at::Tensor migrate(at::Tensor& t, c10::Device const& d);

}// namespace lbannv2


================================================
FILE: src/lbannv2/ops/nonzero.hip
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include "lbannv2/ops/nonzero.hpp"
#include <lbannv2/memory/allocator.hpp>
#include <lbannv2/utils/gpu_utils.hpp>
#include <lbannv2/utils/logging.hpp>

#include <ATen/hip/EmptyTensor.h>
#include <ATen/hip/HIPContext.h>
#include <hipcub/hipcub.hpp>

// Note (trb): UGH PyTorch 2.11
#ifdef C10_HIP_KERNEL_LAUNCH_CHECK
#define LBANNV2_KERNEL_LAUNCH_CHECK() C10_HIP_KERNEL_LAUNCH_CHECK()
#else
#define LBANNV2_KERNEL_LAUNCH_CHECK() C10_CUDA_KERNEL_LAUNCH_CHECK()
#endif

namespace
{
template <typename T>
T const* get_const(c10::DataPtr const& ptr)
{
  return static_cast<T const*>(ptr.get());
}

// Hoisted from PyTorch; clang-format to LBANNv2's style.
//
//   path: aten/src/ATen/native/cuda/Nonzero.cu
//   commit: 36eb64d60ea6371e3a617ba5026d27be7f88a6af
//
// FIXME: Point to <pytorch>/LICENSE or copy thereof.

template <typename T>
struct NonZeroOp
{
  __host__ __device__ __forceinline__ bool operator()(T const& a) const
  {
    return (a != T(0));
  }
};

#define MAX_DIMS 16
template <typename index_t>
struct TensorDims
{
  index_t sizes[MAX_DIMS];
};

template <typename index_t>
__global__ void write_indices(int64_t* inp,
                              TensorDims<index_t> dims,
                              int ndim,
                              index_t n,
                              int64_t* total = nullptr,
                              int64_t fill_value = -1)
{
  auto index = threadIdx.x + (int64_t) blockIdx.x * blockDim.x;
  bool cond = (total == nullptr || index < *total);
  if (index < n && cond)
  {
    index_t div = 1;
    int64_t idx_flat = inp[index];
#pragma unroll
    for (int dim = MAX_DIMS; dim >= 0; dim--)
    {
      if (dim > ndim - 1)
        continue;
      auto dim_size = dims.sizes[dim];
      inp[index + dim * n] = (idx_flat / div) % dim_size;
      div *= dim_size;
    }
  }
  else if (index < n)
  {
    // 0th dim has correct values already
    for (int dim = ndim - 1; dim > 0; dim--)
    {
      inp[index + dim * n] = fill_value;
    }
  }
}

// NOTE (trb): Majority from PyTorch. We removed use of host-based
// pinned_num_nonzeros_h. Instead we just sync the stream and use the
// memory directly on the CPU. We also change
// `const_data_ptr<scalar_t>()` to `static_cast<scalar_t
// const*>(const_data_ptr())` to sidestep a linker error with
// amdclang++.
template <typename scalar_t>
void nonzero_out_mi300a_impl(at::Tensor const& self, at::Tensor& out)
{
  at::Tensor self_ = self.contiguous();
  hipStream_t const stream = at::hip::getCurrentHIPStream();
  int64_t chunk_size, num_chunks;
  if (self.numel() < std::numeric_limits<int>::max())
  {
    chunk_size = self.numel();
    num_chunks = 1;
  }
  else
  {
    chunk_size = std::numeric_limits<int>::max() / 2 + 1;  // 2**30
    num_chunks = (self.numel() + chunk_size - 1) / chunk_size;
  }
  // compute number of nonzero elements
  size_t temp_storage_bytes = 0;
  auto* const allocator = c10::GetAllocator(self.device().type());

  auto num_nonzeros = allocator->allocate(sizeof(int) * num_chunks);
  for (int64_t idx = 0; idx < num_chunks; idx++)
  {
    int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
    hipcub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t const*>
      itr(static_cast<scalar_t const*>(self_.const_data_ptr()) + idx * chunk_size,
          NonZeroOp<scalar_t>());
    AT_CUDA_CHECK(hipcub::DeviceReduce::Sum(nullptr,
                                            temp_storage_bytes,
                                            itr,
                                            ((int*) num_nonzeros.get()) + idx,
                                            remaining,
                                            stream));
    auto temp_storage = allocator->allocate(temp_storage_bytes);
    AT_CUDA_CHECK(hipcub::DeviceReduce::Sum(temp_storage.get(),
                                            temp_storage_bytes,
                                            itr,
                                            ((int*) num_nonzeros.get()) + idx,
                                            remaining,
                                            stream));
  }

  // TOM: Skip the copy...

  // auto pinned_num_nonzeros_h = at::detail::empty_cpu(
  //     {num_chunks}, /* size */
  //     c10::CppTypeToScalarType<int>(), /* dtype */
  //     std::nullopt, /* layout */
  //     std::nullopt, /* device */
  //     true, /* pin_memory */
  //     std::nullopt /* memory format */
  // );
  // at::cuda::memcpy_and_sync(
  //     (void*)pinned_num_nonzeros_h.const_data_ptr<int>(),
  //     num_nonzeros.get(),
  //     sizeof(int) * num_chunks,
  //     cudaMemcpyDeviceToHost,
  //     stream);

  // TOM: ...just sync the stream...
  LBANNV2_CHECK_GPU(hipStreamSynchronize(stream));

  int64_t num_nonzeros_h = 0;

  // TOM: ...and use the pointer.
  for (int64_t idx = 0; idx < num_chunks; idx++)
  {
    num_nonzeros_h += (int) *(get_const<int>(num_nonzeros) + idx);
  }

  // num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr<int>());
  // expected output size is num_nonzeros x ndim
  // we are producing output with size {num_nonzeros, ndim} and strides {1,
  // num_nonzeros} (that is, transposed ndim x num_nonzeros output) we are able
  // to directly use passed output with this size and strides, and we can also
  // (per contract) resize passed output with incorrect sizes anyway we want.
  // However, out with correct sizes and incorrect strides will have to be
  // copied to from the intermediate we've produced.
  bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h
                      && out.sizes()[1] == self.dim()
                      && !out.t().is_contiguous();
  at::Tensor out_temp = need_to_copy
                          ? at::Tensor(at::detail::empty_cuda(
                              {self.dim(), num_nonzeros_h}, out.options()))
                          : out.resize_({self.dim(), num_nonzeros_h});
  // Scalars are expected to produce output of size (1,0), so we can't write to
  // it
  int64_t curr_nonzeros = 0;
  if (self.dim() > 0)
  {
    for (int64_t idx = 0; idx < num_chunks; idx++)
    {
      int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);

      hipcub::CountingInputIterator<int64_t> counting_itr(idx * chunk_size);
      hipcub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t const*>
        itr(static_cast<scalar_t const*>(self_.const_data_ptr()) + idx * chunk_size,
            NonZeroOp<scalar_t>());
      temp_storage_bytes = 0;
      AT_CUDA_CHECK(
        hipcub::DeviceSelect::Flagged(nullptr,
                                      temp_storage_bytes,
                                      counting_itr,
                                      itr,
                                      out_temp.mutable_data_ptr<int64_t>(),
                                      ((int*) num_nonzeros.get()) + idx,
                                      remaining,
                                      stream));
      auto temp_storage = allocator->allocate(temp_storage_bytes);
      AT_CUDA_CHECK(hipcub::DeviceSelect::Flagged(
        temp_storage.get(),
        temp_storage_bytes,
        counting_itr,
        itr,
        out_temp.mutable_data_ptr<int64_t>() + curr_nonzeros,
        ((int*) num_nonzeros.get()) + idx,
        remaining,
        stream));
      // TOM: Oh look, we use it again.
      curr_nonzeros += (int) *(get_const<int>(num_nonzeros) + idx);
    }
    if (num_nonzeros_h > 0 && self.dim() > 1)
    {
      TensorDims<int64_t> dims;
      for (int i = 0; i < self.dim(); i++)
      {
        dims.sizes[i] = self.sizes()[i];
      }
      int const nthreads = 256;
      int const nblocks = (num_nonzeros_h + nthreads - 1) / nthreads;
      write_indices<<<nblocks, nthreads, 0, stream>>>(
        out_temp.mutable_data_ptr<int64_t>(), dims, self.dim(), num_nonzeros_h);
      LBANNV2_KERNEL_LAUNCH_CHECK();
    }
  }
  if (need_to_copy)
  {
    out.copy_(out_temp.t());
  }
  else
  {
    // transpose out so it is correct size
    at::Tensor out_ = out_temp.t();
    out.set_(out_);
  }
}
}  // namespace

at::Tensor& lbannv2::nonzero_out(at::Tensor const& self, at::Tensor& out)
{
  c10::ScalarType const dtype = self.scalar_type();

  LBANNV2_TRACE("lbannv2::nonzero_out(device={}, dtype={})",
                self.device().str(),
                c10::toString(dtype));

  switch (dtype)
  {
  case c10::ScalarType::Bool: nonzero_out_mi300a_impl<bool>(self, out); break;
  case c10::ScalarType::Float:  nonzero_out_mi300a_impl<float>(self, out); break;
  case c10::ScalarType::Double: nonzero_out_mi300a_impl<double>(self, out); break;
  case c10::ScalarType::Int:  nonzero_out_mi300a_impl<int>(self, out); break;
  case c10::ScalarType::UInt32: nonzero_out_mi300a_impl<std::uint32_t>(self, out); break;
  case c10::ScalarType::Long:  nonzero_out_mi300a_impl<long>(self, out); break;
  default: return at::native::nonzero_out_cuda(self, out);
  }

  return out;
}

at::Tensor lbannv2::nonzero(at::Tensor const& self)
{
  at::Tensor out =
    at::detail::empty_cuda({0}, self.options().dtype(c10::kLong));
  return nonzero_out(self, out);
}


================================================
FILE: src/lbannv2/ops/nonzero.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <ATen/ATen.h>

namespace lbannv2
{

at::Tensor nonzero(at::Tensor const& self);
at::Tensor& nonzero_out(at::Tensor const& self, at::Tensor& out);

} // namespace lbannv2


================================================
FILE: src/lbannv2/ops/scalar.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2_config.h>

#include <lbannv2/ops/scalar.hpp>
#include <lbannv2/utils/errors.hpp>

#include <ATen/ops/_local_scalar_dense_native.h>

#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A
#include <lbannv2/types.hpp>
#include <lbannv2/utils/gpu_utils.hpp>
#include <lbannv2/utils/logging.hpp>

#include <ATen/core/TensorBase.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/hip/HIPStream.h>

// FIXME: We should integrate this better with either H2 dispatch or
// Torch dispatch (I don't really care, honestly).
namespace
{

template <typename ScalarT>
at::Scalar mi300a_impl(at::Tensor const& self)
{
  // The contract is a sync, so we sync. (It's also likely a
  // requirement for correctness, so we can assume the value can be
  // safely accessed.
  auto const stream = at::hip::getCurrentHIPStream();
  lbannv2::gpu::sync(stream);
  return at::Scalar(*reinterpret_cast<ScalarT const*>(self.const_data_ptr()));
}

at::Scalar mi300a_dispatch(at::Tensor const& self)
{
  c10::ScalarType const dtype = self.scalar_type();

  LBANNV2_TRACE("lbannv2::_local_scalar_dense_mi300a(device={}, dtype={})",
                self.device().str(),
                c10::toString(dtype));
  switch (dtype)
  {
  case c10::ScalarType::Bool: return mi300a_impl<bool>(self);
  case c10::ScalarType::Float: return mi300a_impl<float>(self);
  case c10::ScalarType::Double: return mi300a_impl<double>(self);
  case c10::ScalarType::Int: return mi300a_impl<int>(self);
  case c10::ScalarType::UInt32: return mi300a_impl<std::uint32_t>(self);
  case c10::ScalarType::Long: return mi300a_impl<long>(self);
  default: return at::native::_local_scalar_dense_cuda(self);
  }
}
}  // namespace
#endif  // LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A

at::Scalar lbannv2::local_scalar_dense_hip(at::Tensor const& self)
{
  // self.numel() == 1 is asserted elsewhere.
  c10::ScalarType const dtype = self.dtype().toScalarType();

  // Technically, the "right" fallback is implemented in all
  // subsequent code paths, but I want to know about it if there's
  // another type we should be supporting.
  LBANNV2_ASSERT(is_supported(dtype), std::runtime_error, c10::toString(dtype));

#if LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A
#if LBANNV2_UNKNOWN_MI300A
  if (lbannv2::gpu::is_integrated())
#endif  // LBANNV2_UNKNOWN_MI300A
    return mi300a_dispatch(self);
#endif  // LBANNV2_WITH_MI300A || LBANNV2_UNKNOWN_MI300A

  // Fallback to the Torch impl (cannot call at::_local_scalar_dense
  // -- it will cause an infinite recursion through this function).
  return at::native::_local_scalar_dense_cuda(self);
}


================================================
FILE: src/lbannv2/ops/scalar.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#include <ATen/core/Tensor.h>

namespace lbannv2
{

LBANNV2_EXPORT at::Scalar local_scalar_dense_hip(at::Tensor const&);

}  // namespace lbannv2


================================================
FILE: src/lbannv2/python/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
if (NOT SKBUILD)
  message(FATAL_ERROR "You should not be here. Not doing a SKBUILD.")
endif ()

target_sources(_lbannv2
  PRIVATE
  register_lbannv2.cpp
  register_memory_funcs.cpp
)

if (LBANNV2_WITH_MI300A OR LBANNV2_UNKNOWN_MI300A)
  target_sources(_lbannv2
    PRIVATE
    register_mi300a_ops.cpp
  )
endif ()


================================================
FILE: src/lbannv2/python/register_lbannv2.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2_config.h>

#include <lbannv2/utils/gpu_utils.hpp>
#include <lbannv2/utils/logging.hpp>

#include <c10/core/Device.h>
#include <pybind11/pybind11.h>

#include <cstdlib>
#include <iostream>

#include <sys/types.h>
#include <unistd.h>

namespace
{
bool _lbannv2_initialized = false;
bool _lbannv2_gpu_initialized = false;

void init_lbannv2()
{
  if (_lbannv2_initialized)
    return;

  if (std::getenv("LBANNV2_HANG_FOR_DEBUG"))
  {
    // Raw vs spdlog here because I want to force the flush.
    std::cout << "LBANNV2 WAITING ON PID " << getpid() << std::endl;
    int volatile wait = 1;
    while (wait) {}
  }

#if LBANNV2_HAS_GPU
  if (!_lbannv2_gpu_initialized)
  {
    // There's nothing to do here since getting rid of H2.
    _lbannv2_gpu_initialized = true;
  }
#endif

  _lbannv2_initialized = true;
}

bool is_lbannv2_initialized() noexcept
{
  return _lbannv2_initialized;
}

bool is_lbannv2_gpu_initialized() noexcept
{
  return _lbannv2_gpu_initialized;
}

bool is_lbannv2_gpu_available() noexcept
{
  return LBANNV2_HAS_GPU;
}

}  // namespace

namespace _lbannv2
{
void add_memory_funcs(pybind11::module_& m);
}  // namespace _lbannv2

PYBIND11_MODULE(_lbannv2, m)
{
  m.def("init_lbannv2", &init_lbannv2, "Initialize state for LBANNv2");
  m.def("is_lbannv2_initialized",
        &is_lbannv2_initialized,
        "Query initialization state for LBANNv2");
  m.def("is_lbannv2_gpu_initialized",
        &is_lbannv2_gpu_initialized,
        "Query initialization state for LBANNv2 GPU support.");
  m.def("is_lbannv2_gpu_available",
        &is_lbannv2_gpu_available,
        "Query whether LBANNv2 has GPU support.");
  m.def("set_log_level",
        &lbannv2::set_log_level,
        "Set the output level for LBANNv2 logging.");

  _lbannv2::add_memory_funcs(m);
}


================================================
FILE: src/lbannv2/python/register_memory_funcs.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2_config.h>

#include <lbannv2/memory/memory_utils.hpp>
#include <lbannv2/memory/registry.hpp>
#include <lbannv2/ops/migrate.hpp>
#include <lbannv2/utils/logging.hpp>

#if LBANNV2_HAS_GPU
#include <lbannv2/utils/gpu_utils.hpp>
#endif

#include <lbannv2/memory/allocator.hpp>

#include <ATen/ops/to_native.h>
#include <c10/core/Device.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/extension.h>
#include <torch/library.h>

namespace
{

// Migrate
at::Tensor py_migrate(at::Tensor& t, at::Device const& d)
{
  return lbannv2::migrate(t, d);
}

bool py_supports_migrate() noexcept
{
#if LBANNV2_WITH_MI300A
  return true;
#elif LBANNV2_HAS_GPU
  return lbannv2::gpu::is_integrated();
#else
  return false;
#endif
}

void py_use_mi300a_host_allocator()
{
  lbannv2::use_mi300a_cpu_allocator();
}

void py_use_torch_host_allocator()
{
  lbannv2::use_torch_cpu_allocator();
}

bool py_using_lbannv2_memory(torch::Tensor const& t)
{
  return lbannv2::pointer_registry().known(t.const_data_ptr());
}

}  // namespace

namespace _lbannv2
{

void add_memory_funcs(pybind11::module_& m)
{
  // Pointer migration
  m.def("supports_migrate",
        &py_supports_migrate,
        "Determine whether device migration is supported");

  m.def("migrate",
        &py_migrate,
        "Try to migrate an LBANNv2-owned pointer to a new device.");

  m.def("use_mi300a_host_allocator",
        &py_use_mi300a_host_allocator,
        "Use the LBANNv2 MI300A allocator for CPU allocations");

  m.def("use_pytorch_host_allocator",
        &py_use_torch_host_allocator,
        "Use the default pytorch CPU allocator for CPU allocations");

  m.def(
    "using_lbannv2_memory",
    &py_using_lbannv2_memory,
    "Determine whether LBANNv2 allocated the memory backing a given tensor");
}

}  // namespace _lbannv2


================================================
FILE: src/lbannv2/python/register_mi300a_ops.cpp
================================================
// NOTE: this file is only compiled when LBANNV2_WITH_MI300A or
// LBANNV2_UNKNOWN_MI300A, so the "#else" clauses below are really
// "#elif LBANNV2_UNKNOWN_MI300A".
#include "lbannv2_config.h"

#include <lbannv2/memory/mi300a_allocator.hpp>
#include <lbannv2/ops/nonzero.hpp>
#include <lbannv2/ops/scalar.hpp>
#include <lbannv2/utils/gpu_utils.hpp>

#include <torch/extension.h>
#include <torch/library.h>

namespace
{

at::Scalar lbannv2__local_scalar_dense_cuda(at::Tensor const& self)
{
#if LBANNV2_WITH_MI300A
  return lbannv2::local_scalar_dense_hip(self);
#else
  if (lbannv2::gpu::is_integrated())
    return lbannv2::local_scalar_dense_hip(self);
  return at::native::_local_scalar_dense_cuda(self);
#endif
}

at::Tensor lbannv2_nonzero(at::Tensor const& self)
{
#if LBANNV2_WITH_MI300A
  return lbannv2::nonzero(self);
#else
  if (lbannv2::gpu::is_integrated())
    return lbannv2::nonzero(self);
  return at::native::nonzero_cuda(self);
#endif
}

at::Tensor& lbannv2_nonzero_out(at::Tensor const& self, at::Tensor& out)
{
#if LBANNV2_WITH_MI300A
  return lbannv2::nonzero_out(self, out);
#else
  if (lbannv2::gpu::is_integrated())
    return lbannv2::nonzero_out(self, out);
  return at::native::nonzero_out_cuda(self, out);
#endif
}

} // namespace

TORCH_LIBRARY_IMPL(aten, CUDA, m)
{
  m.impl("_local_scalar_dense", TORCH_FN(lbannv2__local_scalar_dense_cuda));
  m.impl("nonzero", TORCH_FN(lbannv2_nonzero));
  m.impl("nonzero.out", TORCH_FN(lbannv2_nonzero_out));
}


================================================
FILE: src/lbannv2/types.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

// FIXME (trb): Where should this file live??

#include <c10/core/ScalarType.h>

namespace lbannv2
{

/** @brief Decide if a data type is supported by LBANNv2. */
inline bool is_supported(c10::ScalarType t) noexcept
{
  switch (t)
  {
  case c10::ScalarType::Bool:
  case c10::ScalarType::Float:
  case c10::ScalarType::Double:
  case c10::ScalarType::Int:
  case c10::ScalarType::UInt32:
  case c10::ScalarType::Long: return true;
  default: return false;
  }
}

}  // namespace lbannv2


================================================
FILE: src/lbannv2/utils/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
target_sources(lbannv2
  PUBLIC
  FILE_SET HEADERS
  FILES
  debugging_helpers.hpp
  errors.hpp
  gpu_utils.hpp
  logging.hpp
  tensor_helpers.hpp
)
target_sources(lbannv2
  PRIVATE
  gpu_utils.cpp
  logging.cpp
)


================================================
FILE: src/lbannv2/utils/debugging_helpers.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <cxxabi.h>
#include <execinfo.h>

#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

namespace lbannv2
{

inline std::string demngl(std::string symb)
{
  int status;
  char* const demangled_name =
    abi::__cxa_demangle(symb.data(), nullptr, nullptr, &status);
  if (demangled_name && status == 0)
  {
    std::string out(demangled_name);
    free(demangled_name);
    return out;
  }

  std::ostringstream oss;
  oss << symb << " (demangling failed)";
  return oss.str();
}

inline void print_bt(size_t nframes = 128, std::ostream& os = std::cout)
{
  std::vector<void*> frames(nframes);
  nframes = backtrace(frames.data(), nframes);
  char** symbs = backtrace_symbols(frames.data(), nframes);

  os << "-------------------------------------------------\n";
  for (size_t i = 0; i < nframes; ++i)
  {
    os << std::setw(4) << std::right << i << ": (" << frames[i]
       << "): " << demngl(symbs[i]) << "\n";
  }
  os << "-------------------------------------------------" << std::endl;
  free(symbs);
}

}  // namespace lbannv2


================================================
FILE: src/lbannv2/utils/errors.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

#define LBANNV2_ASSERT(cond, excpt, msg)                                       \
  do                                                                           \
  {                                                                            \
    if (!(cond))                                                               \
    {                                                                          \
      throw excpt(msg);                                                        \
    }                                                                          \
  } while (0)

#define LBANNV2_ASSERT_ALWAYS(cond)                                            \
  LBANNV2_ASSERT(cond, std::runtime_error, "Assertion \"" #cond "\" failed.")

#if LBANNV2_DEBUG
#define LBANNV2_ASSERT_DEBUG(cond) (void)
#else
#define LBANNV2_ASSERT_DEBUG(cond) LBANNV2_ASSERT_ALWAYS(cond)
#endif


================================================
FILE: src/lbannv2/utils/gpu_utils.cpp
================================================
#include "gpu_utils.hpp"

#include "errors.hpp"
#include "logging.hpp"

bool lbannv2::gpu::is_integrated() noexcept
{
#if LBANNV2_WITH_MI300A
  return true;
#else
#if LBANNV2_HAS_ROCM
  hipDeviceProp_t props;
  if (hipGetDeviceProperties(&props, current_device()) == hipSuccess)
    return props.integrated;
  LBANNV2_ERROR("Failed to get device properties of current HIP device {}.",
                current_device());
#endif
#endif
  return false;
}

c10::DeviceIndex lbannv2::gpu::num_devices() noexcept
{
#if LBANNV2_HAS_GPU
  return c10_gpu::device_count();
#else
  return 0;
#endif
}

c10::DeviceIndex lbannv2::gpu::current_device()
{
#if LBANNV2_HAS_GPU
  return c10_gpu::current_device();
#else
  return -1;
#endif
}

void lbannv2::gpu::set_device(c10::DeviceIndex const d)
{
  LBANNV2_TRACE("lbannv2::gpu::set_device(d={})", d);
  LBANNV2_ASSERT_ALWAYS(d >= 0 && d < num_devices());
#if LBANNV2_HAS_GPU
  c10_gpu::set_device(d, false);
#endif
}

#if LBANNV2_HAS_GPU
#if LBANNV2_HAS_CUDA
#define lbannv2StreamCreate cudaStreamCreate
#define lbannv2StreamCreateWithFlags cudaStreamCreateWithFlags
#define lbannv2StreamNonBlocking cudaStreamNonBlocking
#define lbannv2StreamSync cudaStreamSynchronize
#define lbannv2StreamDestroy cudaStreamDestroy
#elif LBANNV2_HAS_ROCM
#define lbannv2StreamCreate hipStreamCreate
#define lbannv2StreamCreateWithFlags hipStreamCreateWithFlags
#define lbannv2StreamNonBlocking hipStreamNonBlocking
#define lbannv2StreamSync hipStreamSynchronize
#define lbannv2StreamDestroy hipStreamDestroy
#endif

auto lbannv2::gpu::make_stream() -> Stream_t
{
  Stream_t stream;
  LBANNV2_CHECK_GPU(lbannv2StreamCreate(&stream));
  LBANNV2_TRACE("lbannv2::gpu::make_stream(): created stream {}",
                (void*) stream);
  return stream;
}

auto lbannv2::gpu::make_nonblocking_stream() -> Stream_t
{
  Stream_t stream;
  LBANNV2_CHECK_GPU(
    lbannv2StreamCreateWithFlags(&stream, lbannv2StreamNonBlocking));
  LBANNV2_TRACE("lbannv2::gpu::make_nonblocking_stream(): created stream {}",
                (void*) stream);
  return stream;
}

void lbannv2::gpu::sync(Stream_t const stream)
{
  LBANNV2_CHECK_GPU(lbannv2StreamSync(stream));
  LBANNV2_TRACE("lbannv2::gpu::sync(stream={})", (void const*) stream);
}

void lbannv2::gpu::destroy_stream(Stream_t const stream)
{
  LBANNV2_CHECK_GPU(lbannv2StreamDestroy(stream));
  LBANNV2_TRACE("lbannv2::gpu::destroy_stream(stream={})", (void*) stream);
}

#endif


================================================
FILE: src/lbannv2/utils/gpu_utils.hpp
================================================
#pragma once
#include <lbannv2_config.h>

#include <c10/core/Device.h>

#if LBANNV2_HAS_CUDA

#include <lbannv2/utils/logging.hpp>

#include <c10/cuda/CUDAFunctions.h>

#include <stdexcept>

#include <cuda_runtime.h>

#define LBANNV2_CHECK_GPU(cmd)                                                 \
  do                                                                           \
  {                                                                            \
    auto const lbannv2_check_gpu_status = (cmd);                               \
    if (lbannv2_check_gpu_status != cudaSuccess)                               \
    {                                                                          \
      LBANNV2_DEBUG("CUDA command \"" #cmd "\" failed. Error: {}",             \
                    cudaGetErrorString(lbannv2_check_gpu_status));             \
      throw std::runtime_error("CUDA command \"" #cmd "\" failed.");           \
    }                                                                          \
  } while (0)

#elif LBANNV2_HAS_ROCM

#include <lbannv2/utils/logging.hpp>

#include <c10/hip/HIPFunctions.h>
#include <c10/hip/HIPStream.h>

#include <stdexcept>

#include <hip/hip_runtime.h>

#define LBANNV2_CHECK_GPU(cmd)                                                 \
  do                                                                           \
  {                                                                            \
    auto const lbannv2_check_gpu_status = (cmd);                               \
    if (lbannv2_check_gpu_status != hipSuccess)                                \
    {                                                                          \
      LBANNV2_DEBUG("HIP command \"" #cmd "\" failed. Error: {}",              \
                    hipGetErrorString(lbannv2_check_gpu_status));              \
      throw std::runtime_error("HIP command \"" #cmd "\" failed.");            \
    }                                                                          \
  } while (0)
#endif

namespace lbannv2
{
#if LBANNV2_USE_C10_HIP_NAMESPACE_AND_SYMBOLS
namespace c10_gpu = c10::hip;
using TorchGPUStream_t = c10::hip::HIPStream;
inline auto& getDeviceCurrentStream = c10::hip::getCurrentHIPStream;
#elif LBANNV2_HAS_GPU
namespace c10_gpu = c10::cuda;
using TorchGPUStream_t = c10::cuda::CUDAStream;
inline auto& getDeviceCurrentStream = c10::cuda::getCurrentCUDAStream;
#endif

inline constexpr bool has_cuda() noexcept
{
  return LBANNV2_HAS_CUDA;
}

inline constexpr bool has_hip() noexcept
{
  return LBANNV2_HAS_ROCM;
}

inline constexpr bool has_gpu() noexcept
{
  return LBANNV2_HAS_GPU;
}

namespace gpu
{

#if LBANNV2_HAS_CUDA
using Stream_t = cudaStream_t;
#elif LBANNV2_HAS_ROCM
using Stream_t = hipStream_t;
#endif

// Returns 'false' if no GPU support
bool is_integrated() noexcept;

// Returns 0 if no GPU support
c10::DeviceIndex num_devices() noexcept;

// Returns -1 if no GPU support
c10::DeviceIndex current_device();

// Throws if d >= num_devices() or d < 0.
void set_device(c10::DeviceIndex d);

#if LBANNV2_HAS_GPU
Stream_t make_stream();
Stream_t make_nonblocking_stream();
void sync(Stream_t);
void destroy_stream(Stream_t);
#endif

}  // namespace gpu

}  // namespace lbannv2


================================================
FILE: src/lbannv2/utils/logging.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include "lbannv2/utils/logging.hpp"

#include <memory>
#include <string>

#include <spdlog/pattern_formatter.h>
#include <spdlog/sinks/basic_file_sink.h>
#include <spdlog/sinks/stdout_color_sinks.h>

#if __has_include(<unistd.h>)
#include <unistd.h>
#define _HAVE_UNISTD_H
#endif

namespace
{
spdlog::level::level_enum get_env_log_level()
{
  if (char const* const var = std::getenv("LBANNV2_LOG_LEVEL"))
  {
    std::string level_str {var};
    std::for_each(begin(level_str), end(level_str), [](char& c) {
      c = static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
    });

    if (level_str == "trace")
      return ::spdlog::level::trace;
    if (level_str == "debug")
      return ::spdlog::level::debug;
    if (level_str == "info")
      return ::spdlog::level::info;
    if (level_str == "warn")
      return ::spdlog::level::warn;
    if (level_str == "err")
      return ::spdlog::level::err;
    if (level_str == "critical")
      return ::spdlog::level::critical;
    if (level_str == "off")
      return ::spdlog::level::off;
  }
  return ::spdlog::level::info;
}

std::string get_hostname()
{
#ifdef _HAVE_UNISTD_H
  char buf[256];
  if (gethostname(buf, 256) == 0)
    return std::string {buf, std::find(buf, buf + 256, '\0')};
#endif

  return "<unknownhost>";
}

// The one in H2 is not exported, but it's a quick reimplementation.
class HostFlag final : public spdlog::custom_flag_formatter
{
public:
  std::unique_ptr<custom_flag_formatter> clone() const final
  {
    return spdlog::details::make_unique<HostFlag>();
  }
  void format(::spdlog::details::log_msg const&,
              ::std::tm const&,
              ::spdlog::memory_buf_t& dest)
  {
    static std::string const hostname = get_hostname();
    dest.append(hostname);
  }
};  // class HostFlag

std::unique_ptr<::spdlog::pattern_formatter> make_default_formatter()
{
  auto formatter = std::make_unique<::spdlog::pattern_formatter>();
  formatter->add_flag<HostFlag>('h');
  formatter->set_pattern("[%h:%P:%t] [%n:%^%l%$] %v");
  // formatter->set_pattern("[%m-%d-%Y %T.%f] [%h:%P] [%n] [%^%l%$] %v");
  return formatter;
}

::spdlog::sink_ptr make_default_sink()
{
  char const* sink_name = std::getenv("LBANNV2_LOG_FILE");
  std::string const sink_name_str(sink_name ? sink_name : "stdout");
  if (sink_name_str == "stdout")
    return std::make_shared<spdlog::sinks::stdout_color_sink_mt>();
  if (sink_name_str == "stderr")
    return std::make_shared<spdlog::sinks::stderr_color_sink_mt>();
  return std::make_shared<spdlog::sinks::basic_file_sink_mt>(sink_name_str);
}

std::shared_ptr<::spdlog::logger> make_default_logger()
{
  auto logger =
    std::make_shared<::spdlog::logger>("lbannv2", make_default_sink());
  logger->set_formatter(make_default_formatter());
  logger->set_level(get_env_log_level());
  return logger;
}

}  // namespace

void lbannv2::set_log_level(std::string const& lvl_str)
{
  // Valid inputs: trace, debug, info, warn, error, critical, off
  auto const lvl = spdlog::level::from_str(lvl_str);
  lbannv2::default_logger()->set_level(lvl);
}

std::shared_ptr<::spdlog::logger>& lbannv2::default_logger()
{
  static std::shared_ptr<::spdlog::logger> logger_ = make_default_logger();
  return logger_;
}


================================================
FILE: src/lbannv2/utils/logging.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include <lbannv2_config.h>

/**
 * @file Enable spdlog logging for LBANNv2.
 *
 * The symbols in this file are not exported by default so any
 * hypothetical downstream doesn't take over our logger.
 *
 * The logger macros that include `LOG` in their names take a logger
 * pointer as their first argument. The other macros use the default
 * LBANNv2 logger.
 */

#include <spdlog/spdlog.h>

// These dispatch through SPDLOG's default macros. Hence, their
// behavior is ultimately determined by the SPDLOG_ACTIVE_LEVEL macro.
#define LBANNV2_LOG_TRACE(logger, ...) SPDLOG_LOGGER_TRACE(logger, __VA_ARGS__)
#define LBANNV2_LOG_DEBUG(logger, ...) SPDLOG_LOGGER_DEBUG(logger, __VA_ARGS__)
#define LBANNV2_LOG_INFO(logger, ...) SPDLOG_LOGGER_INFO(logger, __VA_ARGS__)
#define LBANNV2_LOG_WARN(logger, ...) SPDLOG_LOGGER_WARN(logger, __VA_ARGS__)
#define LBANNV2_LOG_ERROR(logger, ...) SPDLOG_LOGGER_ERROR(logger, __VA_ARGS__)
#define LBANNV2_LOG_CRITICAL(logger, ...)                                      \
  SPDLOG_LOGGER_CRITICAL(logger, __VA_ARGS__)

#define LBANNV2_TRACE(...)                                                     \
  LBANNV2_LOG_TRACE(::lbannv2::default_logger(), __VA_ARGS__)
#define LBANNV2_DEBUG(...)                                                     \
  LBANNV2_LOG_DEBUG(::lbannv2::default_logger(), __VA_ARGS__)
#define LBANNV2_INFO(...)                                                      \
  LBANNV2_LOG_INFO(::lbannv2::default_logger(), __VA_ARGS__)
#define LBANNV2_WARN(...)                                                      \
  LBANNV2_LOG_WARN(::lbannv2::default_logger(), __VA_ARGS__)
#define LBANNV2_ERROR(...)                                                     \
  LBANNV2_LOG_ERROR(::lbannv2::default_logger(), __VA_ARGS__)
#define LBANNV2_CRITICAL(...)                                                  \
  LBANNV2_LOG_CRITICAL(::lbannv2::default_logger(), __VA_ARGS__)

namespace lbannv2
{
/** @brief Get LBANNv2's default logger.
 *
 *  The default logger is configured through the environment variable
 *  `LBANNV2_LOG_FILE`. Acceptable values are 'stdout', 'stderr', and
 *  a valid filename pattern.
 *
 *  @todo Enable logging to a process-specific file.
 */
std::shared_ptr<::spdlog::logger>& default_logger();

/** @brief Set the logging level.
 *
 *  \param[in] level Desired log level. Valid choices are "trace", "debug",
 *                   "info", "warn", "error", "critical", and "off".
 */
void set_log_level(std::string const& level);

}  // namespace lbannv2


================================================
FILE: src/lbannv2/utils/tensor_helpers.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include "lbannv2/utils/errors.hpp"

#include <ATen/NamedTensorUtils.h>
#include <ATen/Tensor.h>
#include <c10/util/ArrayRef.h>

namespace lbannv2
{

/** @brief Determines if t is associated with LBANN */
inline bool is_lbann(at::Tensor const& t)
{
  return t.is_privateuseone();
}

inline bool is_scalar(at::Tensor const& t)
{
  return t.defined() && (t.dim() == 0);
}

inline void set_data_ptr_device(c10::DataPtr& dp, c10::Device d)
{
  dp.unsafe_set_device(std::move(d));
}

inline void set_data_ptr_device(c10::Storage const& s, c10::Device d)
{
  set_data_ptr_device(s.mutable_data_ptr(), std::move(d));
}

inline void sync_metadata(at::Tensor const& src, at::Tensor& dst)
{
  auto* dst_tensor_info = dst.unsafeGetTensorImpl();
  dst_tensor_info->set_storage_offset(src.storage_offset());
  dst_tensor_info->set_sizes_and_strides(src.sizes(), src.strides());

  // I assume this restores named dimensions? Not sure if it
  // should be here or not. See "alias_with_sizes_and_strides"
  // in <pytorch>/aten/src/ATen/native/TensorShape.cpp
  at::namedinference::propagate_names(dst, src);
}

/** @brief Make an alias of the tensor on a new backend
 *
 *  This function can be used to produce aliases with diffent devices,
 *  different dispatch keys, or both (or neither, I suppose).
 *
 *  @post The original tensor will keep its device type and keys, but
 *        its DataPtr will appear to be on the new device if queried.
 */
inline at::Tensor alias_as_device(at::Tensor const& orig_tensor,
                                  c10::Device const& d,
                                  c10::DispatchKeySet ks)
{
  // Make (soft) copy of the storage and set the device to be the real
  // underlying device.
  at::Storage aliased_storage(orig_tensor.storage());
  set_data_ptr_device(aliased_storage, d);

  // Set up a view with this storage, using the modified keyset.
  auto alias_tensor =
    at::detail::make_tensor<at::TensorImpl>(c10::TensorImpl::VIEW,
                                            std::move(aliased_storage),
                                            std::move(ks),
                                            orig_tensor.dtype());

  // Setup sizes, strides, and storage offset.
  sync_metadata(orig_tensor, alias_tensor);

  // Quick sanity check before we go
  LBANNV2_ASSERT(alias_tensor.const_data_ptr() == orig_tensor.const_data_ptr(),
                 std::runtime_error,
                 "Aliasing tensor data has failed");

  return alias_tensor;
}

/** @brief Minimal tensor stringification.
 *
 *  Returns "[ {device type}{data type}[d1, d2, ..., dn] ]", for
 *  example, "[ lbannFloatType[2, 2] ]" for a 2x2 Float32 tensor on
 *  the LBANN backend.
 */
inline std::string to_str(at::Tensor const& t)
{
  std::ostringstream oss;
  oss << "[ " << t.toString() << t.sizes() << " ]";
  return oss.str();
}

/** @brief ArrayRef stringification */
template <typename T>
std::string to_str(c10::ArrayRef<T> const& ar)
{
  std::ostringstream oss;
  oss << ar;
  return oss.str();
}

}  // namespace lbannv2


================================================
FILE: test/CMakeLists.txt
================================================
################################################################################
## Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
## LBANN Project Developers. See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
include(FetchContent)
FetchContent_Declare(
  Catch2
  GIT_REPOSITORY https://github.com/catchorg/Catch2
  GIT_TAG fa43b77429ba76c462b1898d6cd2f2d7a9416b14 # v3.7.1
  FIND_PACKAGE_ARGS 3.0.0 CONFIG)
FetchContent_MakeAvailable(Catch2)

add_executable(catch-tests
  cpp/test_pointer_registry.cpp
)

if (LBANNV2_UNKNOWN_MI300A OR LBANNV2_WITH_MI300A)
  target_sources(catch-tests
    PRIVATE
    cpp/test_mi300a_allocator.cpp
  )
endif ()

target_link_libraries(catch-tests
  PRIVATE
  lbann::lbannv2
  Catch2::Catch2WithMain
)

set_target_properties(catch-tests
  PROPERTIES
  CXX_STANDARD 20
  CXX_STANDARD_REQUIRED ON
  CXX_EXTENSIONS ON
)


================================================
FILE: test/cpp/test_empty_tensor.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2_config.h>

#include <lbannv2/ops/empty_tensor.hpp>
#include <lbannv2/utils/device_helpers.hpp>

#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>

namespace
{
// This factory function can throw, and we cannot wrap an assignment
// in Catch's `REQUIRE_NOTHROW`/`REQUIRE_THROWS*` macros. We use this
// simple wrapper to facilitate things. eglot+clangd is able to still
// forward the inlay hints from `empty_lbann` out to the
// `make_empty_tensor` signature, so that's cool.
template <typename... Args>
void make_empty_tensor(at::Tensor& t, Args&&... args)
{
  t = lbannv2::empty_lbann(std::forward<Args>(args)...);
}
}  // namespace

TEST_CASE("empty_lbann", "[ops][empty]")
{
  at::Tensor t;
  c10::Device lbann_cpu {lbannv2::LBANNDeviceT, 0},
    lbann_gpu {lbannv2::LBANNDeviceT, 1};
  SECTION("Zero-size tensor is ok")
  {
#if LBANNV2_HAS_GPU
    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));
#else
    auto lbann_device = lbann_cpu;
#endif

    REQUIRE_NOTHROW(make_empty_tensor(t,
                                      c10::IntArrayRef {0},
                                      c10::ScalarType::Float,
                                      std::nullopt,
                                      lbann_device,
                                      false,
                                      std::nullopt));
    REQUIRE(t.dim() == 1);
    REQUIRE(t.sizes() == c10::IntArrayRef {0});
    REQUIRE(t.strides() == c10::IntArrayRef {1});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());
  }

  SECTION("Nonzero tensor is ok")
  {
#if LBANNV2_HAS_GPU
    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));
#else
    auto lbann_device = lbann_cpu;
#endif
    REQUIRE_NOTHROW(make_empty_tensor(t,
                                      c10::IntArrayRef {3, 4},
                                      c10::ScalarType::Float,
                                      std::nullopt,
                                      lbann_device,
                                      false,
                                      std::nullopt));
    REQUIRE(t.dim() == 2);
    REQUIRE(t.sizes() == c10::IntArrayRef {3, 4});
    REQUIRE(t.strides() == c10::IntArrayRef {4, 1});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());

    REQUIRE_NOTHROW(make_empty_tensor(t,
                                      c10::IntArrayRef {2, 3, 4, 5},
                                      c10::ScalarType::Float,
                                      std::nullopt,
                                      lbann_device,
                                      false,
                                      std::nullopt));
    REQUIRE(t.dim() == 4);
    REQUIRE(t.sizes() == c10::IntArrayRef {2, 3, 4, 5});
    REQUIRE(t.strides() == c10::IntArrayRef {60, 20, 5, 1});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());
  }

  SECTION("Non-LBANN devices throw")
  {
    REQUIRE_THROWS_WITH(
      make_empty_tensor(t,
                        c10::IntArrayRef {3, 4},
                        c10::ScalarType::Float,
                        std::nullopt,
                        c10::DeviceType::CPU,
                        false,
                        std::nullopt),
      "LBANN should only be constructing tensors on \"PrivateUse1\" backend");
  }
}

namespace
{

template <typename... Args>
void make_empty_strided_tensor(at::Tensor& t, Args&&... args)
{
  t = lbannv2::empty_strided_lbann(std::forward<Args>(args)...);
}

}  // namespace

TEST_CASE("empty_strided_lbann", "[ops][empty]")
{
  at::Tensor t;
  c10::Device lbann_cpu {lbannv2::LBANNDeviceT, 0},
    lbann_gpu {lbannv2::LBANNDeviceT, 1};

  SECTION("Zero-size tensor is ok")
  {
#if LBANNV2_HAS_GPU
    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));
#else
    auto lbann_device = lbann_cpu;
#endif
    REQUIRE_NOTHROW(make_empty_strided_tensor(t,
                                              c10::IntArrayRef {0},
                                              c10::IntArrayRef {1},
                                              c10::ScalarType::Float,
                                              std::nullopt,
                                              lbann_device,
                                              false));
    REQUIRE(t.dim() == 1);
    REQUIRE(t.sizes() == c10::IntArrayRef {0});
    REQUIRE(t.strides() == c10::IntArrayRef {1});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());
  }

  SECTION("Nonzero tensor is ok")
  {
#if LBANNV2_HAS_GPU
    auto lbann_device = GENERATE_COPY(values({lbann_cpu, lbann_gpu}));
#else
    auto lbann_device = lbann_cpu;
#endif
    REQUIRE_NOTHROW(make_empty_strided_tensor(t,
                                              c10::IntArrayRef {3, 4},
                                              c10::IntArrayRef {8, 2},
                                              c10::ScalarType::Float,
                                              std::nullopt,
                                              lbann_device,
                                              false));
    REQUIRE(t.dim() == 2);
    REQUIRE(t.sizes() == c10::IntArrayRef {3, 4});
    REQUIRE(t.strides() == c10::IntArrayRef {8, 2});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());

    REQUIRE_NOTHROW(make_empty_strided_tensor(t,
                                              c10::IntArrayRef {2, 3, 4, 5},
                                              c10::IntArrayRef {120, 40, 10, 2},
                                              c10::ScalarType::Float,
                                              std::nullopt,
                                              std::nullopt,
                                              false));
    REQUIRE(t.dim() == 4);
    REQUIRE(t.sizes() == c10::IntArrayRef {2, 3, 4, 5});
    REQUIRE(t.strides() == c10::IntArrayRef {120, 40, 10, 2});
    REQUIRE(t.is_privateuseone());
    REQUIRE(t.dtype().toScalarType() == c10::ScalarType::Float);
    REQUIRE(t.key_set().has(lbannv2::LBANNDispKey));
    REQUIRE_FALSE(t.is_pinned());
  }

  SECTION("Non-LBANN devices throw")
  {
    REQUIRE_THROWS_WITH(
      make_empty_strided_tensor(t,
                                c10::IntArrayRef {3, 4},
                                c10::IntArrayRef {8, 2},
                                c10::ScalarType::Float,
                                std::nullopt,
                                c10::DeviceType::CPU,
                                false),
      "LBANN should only be constructing tensors on \"PrivateUse1\" backend");
  }
}


================================================
FILE: test/cpp/test_helpers.hpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#pragma once

#include "lbannv2_config.h"

#include <lbannv2/utils/gpu_utils.hpp>

#include <catch2/catch_test_macros.hpp>

#if LBANNV2_WITH_MI300A
#define SKIP_WHEN_NO_MI300A()
#elif LBANNV2_WITHOUT_MI300A
#define SKIP_WHEN_NO_MI300A() SKIP("No MI300A support")
#elif LBANNV2_UNKNOWN_MI300A
#include <h2/gpu/runtime.hpp>
#define SKIP_WHEN_NO_MI300A()                                                  \
  do                                                                           \
  {                                                                            \
    if (!lbannv2::gpu::is_integrated())                                        \
    {                                                                          \
      SKIP("No MI300A support");                                               \
    }                                                                          \
  } while (0)
#endif


================================================
FILE: test/cpp/test_mi300a_allocator.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2/memory/mi300a_allocator.hpp>
#include <lbannv2/memory/registry.hpp>
#include <lbannv2/utils/gpu_utils.hpp>

#include "test_helpers.hpp"

#include <ATen/hip/HIPContextLight.h>
#include <c10/core/Allocator.h>
#include <c10/core/CPUAllocator.h>
#include <c10/hip/HIPCachingAllocator.h>
#include <c10/hip/HIPStream.h>

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>

namespace
{

void do_raw_allocate(void** ptr, size_t size, lbannv2::MI300Allocator& alloc)
{
  *ptr = alloc.raw_allocate(size);
}

}  // namespace

TEST_CASE("MI300Allocator::raw_allocate and MI300Allocator::raw_deallocate",
          "[memory][mi300a]")
{
  SKIP_WHEN_NO_MI300A();

  auto& alloc = lbannv2::MI300Allocator::instance();
  size_t const size = 64;
  void* ptr = nullptr;
  REQUIRE_NOTHROW(do_raw_allocate(&ptr, size, alloc));

  CHECK(ptr != nullptr);

  REQUIRE_NOTHROW(alloc.raw_deallocate(ptr));
}

namespace
{
c10::Device lbann_cpu() noexcept
{
  return c10::Device {c10::kCPU};
}
c10::Device lbann_gpu() noexcept
{
  return c10::Device {
    c10::kCUDA, static_cast<c10::DeviceIndex>(lbannv2::gpu::current_device())};
}
}  // namespace

TEST_CASE("MI300Allocator::allocate and MI300Allocator::deallocate",
          "[memory][mi300a]")
{
  SKIP_WHEN_NO_MI300A();

  auto& alloc = lbannv2::MI300Allocator::instance();
  size_t const size = 64;

  void* raw_ptr = nullptr;
  {
    auto ptr = alloc.allocate(size);
    raw_ptr = ptr.get();
    CHECK(ptr.device() == lbann_cpu());
    CHECK(lbannv2::pointer_registry().known(raw_ptr));
  }

  // DataPtr goes out of scope, should be deleted.

  CHECK(!lbannv2::pointer_registry().known(raw_ptr));
}

// The "kernel" here is loosely inspired by Aluminum's "GPUWait", but
// less fussy about things like "cache-line allocation" and
// "atomics"... All I need is something to guarantee the stream isn't
// synced before the second allocation, and this saves me the trouble
// of compiling a HIP kernel.
TEST_CASE("MI300Allocator stream semantics are working", "[memory][mi300a]")
{
  auto const gpu = lbann_gpu();

  // Some memory we can use later.
  int32_t* wait_mem;
  LBANNV2_CHECK_GPU(hipMalloc(&wait_mem, sizeof(int32_t)));
  *wait_mem = 0;

  int32_t const wait_value = 1;
  auto& alloc = lbannv2::MI300Allocator::instance();
  size_t const size = 64;

  // open block
  //   do an allocation
  //   migrate allocation to GPU
  //   "run a kernel" on the same stream
  // close block (delete the allocation)
  // allocate new buffer
  // check old and new buffers have different addresses

  auto torch_stream = lbannv2::getDeviceCurrentStream(gpu.index());
  void* orig_ptr = nullptr;  // never dereferenced
  {
    auto ptr = alloc.allocate(size);
    // cache the buffer address -- NEVER DEREFERENCED
    orig_ptr = ptr.get();

    // Add the ptr to the stream on GPU
    lbannv2::migrate_ptr(ptr, gpu, torch_stream);
    // Fake a kernel on the stream
    LBANNV2_CHECK_GPU(hipStreamWaitValue32(
      torch_stream, wait_mem, wait_value, hipStreamWaitValueEq));
  }
  // GPU allocation will "FREE_REQUESTED" here, but it should NOT be
  // available for reuse

  auto ptr = alloc.allocate(size);
  CHECK(ptr.get() != orig_ptr);  // NOT REQUIRE -- need to clean up.

  // Write the new value
  *wait_mem = wait_value;

  // Ensure the "kernel" is done.
  torch_stream.synchronize();

  // Free our wait memory
  LBANNV2_CHECK_GPU(hipFree(wait_mem));
}


================================================
FILE: test/cpp/test_pointer_registry.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2/memory/registry.hpp>

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>

TEST_CASE("RangeLessAndDisjoint", "[memory][registry]")
{
  std::vector<unsigned char> buffer(8);
  lbannv2::PointerRegistry::RangeLessAndDisjoint rng_less;

  SECTION("Non-overlapping ranges behave sanely")
  {
    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[3], &buffer[4]}));
    CHECK_FALSE(rng_less({&buffer[3], &buffer[4]}, {&buffer[1], &buffer[2]}));
  }

  SECTION("Abutting ranges are nonoverlapping")
  {
    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[2], &buffer[3]}));
    CHECK_FALSE(rng_less({&buffer[2], &buffer[3]}, {&buffer[1], &buffer[2]}));
  }

  SECTION("Identical ranges")
  {
    CHECK_FALSE(rng_less({&buffer[1], &buffer[4]}, {&buffer[1], &buffer[4]}));
  }

  SECTION("Partially overlapping ranges")
  {
    CHECK_FALSE(rng_less({&buffer[1], &buffer[4]}, {&buffer[2], &buffer[5]}));
    CHECK_FALSE(rng_less({&buffer[2], &buffer[5]}, {&buffer[1], &buffer[4]}));
  }

  SECTION("One range proper subset of the other")
  {
    CHECK_FALSE(rng_less({&buffer[1], &buffer[8]}, {&buffer[3], &buffer[4]}));
    CHECK_FALSE(rng_less({&buffer[3], &buffer[4]}, {&buffer[1], &buffer[8]}));
  }

  SECTION("Zero-size ranges work appropriately")
  {
    CHECK(rng_less({&buffer[1], &buffer[1]}, {&buffer[2], &buffer[2]}));
    CHECK_FALSE(rng_less({&buffer[2], &buffer[2]}, {&buffer[1], &buffer[1]}));

    CHECK(rng_less({&buffer[1], &buffer[1]}, {&buffer[2], &buffer[4]}));
    CHECK_FALSE(rng_less({&buffer[2], &buffer[4]}, {&buffer[1], &buffer[1]}));

    CHECK(rng_less({&buffer[1], &buffer[2]}, {&buffer[2], &buffer[2]}));
    CHECK_FALSE(rng_less({&buffer[2], &buffer[2]}, {&buffer[1], &buffer[2]}));

    CHECK(rng_less({&buffer[1], &buffer[2]}, &buffer[2]));
    CHECK(rng_less(&buffer[1], {&buffer[2], &buffer[3]}));

    CHECK_FALSE(rng_less(&buffer[1], {&buffer[1], &buffer[2]}));
    CHECK_FALSE(rng_less({&buffer[1], &buffer[2]}, &buffer[1]));

    CHECK_FALSE(rng_less(&buffer[1], {&buffer[1], &buffer[1]}));
    CHECK_FALSE(rng_less({&buffer[1], &buffer[1]}, &buffer[1]));
  }
}

namespace
{
size_t rng_bytes(std::pair<void*, void*> const& r)
{
  return std::distance((std::byte*) r.first, (std::byte*) r.second);
}
}  // namespace

TEST_CASE("PointerRegistry::add()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(32);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  SECTION("Adding nonoverlapping regions is successful.")
  {
    RangeT const rng1 = {&buffer[4], &buffer[8]};
    RangeT const rng2 = {&buffer[12], &buffer[16]};
    RangeT const rng3 = {&buffer[16], &buffer[20]};
    RangeT const rng4 = {&buffer[8], &buffer[12]};

    size_t expected_bytes = 0UL;
    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));
    expected_bytes += rng_bytes(rng1);

    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == expected_bytes);

    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));
    expected_bytes += rng_bytes(rng2);

    REQUIRE(registry.num_registered() == 2UL);
    REQUIRE(registry.bytes_registered() == expected_bytes);

    REQUIRE_NOTHROW(registry.add(rng3.first, rng_bytes(rng3), nullptr));
    expected_bytes += rng_bytes(rng3);

    REQUIRE(registry.num_registered() == 3UL);
    REQUIRE(registry.bytes_registered() == expected_bytes);

    REQUIRE_NOTHROW(registry.add(rng4.first, rng_bytes(rng4), nullptr));
    expected_bytes += rng_bytes(rng4);

    REQUIRE(registry.num_registered() == 4UL);
    REQUIRE(registry.bytes_registered() == expected_bytes);
  }

  SECTION("Zero-size regions")
  {
    SECTION("Adding zero-size regions is ok.")
    {
      RangeT const rng1 = {&buffer[0], &buffer[0]};
      RangeT const rng2 = {&buffer[2], &buffer[2]};

      REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));
      REQUIRE(registry.num_registered() == 1UL);
      REQUIRE(registry.bytes_registered() == 0UL);

      REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));
      REQUIRE(registry.num_registered() == 2UL);
      REQUIRE(registry.bytes_registered() == 0UL);
    }

    SECTION("Zero-size regions are not valid start points for other regions")
    {
      RangeT const zero_rng = {&buffer[0], &buffer[0]};
      RangeT const other_rng = {&buffer[0], &buffer[2]};

      REQUIRE_NOTHROW(
        registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));
      REQUIRE(registry.num_registered() == 1UL);
      REQUIRE(registry.bytes_registered() == 0UL);

      REQUIRE_THROWS_WITH(
        registry.add(other_rng.first, rng_bytes(other_rng), nullptr),
        "Address range overlaps existing range");
    }

    SECTION("Zero-size regions are valid end points for other regions")
    {
      RangeT const other_rng = {&buffer[0], &buffer[2]};
      RangeT const zero_rng = {&buffer[2], &buffer[2]};

      REQUIRE_NOTHROW(
        registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));
      REQUIRE(registry.num_registered() == 1UL);
      REQUIRE(registry.bytes_registered() == 0UL);

      REQUIRE_NOTHROW(
        registry.add(other_rng.first, rng_bytes(other_rng), nullptr));
      REQUIRE(registry.num_registered() == 2UL);
      REQUIRE(registry.bytes_registered() == rng_bytes(other_rng));
    }
  }
}

TEST_CASE("PointerRegistry::remove()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(32);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  SECTION("Removing a context pointer works")
  {
    RangeT const rng = {&buffer[4], &buffer[8]};
    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE_NOTHROW(registry.remove(rng.first));
    REQUIRE(registry.num_registered() == 0UL);
    REQUIRE(registry.bytes_registered() == 0UL);
  }

  SECTION("Removing a known non-context pointer fails")
  {
    RangeT const rng = {&buffer[4], &buffer[8]};
    void* const noncontext_ptr = &buffer[6];

    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE_THROWS_WITH(registry.remove(noncontext_ptr),
                        "Cannot remove ptr; not beginning of range.");
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));
  }

  SECTION("Removing an unknown pointer fails")
  {
    RangeT const rng = {&buffer[4], &buffer[8]};
    void* const unknown_ptr = &buffer[16];

    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE_THROWS_AS(registry.remove(unknown_ptr), lbannv2::UnknownAddress);
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));
  }

  SECTION("Removing a zero-size region is ok")
  {
    RangeT const rng = {&buffer[2], &buffer[2]};
    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == 0UL);

    REQUIRE_NOTHROW(registry.remove(rng.first));
    REQUIRE(registry.num_registered() == 0UL);
    REQUIRE(registry.bytes_registered() == 0UL);
  }
}

TEST_CASE("PointerRegistry::known()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(32);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  SECTION("Pointers in registered ranges are known")
  {
    RangeT const rng = {&buffer[4], &buffer[8]};
    void const* const context_ptr = &buffer[4];
    void const* const noncontext_ptr = &buffer[6];
    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE(registry.known(context_ptr));
    REQUIRE(registry.known(noncontext_ptr));
  }

  SECTION("Registered pointer in size-zero ranges are known")
  {
    RangeT const rng = {&buffer[4], &buffer[4]};
    void const* const context_ptr = &buffer[4];
    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE(registry.known(context_ptr));
  }

  SECTION("Pointers outside registered ranges are not known")
  {
    RangeT const rng = {&buffer[4], &buffer[8]};
    void const* const unknown_low_ptr = &buffer[2];
    void const* const unknown_ub_ptr = &buffer[8];
    void const* const unknown_high_ptr = &buffer[14];

    REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
    REQUIRE(registry.num_registered() == 1UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng));

    REQUIRE_FALSE(registry.known(unknown_low_ptr));
    REQUIRE_FALSE(registry.known(unknown_ub_ptr));
    REQUIRE_FALSE(registry.known(unknown_high_ptr));
  }
}

TEST_CASE("PointerRegistry::get_context()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(32);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  SECTION("Context pointers are their own context")
  {
    RangeT const rng1 = {&buffer[4], &buffer[8]};
    RangeT const rng2 = {&buffer[12], &buffer[16]};
    RangeT const zero_rng = {&buffer[20], &buffer[20]};

    void const* const context_ptr1 = &buffer[4];
    void const* const context_ptr2 = &buffer[12];
    void const* const zero_context_ptr = &buffer[20];

    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));
    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));
    REQUIRE_NOTHROW(registry.add(zero_rng.first, rng_bytes(zero_rng), nullptr));
    REQUIRE(registry.num_registered() == 3UL);
    REQUIRE(registry.bytes_registered()
            == rng_bytes(rng1) + rng_bytes(rng2) + rng_bytes(zero_rng));

    REQUIRE(registry.get_context(context_ptr1) == rng1.first);
    REQUIRE(registry.get_context(context_ptr2) == rng2.first);
    REQUIRE(registry.get_context(zero_context_ptr) == zero_rng.first);
  }

  SECTION("Noncontext pointers return the proper context pointer")
  {
    RangeT const rng1 = {&buffer[4], &buffer[8]};
    RangeT const rng2 = {&buffer[12], &buffer[16]};

    void const* const noncontext_ptr1 = &buffer[6];
    void const* const noncontext_ptr2 = &buffer[14];

    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));
    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));

    REQUIRE(registry.num_registered() == 2UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng1) + rng_bytes(rng2));

    REQUIRE(registry.get_context(noncontext_ptr1) == rng1.first);
    REQUIRE(registry.get_context(noncontext_ptr2) == rng2.first);
  }

  SECTION("Unknown pointers fail")
  {
    RangeT const rng1 = {&buffer[4], &buffer[8]};
    RangeT const rng2 = {&buffer[12], &buffer[16]};

    void const* const ptr1 = &buffer[2];
    void const* const ptr2 = &buffer[8];
    void const* const ptr3 = &buffer[10];
    void const* const ptr4 = &buffer[16];
    void const* const ptr5 = &buffer[20];

    REQUIRE_NOTHROW(registry.add(rng1.first, rng_bytes(rng1), nullptr));
    REQUIRE_NOTHROW(registry.add(rng2.first, rng_bytes(rng2), nullptr));

    REQUIRE(registry.num_registered() == 2UL);
    REQUIRE(registry.bytes_registered() == rng_bytes(rng1) + rng_bytes(rng2));

    REQUIRE_THROWS_AS(registry.get_context(ptr1), lbannv2::UnknownAddress);
    REQUIRE_THROWS_AS(registry.get_context(ptr2), lbannv2::UnknownAddress);
    REQUIRE_THROWS_AS(registry.get_context(ptr3), lbannv2::UnknownAddress);
    REQUIRE_THROWS_AS(registry.get_context(ptr4), lbannv2::UnknownAddress);
    REQUIRE_THROWS_AS(registry.get_context(ptr5), lbannv2::UnknownAddress);
  }
}

TEST_CASE("PointerRegistry::unsafe_reset_allocator()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(32);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  RangeT const rng = {&buffer[4], &buffer[8]};

  void const* const ctxt_ptr = &buffer[4];
  void const* const mid_ptr = &buffer[6];
  void const* const bad_ptr = &buffer[0];

  c10::Allocator& alloc = *c10::GetAllocator(c10::kCPU);
  c10::Allocator* orig_alloc = &alloc;

  // FAKE -- DO NOT DEREFERENCE!
  c10::Allocator* other_alloc = ++orig_alloc;

  // Get the allocator setup
  REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), orig_alloc));
  REQUIRE(registry.get_allocator(ctxt_ptr) == orig_alloc);
  REQUIRE(registry.get_allocator(mid_ptr) == orig_alloc);

  SECTION("Resetting by context is ok")
  {
    REQUIRE_NOTHROW(registry.unsafe_reset_allocator(ctxt_ptr, other_alloc));
    REQUIRE(registry.get_allocator(ctxt_ptr) == other_alloc);
    REQUIRE(registry.get_allocator(mid_ptr) == other_alloc);
  }

  SECTION("Resetting by an interior pointer is ok")
  {
    // FIXME: Perhaps this should actually be disallowed??
    REQUIRE_NOTHROW(registry.unsafe_reset_allocator(ctxt_ptr, other_alloc));
    REQUIRE(registry.get_allocator(ctxt_ptr) == other_alloc);
    REQUIRE(registry.get_allocator(mid_ptr) == other_alloc);
  }

  SECTION("Resetting an unknown pointer fails")
  {
    REQUIRE_THROWS_AS(registry.unsafe_reset_allocator(bad_ptr, other_alloc),
                      lbannv2::UnknownAddress);
  }
}

TEST_CASE("PointerRegistry::bytes_registered()", "[memory][registry]")
{
  using RangeT = std::pair<void*, void*>;

  lbannv2::PointerRegistry registry;
  std::vector<unsigned char> buffer(16);

  // Establish preconditions
  REQUIRE(registry.num_registered() == 0UL);
  REQUIRE(registry.bytes_registered() == 0UL);

  RangeT const rng = {&buffer[4], &buffer[8]};
  size_t const rng_size = rng_bytes(rng);

  void const* const ctxt_ptr = &buffer[4];
  void const* const mid_ptr = &buffer[6];
  void const* const extern_ptr_1 = &buffer[0];
  void const* const extern_ptr_2 = &buffer[16];

  REQUIRE_NOTHROW(registry.add(rng.first, rng_bytes(rng), nullptr));
  REQUIRE(registry.bytes_registered() == rng_size);

  CHECK(registry.bytes_registered(ctxt_ptr) == rng_size);
  CHECK(registry.bytes_registered(mid_ptr) == rng_size);
  CHECK(registry.bytes_registered(extern_ptr_1) == 0UL);
  CHECK(registry.bytes_registered(extern_ptr_2) == 0UL);
}


================================================
FILE: test/cpp/test_tensor_helpers.cpp
================================================
////////////////////////////////////////////////////////////////////////////////
// Copyright 2014-2025 Lawrence Livermore National Security, LLC and other
// LBANN Project Developers. See the top-level LICENSE file for details.
//
// SPDX-License-Identifier: Apache-2.0
////////////////////////////////////////////////////////////////////////////////
#include <lbannv2/ops/empty_tensor.hpp>
#include <lbannv2/utils/tensor_helpers.hpp>

#include <ATen/EmptyTensor.h>

// A c10 header file in PyTorch has left a macro called `CHECK`
// defined. To prevent warnings, we need to clear that out. This
// should not cause problems as we don't use the PyTorch macro
// directly, and all PyTorch includes should precede this line in this
// source code.
#ifdef CHECK
#undef CHECK
#endif

#include <catch2/catch_test_macros.hpp>

TEST_CASE("alias_as_device", "[tensor][utils]")
{
  SECTION("Aliasing from LBANN to native device")
  {
    at::Tensor t = lbannv2::empty_lbann({2, 3, 4},
                                        c10::ScalarType::Float,
                                        std::nullopt,
                                        std::nullopt,
                                        false,
                                        std::nullopt);
    auto const orig_keys = t.key_set();
    auto const orig_device = t.device();

    at::Tensor cpu_alias = lbannv2::alias_as_device(
      t, c10::DeviceType::CPU, c10::DispatchKeySet {c10::DispatchKey::CPU});

    CHECK(t.is_privateuseone());
    CHECK(t.key_set() == orig_keys);
    CHECK(t.device() == orig_device);

    CHECK(cpu_alias.is_alias_of(t));
    CHECK(cpu_alias.is_cpu());

    // This is documented to change
    CHECK(t.storage().data_ptr().device().is_cpu());

    // Metadata should match
    CHECK(cpu_alias.sizes() == t.sizes());
    CHECK(cpu_alias.strides() == t.strides());
    CHECK(cpu_alias.names() == t.names());
    CHECK(cpu_alias.dtype() == t.dtype());
  }
}

TEST_CASE("alias_as_native_device", "[tensor][utils]")
{
  SECTION("Aliasing a native PyTorch tensor does nothing")
  {
    at::Tensor t = at::detail::empty_cpu({3, 2, 4},
                                         c10::ScalarType::Float,
                                         std::nullopt,
                                         std::nullopt,
                                         std::nullopt,
                                         std::nullopt);
    at::Tensor alias = lbannv2::alias_as_native_device(t);
    CHECK(alias.is_alias_of(t));
    CHECK(alias.key_set() == t.key_set());
    CHECK(alias.device() == t.device());
    CHECK(alias.dtype() == t.dtype());
    CHECK(alias.unsafeGetTensorImpl() == t.unsafeGetTensorImpl());
  }

  SECTION("Aliasing an LBANN tensor is ok")
  {
    using namespace lbannv2;
    static constexpr auto LBANNbit = c10::BackendComponent::PrivateUse1Bit;

    at::Tensor t = lbannv2::empty_lbann({2, 3, 4},
                                        c10::ScalarType::Float,
                                        std::nullopt,
                                        c10::Device {LBANNDeviceT, LBANN_CPU},
                                        false,
                                        std::nullopt);
    at::Tensor lbann_alias = lbannv2::alias_as_native_device(t);

    // Still an alias (based on storage objects)
    CHECK(lbann_alias.is_alias_of(t));
    CHECK(lbann_alias.key_set() == t.key_set().remove_backend(LBANNbit));
    CHECK(lbann_alias.sizes() == t.sizes());
    CHECK(lbann_alias.strides() == t.strides());
    CHECK(lbann_alias.dtype() == t.dtype());
    CHECK(lbann_alias.device() != t.device());
    CHECK(lbann_alias.device().is_cpu());
    CHECK(lbann_alias.unsafeGetTensorImpl()->data()
          == t.unsafeGetTensorImpl()->data());
    CHECK(lbann_alias.unsafeGetTensorImpl()->storage_offset()
          == t.unsafeGetTensorImpl()->storage_offset());
  }
}
Download .txt
gitextract_2qpyp2vm/

├── .clang-format
├── .gitignore
├── CMakeLists.txt
├── CONTRIBUTING.md
├── CONTRIBUTORS
├── LICENSE
├── NOTICE
├── README.md
├── cmake/
│   ├── LBANNv2DetectTorchNVIDIALibraries.cmake
│   ├── LBANNv2DetermineMI300A.cmake
│   ├── lbannv2Config.cmake.in
│   └── lbannv2_config.h.in
├── pyproject.toml
├── python/
│   └── lbannv2/
│       ├── __init__.py
│       └── _automigrate.py
├── src/
│   └── lbannv2/
│       ├── CMakeLists.txt
│       ├── memory/
│       │   ├── CMakeLists.txt
│       │   ├── allocator.cpp
│       │   ├── allocator.hpp
│       │   ├── h2_allocator_wrappers.cpp
│       │   ├── h2_allocator_wrappers.hpp
│       │   ├── memory_utils.hpp
│       │   ├── mi300a_allocator.cpp
│       │   ├── mi300a_allocator.hpp
│       │   ├── registry.cpp
│       │   └── registry.hpp
│       ├── ops/
│       │   ├── CMakeLists.txt
│       │   ├── migrate.cpp
│       │   ├── migrate.hpp
│       │   ├── nonzero.hip
│       │   ├── nonzero.hpp
│       │   ├── scalar.cpp
│       │   └── scalar.hpp
│       ├── python/
│       │   ├── CMakeLists.txt
│       │   ├── register_lbannv2.cpp
│       │   ├── register_memory_funcs.cpp
│       │   └── register_mi300a_ops.cpp
│       ├── types.hpp
│       └── utils/
│           ├── CMakeLists.txt
│           ├── debugging_helpers.hpp
│           ├── errors.hpp
│           ├── gpu_utils.cpp
│           ├── gpu_utils.hpp
│           ├── logging.cpp
│           ├── logging.hpp
│           └── tensor_helpers.hpp
└── test/
    ├── CMakeLists.txt
    └── cpp/
        ├── test_empty_tensor.cpp
        ├── test_helpers.hpp
        ├── test_mi300a_allocator.cpp
        ├── test_pointer_registry.cpp
        └── test_tensor_helpers.cpp
Download .txt
SYMBOL INDEX (107 symbols across 28 files)

FILE: python/lbannv2/__init__.py
  function is_available (line 20) | def is_available():
  class MigratableMemory (line 26) | class MigratableMemory:
    method __enter__ (line 29) | def __enter__(self):
    method __exit__ (line 32) | def __exit__(self, exc_type, exc_value, traceback):
  function make_migratory_tensor (line 36) | def make_migratory_tensor(ctor, *args, **kwargs):

FILE: python/lbannv2/_automigrate.py
  function automigrate (line 10) | def automigrate(f: Union[Callable, torch.fx.GraphModule]) -> torch.fx.Gr...

FILE: src/lbannv2/memory/allocator.cpp
  type lbannv2 (line 25) | namespace lbannv2

FILE: src/lbannv2/memory/allocator.hpp
  type lbannv2 (line 12) | namespace lbannv2
    class LBANNV2_EXPORT (line 18) | class LBANNV2_EXPORT

FILE: src/lbannv2/memory/h2_allocator_wrappers.cpp
  type lbannv2 (line 9) | namespace lbannv2

FILE: src/lbannv2/memory/h2_allocator_wrappers.hpp
  type lbannv2 (line 18) | namespace lbannv2
    class H2AllocatorWrapper (line 22) | class H2AllocatorWrapper : public Allocator
      method copy_data (line 31) | void copy_data(void* dst, void const* src, size_t n) const final
      method raw_deallocate (line 55) | void raw_deallocate(void* ptr) final
      method get_device (line 61) | c10::Device get_device() const noexcept final
      method H2AllocatorWrapper (line 74) | static H2AllocatorWrapper& instance()
      method H2AllocatorWrapper (line 81) | H2AllocatorWrapper() = default;
      method H2AllocatorWrapper (line 83) | H2AllocatorWrapper(H2AllocatorWrapper const&) = delete;
      method H2AllocatorWrapper (line 84) | H2AllocatorWrapper(H2AllocatorWrapper&&) = delete;
      method H2AllocatorWrapper (line 85) | H2AllocatorWrapper& operator=(H2AllocatorWrapper const&) = delete;
      method H2AllocatorWrapper (line 86) | H2AllocatorWrapper& operator=(H2AllocatorWrapper&&) = delete;

FILE: src/lbannv2/memory/memory_utils.hpp
  type lbannv2 (line 13) | namespace lbannv2
    class AllocatorWrapper (line 30) | class AllocatorWrapper : public c10::Allocator
      method AllocatorWrapper (line 39) | AllocatorWrapper(c10::Allocator& alloc, c10::Device device)
      method allocate (line 44) | c10::DataPtr allocate(size_t n) final
      method raw_deleter (line 56) | c10::DeleterFnPtr raw_deleter() const noexcept final
      method copy_data (line 61) | void copy_data(void* dst, void const* src, size_t n) const final

FILE: src/lbannv2/memory/mi300a_allocator.cpp
  function get_use_nonblocking_stream_env_var (line 28) | bool get_use_nonblocking_stream_env_var()
  function use_nonblocking_stream (line 34) | bool use_nonblocking_stream()
  type StreamRAII (line 41) | struct StreamRAII
    method StreamRAII (line 45) | StreamRAII()
  function host_allocation_stream (line 63) | ::lbannv2::TorchGPUStream_t host_allocation_stream(c10::DeviceIndex cons...
  function resolve_device (line 70) | c10::Device resolve_device(c10::Device const& d)
  function lbannv2_report_free (line 84) | void lbannv2_report_free(DeviceAlloc_ns::TraceEntry const& entry)
  function lbannv2_trace_alloc (line 98) | void lbannv2_trace_alloc(DeviceAlloc_ns::TraceEntry const& entry)
  type lbannv2 (line 105) | namespace lbannv2
    function MI300Allocator (line 169) | MI300Allocator& MI300Allocator::instance()

FILE: src/lbannv2/memory/mi300a_allocator.hpp
  type lbannv2 (line 20) | namespace lbannv2
    class MI300Allocator (line 28) | class MI300Allocator final : public Allocator
      method MI300Allocator (line 46) | MI300Allocator(MI300Allocator const&) = delete;
      method MI300Allocator (line 47) | MI300Allocator(MI300Allocator&&) = delete;
      method MI300Allocator (line 48) | MI300Allocator& operator=(MI300Allocator const&) = delete;
      method MI300Allocator (line 49) | MI300Allocator& operator=(MI300Allocator&&) = delete;

FILE: src/lbannv2/memory/registry.cpp
  function range_bytes (line 26) | std::size_t range_bytes(std::pair<void*, void*> const& r) noexcept
  type lbannv2 (line 33) | namespace lbannv2

FILE: src/lbannv2/memory/registry.hpp
  type lbannv2 (line 19) | namespace lbannv2
    type LBANNV2_EXPORT (line 22) | struct LBANNV2_EXPORT
    function UnknownAddress (line 24) | UnknownAddress() : std::runtime_error {"Unknown address"} {}
    function PointerRegistry (line 38) | class LBANNV2_EXPORT PointerRegistry

FILE: src/lbannv2/ops/migrate.cpp
  function get_origin_device (line 28) | at::Device get_origin_device(void const* const ptr)
  function is_ok_device (line 51) | bool is_ok_device(c10::Device const& d)
  function get_default_keyset (line 60) | c10::DispatchKeySet get_default_keyset(c10::Device const& d)

FILE: src/lbannv2/ops/migrate.hpp
  type lbannv2 (line 14) | namespace lbannv2

FILE: src/lbannv2/ops/nonzero.hpp
  type lbannv2 (line 11) | namespace lbannv2

FILE: src/lbannv2/ops/scalar.cpp
  function mi300a_impl (line 30) | at::Scalar mi300a_impl(at::Tensor const& self)
  function mi300a_dispatch (line 40) | at::Scalar mi300a_dispatch(at::Tensor const& self)

FILE: src/lbannv2/ops/scalar.hpp
  type lbannv2 (line 13) | namespace lbannv2

FILE: src/lbannv2/python/register_lbannv2.cpp
  function init_lbannv2 (line 26) | void init_lbannv2()
  function is_lbannv2_initialized (line 50) | bool is_lbannv2_initialized() noexcept
  function is_lbannv2_gpu_initialized (line 55) | bool is_lbannv2_gpu_initialized() noexcept
  function is_lbannv2_gpu_available (line 60) | bool is_lbannv2_gpu_available() noexcept
  type _lbannv2 (line 67) | namespace _lbannv2
  function PYBIND11_MODULE (line 72) | PYBIND11_MODULE(_lbannv2, m)

FILE: src/lbannv2/python/register_memory_funcs.cpp
  function py_migrate (line 31) | at::Tensor py_migrate(at::Tensor& t, at::Device const& d)
  function py_supports_migrate (line 36) | bool py_supports_migrate() noexcept
  function py_use_mi300a_host_allocator (line 47) | void py_use_mi300a_host_allocator()
  function py_use_torch_host_allocator (line 52) | void py_use_torch_host_allocator()
  function py_using_lbannv2_memory (line 57) | bool py_using_lbannv2_memory(torch::Tensor const& t)
  type _lbannv2 (line 64) | namespace _lbannv2
    function add_memory_funcs (line 67) | void add_memory_funcs(pybind11::module_& m)

FILE: src/lbannv2/python/register_mi300a_ops.cpp
  function lbannv2__local_scalar_dense_cuda (line 17) | at::Scalar lbannv2__local_scalar_dense_cuda(at::Tensor const& self)
  function lbannv2_nonzero (line 28) | at::Tensor lbannv2_nonzero(at::Tensor const& self)
  function TORCH_LIBRARY_IMPL (line 52) | TORCH_LIBRARY_IMPL(aten, CUDA, m)

FILE: src/lbannv2/types.hpp
  type lbannv2 (line 13) | namespace lbannv2
    function is_supported (line 17) | inline bool is_supported(c10::ScalarType t) noexcept

FILE: src/lbannv2/utils/debugging_helpers.hpp
  type lbannv2 (line 18) | namespace lbannv2
    function demngl (line 21) | inline std::string demngl(std::string symb)
    function print_bt (line 38) | inline void print_bt(size_t nframes = 128, std::ostream& os = std::cout)

FILE: src/lbannv2/utils/gpu_utils.hpp
  type lbannv2 (line 52) | namespace lbannv2
    function has_cuda (line 64) | inline constexpr bool has_cuda() noexcept
    function has_hip (line 69) | inline constexpr bool has_hip() noexcept
    function has_gpu (line 74) | inline constexpr bool has_gpu() noexcept
    type gpu (line 79) | namespace gpu

FILE: src/lbannv2/utils/logging.cpp
  function get_env_log_level (line 23) | spdlog::level::level_enum get_env_log_level()
  function get_hostname (line 50) | std::string get_hostname()
  class HostFlag (line 62) | class HostFlag final : public spdlog::custom_flag_formatter
    method clone (line 65) | std::unique_ptr<custom_flag_formatter> clone() const final
    method format (line 69) | void format(::spdlog::details::log_msg const&,
  function make_default_formatter (line 78) | std::unique_ptr<::spdlog::pattern_formatter> make_default_formatter()
  function make_default_sink (line 87) | ::spdlog::sink_ptr make_default_sink()
  function make_default_logger (line 98) | std::shared_ptr<::spdlog::logger> make_default_logger()

FILE: src/lbannv2/utils/logging.hpp
  type lbannv2 (line 47) | namespace lbannv2

FILE: src/lbannv2/utils/tensor_helpers.hpp
  type lbannv2 (line 15) | namespace lbannv2
    function is_lbann (line 19) | inline bool is_lbann(at::Tensor const& t)
    function is_scalar (line 24) | inline bool is_scalar(at::Tensor const& t)
    function set_data_ptr_device (line 29) | inline void set_data_ptr_device(c10::DataPtr& dp, c10::Device d)
    function set_data_ptr_device (line 34) | inline void set_data_ptr_device(c10::Storage const& s, c10::Device d)
    function sync_metadata (line 39) | inline void sync_metadata(at::Tensor const& src, at::Tensor& dst)
    function alias_as_device (line 59) | inline at::Tensor alias_as_device(at::Tensor const& orig_tensor,
    function to_str (line 92) | inline std::string to_str(at::Tensor const& t)
    function to_str (line 101) | std::string to_str(c10::ArrayRef<T> const& ar)

FILE: test/cpp/test_empty_tensor.cpp
  function make_empty_tensor (line 27) | void make_empty_tensor(at::Tensor& t, Args&&... args)
  function make_empty_strided_tensor (line 118) | void make_empty_strided_tensor(at::Tensor& t, Args&&... args)

FILE: test/cpp/test_mi300a_allocator.cpp
  function do_raw_allocate (line 25) | void do_raw_allocate(void** ptr, size_t size, lbannv2::MI300Allocator& a...
  function lbann_cpu (line 49) | c10::Device lbann_cpu() noexcept
  function lbann_gpu (line 53) | c10::Device lbann_gpu() noexcept

FILE: test/cpp/test_pointer_registry.cpp
  function rng_bytes (line 70) | size_t rng_bytes(std::pair<void*, void*> const& r)
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (157K chars).
[
  {
    "path": ".clang-format",
    "chars": 4310,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": ".gitignore",
    "chars": 574,
    "preview": "################################################################################\n## Copyright 2019-2020 Lawrence Livermo"
  },
  {
    "path": "CMakeLists.txt",
    "chars": 12522,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1658,
    "preview": "# Contributing Guidelines for LBANN\n\nWe welcome any contributions to LBANN in the form of Pull Requests.\nPlease follow t"
  },
  {
    "path": "CONTRIBUTORS",
    "chars": 703,
    "preview": "LLNL Core Team:\n  Brian Van Essen <vanessen1@llnl.gov> [@bvanessen]\n  Tom Benson <benson31@llnl.gov> [@benson31]\n  Nikol"
  },
  {
    "path": "LICENSE",
    "chars": 1341,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "NOTICE",
    "chars": 1165,
    "preview": "This work was produced under the auspices of the U.S. Department of Energy by\nLawrence Livermore National Laboratory und"
  },
  {
    "path": "README.md",
    "chars": 347,
    "preview": "# Build\n\nTo save some pip-related heartburn, LBANNv2 is currently BYOT (\"bring\nyour own Torch\").\n\n```\npip install torch\n"
  },
  {
    "path": "cmake/LBANNv2DetectTorchNVIDIALibraries.cmake",
    "chars": 3922,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "cmake/LBANNv2DetermineMI300A.cmake",
    "chars": 2663,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "cmake/lbannv2Config.cmake.in",
    "chars": 964,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "cmake/lbannv2_config.h.in",
    "chars": 1227,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "pyproject.toml",
    "chars": 2083,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "python/lbannv2/__init__.py",
    "chars": 1042,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "python/lbannv2/_automigrate.py",
    "chars": 3329,
    "preview": "import torch\nfrom typing import Callable, Union\n\ntry:\n    from .lib._lbannv2 import migrate\nexcept ModuleNotFoundError:\n"
  },
  {
    "path": "src/lbannv2/CMakeLists.txt",
    "chars": 575,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/CMakeLists.txt",
    "chars": 763,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/allocator.cpp",
    "chars": 1837,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/allocator.hpp",
    "chars": 987,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/h2_allocator_wrappers.cpp",
    "chars": 598,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/h2_allocator_wrappers.hpp",
    "chars": 2500,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/memory_utils.hpp",
    "chars": 2210,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/mi300a_allocator.cpp",
    "chars": 6660,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/mi300a_allocator.hpp",
    "chars": 2183,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/registry.cpp",
    "chars": 3799,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/memory/registry.hpp",
    "chars": 6588,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/CMakeLists.txt",
    "chars": 945,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/migrate.cpp",
    "chars": 5871,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/migrate.hpp",
    "chars": 2134,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/nonzero.hip",
    "chars": 9509,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/nonzero.hpp",
    "chars": 545,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/scalar.cpp",
    "chars": 2983,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/ops/scalar.hpp",
    "chars": 541,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/python/CMakeLists.txt",
    "chars": 667,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/python/register_lbannv2.cpp",
    "chars": 2152,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/python/register_memory_funcs.cpp",
    "chars": 2200,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/python/register_mi300a_ops.cpp",
    "chars": 1481,
    "preview": "// NOTE: this file is only compiled when LBANNV2_WITH_MI300A or\n// LBANNV2_UNKNOWN_MI300A, so the \"#else\" clauses below "
  },
  {
    "path": "src/lbannv2/types.hpp",
    "chars": 854,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/CMakeLists.txt",
    "chars": 566,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/debugging_helpers.hpp",
    "chars": 1452,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/errors.hpp",
    "chars": 1273,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/gpu_utils.cpp",
    "chars": 2442,
    "preview": "#include \"gpu_utils.hpp\"\n\n#include \"errors.hpp\"\n#include \"logging.hpp\"\n\nbool lbannv2::gpu::is_integrated() noexcept\n{\n#i"
  },
  {
    "path": "src/lbannv2/utils/gpu_utils.hpp",
    "chars": 3270,
    "preview": "#pragma once\n#include <lbannv2_config.h>\n\n#include <c10/core/Device.h>\n\n#if LBANNV2_HAS_CUDA\n\n#include <lbannv2/utils/lo"
  },
  {
    "path": "src/lbannv2/utils/logging.cpp",
    "chars": 3589,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/logging.hpp",
    "chars": 2883,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "src/lbannv2/utils/tensor_helpers.hpp",
    "chars": 3406,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/CMakeLists.txt",
    "chars": 992,
    "preview": "################################################################################\n## Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/cpp/test_empty_tensor.cpp",
    "chars": 7707,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/cpp/test_helpers.hpp",
    "chars": 1262,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/cpp/test_mi300a_allocator.cpp",
    "chars": 3804,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/cpp/test_pointer_registry.cpp",
    "chars": 15449,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  },
  {
    "path": "test/cpp/test_tensor_helpers.cpp",
    "chars": 3873,
    "preview": "////////////////////////////////////////////////////////////////////////////////\n// Copyright 2014-2025 Lawrence Livermo"
  }
]

About this extraction

This page contains the full source code of the LLNL/lbann GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (144.9 KB), approximately 38.8k tokens, and a symbol index with 107 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!