Showing preview only (3,850K chars total). Download the full file or copy to clipboard to get everything.
Repository: NVIDIA/apex
Branch: master
Commit: ba32a259b7aa
Files: 419
Total size: 3.6 MB
Directory structure:
gitextract_8yaiblk9/
├── .clang-format
├── .git-blame-ignore-revs
├── .github/
│ └── ISSUE_TEMPLATE/
│ └── bug_report.md
├── .gitignore
├── .gitmodules
├── .nojekyll
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── apex/
│ ├── __init__.py
│ ├── _autocast_utils.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── bottleneck/
│ │ │ ├── __init__.py
│ │ │ ├── bottleneck.py
│ │ │ ├── halo_exchangers.py
│ │ │ └── test.py
│ │ ├── clip_grad/
│ │ │ ├── __init__.py
│ │ │ └── clip_grad.py
│ │ ├── conv_bias_relu/
│ │ │ ├── __init__.py
│ │ │ └── conv_bias_relu.py
│ │ ├── csrc/
│ │ │ ├── bottleneck/
│ │ │ │ └── bottleneck.cpp
│ │ │ ├── conv_bias_relu/
│ │ │ │ └── conv_bias_relu.cpp
│ │ │ ├── cudnn_gbn/
│ │ │ │ ├── cudnn_gbn.cpp
│ │ │ │ ├── norm_sample.cpp
│ │ │ │ └── norm_sample.h
│ │ │ ├── fmha/
│ │ │ │ ├── fmha_api.cpp
│ │ │ │ └── src/
│ │ │ │ ├── fmha/
│ │ │ │ │ ├── gemm.h
│ │ │ │ │ ├── gmem_tile.h
│ │ │ │ │ ├── kernel_traits.h
│ │ │ │ │ ├── mask.h
│ │ │ │ │ ├── smem_tile.h
│ │ │ │ │ ├── softmax.h
│ │ │ │ │ └── utils.h
│ │ │ │ ├── fmha.h
│ │ │ │ ├── fmha_dgrad_fp16_128_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_256_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_384_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_512_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_kernel_1xN_reload.h
│ │ │ │ ├── fmha_dgrad_kernel_1xN_reload_nl.h
│ │ │ │ ├── fmha_fill.cu
│ │ │ │ ├── fmha_fprop_fp16_128_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_256_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_384_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_512_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_kernel_1xN.h
│ │ │ │ ├── fmha_kernel.h
│ │ │ │ ├── fmha_noloop_reduce.cu
│ │ │ │ └── fmha_utils.h
│ │ │ ├── focal_loss/
│ │ │ │ ├── focal_loss_cuda.cpp
│ │ │ │ └── focal_loss_cuda_kernel.cu
│ │ │ ├── gpu_direct_storage/
│ │ │ │ ├── gds.cpp
│ │ │ │ ├── gds.h
│ │ │ │ └── gds_pybind.cpp
│ │ │ ├── group_norm/
│ │ │ │ ├── group_norm_nhwc.cpp
│ │ │ │ ├── group_norm_nhwc.h
│ │ │ │ ├── group_norm_nhwc_bwd_one_pass.h
│ │ │ │ ├── group_norm_nhwc_bwd_one_pass_kernel.cuh
│ │ │ │ ├── group_norm_nhwc_bwd_two_pass.cu
│ │ │ │ ├── group_norm_nhwc_fwd_one_pass.h
│ │ │ │ ├── group_norm_nhwc_fwd_one_pass_kernel.cuh
│ │ │ │ ├── group_norm_nhwc_fwd_two_pass.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_10.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_112.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_12.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_120.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_128.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_14.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_16.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_160.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_20.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_24.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_26.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_28.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_30.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_32.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_4.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_40.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_42.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_48.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_56.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_60.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_64.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_70.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_8.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_80.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_84.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_96.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_98.cu
│ │ │ │ ├── group_norm_nhwc_op.cpp
│ │ │ │ ├── macros.h
│ │ │ │ └── traits.h
│ │ │ ├── group_norm_v2/
│ │ │ │ ├── generate_gn_cuda_inst.py
│ │ │ │ ├── gn.cpp
│ │ │ │ ├── gn.hpp
│ │ │ │ ├── gn_cuda.cu
│ │ │ │ ├── gn_cuda_host_template.cuh
│ │ │ │ ├── gn_cuda_inst_1024_1280.cu
│ │ │ │ ├── gn_cuda_inst_1024_1920.cu
│ │ │ │ ├── gn_cuda_inst_1024_320.cu
│ │ │ │ ├── gn_cuda_inst_1024_640.cu
│ │ │ │ ├── gn_cuda_inst_1024_960.cu
│ │ │ │ ├── gn_cuda_inst_256_1280.cu
│ │ │ │ ├── gn_cuda_inst_256_1920.cu
│ │ │ │ ├── gn_cuda_inst_256_2560.cu
│ │ │ │ ├── gn_cuda_inst_256_640.cu
│ │ │ │ ├── gn_cuda_inst_4096_320.cu
│ │ │ │ ├── gn_cuda_inst_4096_640.cu
│ │ │ │ ├── gn_cuda_inst_4096_960.cu
│ │ │ │ ├── gn_cuda_inst_64_1280.cu
│ │ │ │ ├── gn_cuda_inst_64_2560.cu
│ │ │ │ ├── gn_cuda_kernel.cuh
│ │ │ │ ├── gn_dispatch_hw_c.hpp
│ │ │ │ ├── gn_utils.cpp
│ │ │ │ └── gn_utils.hpp
│ │ │ ├── groupbn/
│ │ │ │ ├── batch_norm.cu
│ │ │ │ ├── batch_norm.h
│ │ │ │ ├── batch_norm_add_relu.cu
│ │ │ │ ├── batch_norm_add_relu.h
│ │ │ │ ├── cuda_utils.h
│ │ │ │ ├── interface.cpp
│ │ │ │ ├── ipc.cu
│ │ │ │ └── nhwc_batch_norm_kernel.h
│ │ │ ├── index_mul_2d/
│ │ │ │ ├── index_mul_2d_cuda.cpp
│ │ │ │ └── index_mul_2d_cuda_kernel.cu
│ │ │ ├── layer_norm/
│ │ │ │ ├── ln.h
│ │ │ │ ├── ln_api.cpp
│ │ │ │ ├── ln_bwd_kernels.cuh
│ │ │ │ ├── ln_bwd_semi_cuda_kernel.cu
│ │ │ │ ├── ln_fwd_cuda_kernel.cu
│ │ │ │ ├── ln_fwd_kernels.cuh
│ │ │ │ ├── ln_kernel_traits.h
│ │ │ │ └── ln_utils.cuh
│ │ │ ├── multihead_attn/
│ │ │ │ ├── additive_masked_softmax_dropout_cuda.cu
│ │ │ │ ├── dropout.cuh
│ │ │ │ ├── encdec_multihead_attn_cuda.cu
│ │ │ │ ├── encdec_multihead_attn_norm_add_cuda.cu
│ │ │ │ ├── layer_norm.cuh
│ │ │ │ ├── masked_softmax_dropout_cuda.cu
│ │ │ │ ├── multihead_attn_frontend.cpp
│ │ │ │ ├── philox.cuh
│ │ │ │ ├── self_multihead_attn_bias_additive_mask_cuda.cu
│ │ │ │ ├── self_multihead_attn_bias_cuda.cu
│ │ │ │ ├── self_multihead_attn_cuda.cu
│ │ │ │ ├── self_multihead_attn_norm_add_cuda.cu
│ │ │ │ ├── softmax.cuh
│ │ │ │ └── strided_batched_gemm.cuh
│ │ │ ├── nccl_allocator/
│ │ │ │ └── NCCLAllocator.cpp
│ │ │ ├── nccl_p2p/
│ │ │ │ ├── nccl_p2p.cpp
│ │ │ │ ├── nccl_p2p_cuda.cu
│ │ │ │ ├── nccl_p2p_cuda.cuh
│ │ │ │ ├── nccl_version.cpp
│ │ │ │ └── nccl_version_check.cu
│ │ │ ├── optimizers/
│ │ │ │ ├── fused_adam_cuda.cpp
│ │ │ │ ├── fused_adam_cuda_kernel.cu
│ │ │ │ ├── fused_lamb_cuda.cpp
│ │ │ │ ├── fused_lamb_cuda_kernel.cu
│ │ │ │ ├── multi_tensor_distopt_adam.cpp
│ │ │ │ ├── multi_tensor_distopt_adam_kernel.cu
│ │ │ │ ├── multi_tensor_distopt_lamb.cpp
│ │ │ │ └── multi_tensor_distopt_lamb_kernel.cu
│ │ │ ├── peer_memory/
│ │ │ │ ├── peer_memory.cpp
│ │ │ │ ├── peer_memory_cuda.cu
│ │ │ │ └── peer_memory_cuda.cuh
│ │ │ ├── transducer/
│ │ │ │ ├── transducer_joint.cpp
│ │ │ │ ├── transducer_joint_kernel.cu
│ │ │ │ ├── transducer_loss.cpp
│ │ │ │ └── transducer_loss_kernel.cu
│ │ │ └── xentropy/
│ │ │ ├── interface.cpp
│ │ │ └── xentropy_kernel.cu
│ │ ├── cudnn_gbn/
│ │ │ ├── __init__.py
│ │ │ └── batch_norm.py
│ │ ├── examples/
│ │ │ ├── gpu_direct_storage/
│ │ │ │ ├── benchmark_load.py
│ │ │ │ ├── benchmark_save.py
│ │ │ │ ├── example_load.py
│ │ │ │ └── example_save.py
│ │ │ ├── multihead_attn/
│ │ │ │ ├── func_test_multihead_attn.py
│ │ │ │ └── perf_test_multihead_attn.py
│ │ │ └── nccl_allocator/
│ │ │ ├── allreduce.py
│ │ │ ├── cache.py
│ │ │ ├── change_cuda_allocator.py
│ │ │ └── toy_ddp.py
│ │ ├── fmha/
│ │ │ ├── __init__.py
│ │ │ └── fmha.py
│ │ ├── focal_loss/
│ │ │ ├── __init__.py
│ │ │ └── focal_loss.py
│ │ ├── gpu_direct_storage/
│ │ │ ├── README.md
│ │ │ └── __init__.py
│ │ ├── group_norm/
│ │ │ ├── __init__.py
│ │ │ └── group_norm.py
│ │ ├── groupbn/
│ │ │ ├── __init__.py
│ │ │ └── batch_norm.py
│ │ ├── index_mul_2d/
│ │ │ ├── __init__.py
│ │ │ └── index_mul_2d.py
│ │ ├── layer_norm/
│ │ │ ├── __init__.py
│ │ │ └── layer_norm.py
│ │ ├── multihead_attn/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── encdec_multihead_attn.py
│ │ │ ├── encdec_multihead_attn_func.py
│ │ │ ├── fast_encdec_multihead_attn_func.py
│ │ │ ├── fast_encdec_multihead_attn_norm_add_func.py
│ │ │ ├── fast_self_multihead_attn_func.py
│ │ │ ├── fast_self_multihead_attn_norm_add_func.py
│ │ │ ├── mask_softmax_dropout_func.py
│ │ │ ├── self_multihead_attn.py
│ │ │ └── self_multihead_attn_func.py
│ │ ├── nccl_allocator/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ └── nccl_allocator.py
│ │ ├── openfold_triton/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── _layer_norm_backward_kernels.py
│ │ │ ├── _layer_norm_config_ampere.py
│ │ │ ├── _layer_norm_config_hopper.py
│ │ │ ├── _layer_norm_forward_kernels.py
│ │ │ ├── _mha_kernel.py
│ │ │ ├── fused_adam_swa.py
│ │ │ ├── layer_norm.py
│ │ │ └── mha.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── distributed_fused_adam.py
│ │ │ ├── distributed_fused_lamb.py
│ │ │ ├── fp16_optimizer.py
│ │ │ ├── fused_adam.py
│ │ │ ├── fused_lamb.py
│ │ │ └── fused_sgd.py
│ │ ├── peer_memory/
│ │ │ ├── __init__.py
│ │ │ ├── peer_halo_exchanger_1d.py
│ │ │ └── peer_memory.py
│ │ ├── sparsity/
│ │ │ ├── COPYRIGHT
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── asp.py
│ │ │ ├── permutation_lib.py
│ │ │ ├── permutation_search_kernels/
│ │ │ │ ├── CUDA_kernels/
│ │ │ │ │ └── permutation_search_kernels.cu
│ │ │ │ ├── __init__.py
│ │ │ │ ├── call_permutation_search_kernels.py
│ │ │ │ ├── channel_swap.py
│ │ │ │ ├── exhaustive_search.py
│ │ │ │ └── permutation_utilities.py
│ │ │ ├── permutation_tests/
│ │ │ │ ├── README.md
│ │ │ │ ├── ablation_studies.sh
│ │ │ │ ├── permutation_test.py
│ │ │ │ ├── runtime_table.sh
│ │ │ │ └── unstructured_study.sh
│ │ │ ├── sparse_masklib.py
│ │ │ └── test/
│ │ │ ├── checkpointing_test_part1.py
│ │ │ ├── checkpointing_test_part2.py
│ │ │ ├── checkpointing_test_reference.py
│ │ │ ├── test_permutation_application.py
│ │ │ └── toy_problem.py
│ │ ├── test/
│ │ │ ├── __init__.py
│ │ │ ├── bottleneck/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_bottleneck_module.py
│ │ │ ├── clip_grad/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_clip_grad.py
│ │ │ ├── conv_bias_relu/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_conv_bias_relu.py
│ │ │ ├── cudnn_gbn/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_cudnn_gbn_with_two_gpus.py
│ │ │ ├── fmha/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_fmha.py
│ │ │ ├── focal_loss/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_focal_loss.py
│ │ │ ├── fused_dense/
│ │ │ │ └── test_fused_dense.py
│ │ │ ├── group_norm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_group_norm.py
│ │ │ ├── index_mul_2d/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_index_mul_2d.py
│ │ │ ├── layer_norm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_fast_layer_norm.py
│ │ │ ├── multihead_attn/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_encdec_multihead_attn.py
│ │ │ │ ├── test_encdec_multihead_attn_norm_add.py
│ │ │ │ ├── test_fast_self_multihead_attn_bias.py
│ │ │ │ ├── test_mha_fused_softmax.py
│ │ │ │ ├── test_self_multihead_attn.py
│ │ │ │ └── test_self_multihead_attn_norm_add.py
│ │ │ ├── openfold_triton/
│ │ │ │ ├── test_fused_adam_swa.py
│ │ │ │ ├── test_openfold_mha.py
│ │ │ │ └── test_sync_triton_auto_tune_cache_across_gpus.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_dist_adam.py
│ │ │ │ └── test_distributed_fused_lamb.py
│ │ │ ├── peer_memory/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_peer_halo_exchange_module.py
│ │ │ ├── transducer/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_transducer_joint.py
│ │ │ │ └── test_transducer_loss.py
│ │ │ └── xentropy/
│ │ │ ├── __init__.py
│ │ │ └── test_label_smoothing.py
│ │ ├── torchsched/
│ │ │ ├── __init__.py
│ │ │ ├── backend.py
│ │ │ ├── config.py
│ │ │ ├── inductor/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _utils.py
│ │ │ │ ├── event.py
│ │ │ │ ├── graph.py
│ │ │ │ ├── scheduler.py
│ │ │ │ └── wrapper.py
│ │ │ ├── ops/
│ │ │ │ ├── __init__.py
│ │ │ │ └── layer_norm.py
│ │ │ └── passes/
│ │ │ ├── __init__.py
│ │ │ └── pre_grad_passes.py
│ │ ├── transducer/
│ │ │ ├── __init__.py
│ │ │ ├── _transducer_ref.py
│ │ │ └── transducer.py
│ │ └── xentropy/
│ │ ├── __init__.py
│ │ └── softmax_xentropy.py
│ ├── distributed_testing/
│ │ ├── __init__.py
│ │ ├── _ucc_util.py
│ │ └── distributed_test_base.py
│ ├── fused_dense/
│ │ ├── __init__.py
│ │ └── fused_dense.py
│ ├── mlp/
│ │ ├── __init__.py
│ │ └── mlp.py
│ ├── multi_tensor_apply/
│ │ ├── __init__.py
│ │ └── multi_tensor_apply.py
│ ├── normalization/
│ │ ├── __init__.py
│ │ └── fused_layer_norm.py
│ └── optimizers/
│ ├── __init__.py
│ ├── fused_adagrad.py
│ ├── fused_adam.py
│ ├── fused_lamb.py
│ ├── fused_mixed_precision_lamb.py
│ ├── fused_novograd.py
│ └── fused_sgd.py
├── csrc/
│ ├── amp_C_frontend.cpp
│ ├── flatten_unflatten.cpp
│ ├── fused_dense.cpp
│ ├── fused_dense_cuda.cu
│ ├── layer_norm_cuda.cpp
│ ├── layer_norm_cuda_kernel.cu
│ ├── megatron/
│ │ ├── fused_rotary_positional_embedding.cpp
│ │ ├── fused_rotary_positional_embedding.h
│ │ ├── fused_rotary_positional_embedding_cuda.cu
│ │ ├── fused_weight_gradient_dense.cpp
│ │ ├── fused_weight_gradient_dense_16bit_prec_cuda.cu
│ │ ├── fused_weight_gradient_dense_cuda.cu
│ │ ├── generic_scaled_masked_softmax.cpp
│ │ ├── generic_scaled_masked_softmax.h
│ │ ├── generic_scaled_masked_softmax_cuda.cu
│ │ ├── scaled_masked_softmax.cpp
│ │ ├── scaled_masked_softmax.h
│ │ ├── scaled_masked_softmax_cuda.cu
│ │ ├── scaled_softmax.cpp
│ │ ├── scaled_softmax_cuda.cu
│ │ ├── scaled_upper_triang_masked_softmax.cpp
│ │ ├── scaled_upper_triang_masked_softmax.h
│ │ └── scaled_upper_triang_masked_softmax_cuda.cu
│ ├── mlp.cpp
│ ├── mlp_cuda.cu
│ ├── multi_tensor_adagrad.cu
│ ├── multi_tensor_adam.cu
│ ├── multi_tensor_apply.cuh
│ ├── multi_tensor_axpby_kernel.cu
│ ├── multi_tensor_l2norm_kernel.cu
│ ├── multi_tensor_l2norm_kernel_mp.cu
│ ├── multi_tensor_l2norm_scale_kernel.cu
│ ├── multi_tensor_lamb.cu
│ ├── multi_tensor_lamb_mp.cu
│ ├── multi_tensor_lamb_stage_1.cu
│ ├── multi_tensor_lamb_stage_2.cu
│ ├── multi_tensor_novograd.cu
│ ├── multi_tensor_scale_kernel.cu
│ ├── multi_tensor_sgd_kernel.cu
│ ├── static_switch.h
│ ├── syncbn.cpp
│ ├── type_shim.h
│ ├── update_scale_hysteresis.cu
│ └── welford.cu
├── docs/
│ ├── Makefile
│ └── source/
│ ├── _static/
│ │ └── css/
│ │ └── pytorch_theme.css
│ ├── _templates/
│ │ └── layout.html
│ ├── conf.py
│ ├── index.rst
│ ├── layernorm.rst
│ └── optimizers.rst
├── examples/
│ ├── README.md
│ ├── dcgan/
│ │ ├── README.md
│ │ └── main_amp.py
│ ├── docker/
│ │ ├── Dockerfile
│ │ └── README.md
│ ├── imagenet/
│ │ ├── README.md
│ │ └── main_amp.py
│ └── simple/
│ └── distributed/
│ ├── README.md
│ ├── distributed_data_parallel.py
│ └── run.sh
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
├── setup.py
└── tests/
├── L0/
│ ├── run_fused_layer_norm/
│ │ └── test_fused_layer_norm.py
│ ├── run_mlp/
│ │ └── test_mlp.py
│ ├── run_optimizers/
│ │ ├── __init__.py
│ │ ├── test_adam.py
│ │ ├── test_fused_novograd.py
│ │ ├── test_fused_optimizer.py
│ │ └── test_lamb.py
│ └── run_test.py
├── L1/
│ ├── common/
│ │ ├── compare.py
│ │ ├── main_amp.py
│ │ └── run_test.sh
│ ├── cross_product/
│ │ └── run.sh
│ └── cross_product_distributed/
│ └── run.sh
├── distributed/
│ ├── DDP/
│ │ ├── ddp_race_condition_test.py
│ │ └── run_race_test.sh
│ ├── amp_master_params/
│ │ ├── amp_master_params.py
│ │ ├── compare.py
│ │ └── run.sh
│ └── synced_batchnorm/
│ ├── python_single_gpu_unit_test.py
│ ├── single_gpu_unit_test.py
│ ├── test_batchnorm1d.py
│ ├── test_groups.py
│ ├── two_gpu_test_different_batch_size.py
│ ├── two_gpu_unit_test.py
│ └── unit_test.sh
└── docker_extension_builds/
└── run.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .clang-format
================================================
# Start with a built-in style and modify it
BasedOnStyle: Google
# Overrides
ColumnLimit: 120
================================================
FILE: .git-blame-ignore-revs
================================================
# Commits to ignore in git-blame
# These commits are bulk formatting or refactoring changes that should be skipped when viewing blame history
# Add pre-commit and GitHub Actions workflow for it (#1949)
1f20398756f0eeba37d6887a2d3f65e0687ec94f
# Remove github actions config of pre-commit in favor of pre-commit ci (#1958)
27e0e8951352d9d58c88b2895cd8f2c752bda963
# Enable Ruff pre-commit hooks (#1957)
16fadfe71c0d57312351c2d8b056251a0c8ce1ef
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve apex
title: ''
labels: bug
assignees: ''
---
**Describe the Bug**
**Minimal Steps/Code to Reproduce the Bug**
<!--
Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
-->
**Expected Behavior**
<!-- A clear and concise description of what you expected to happen. -->
**Environment**
<!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->
================================================
FILE: .gitignore
================================================
apex.egg-info
dist
build
docs/build
*~
__pycache__
.vscode
# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
================================================
FILE: .gitmodules
================================================
[submodule "apex/contrib/csrc/multihead_attn/cutlass"]
path = apex/contrib/csrc/multihead_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v1.2.0
[submodule "apex/contrib/csrc/cudnn-frontend"]
path = apex/contrib/csrc/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
================================================
FILE: .nojekyll
================================================
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v22.1.1 # Or pin to your preferred clang-format version
hooks:
- id: clang-format
files: \.(c|h|cpp|hpp|proto|cu|cuh)$
exclude: ^(apex/contrib/csrc/multihead_attn/cutlass|apex/contrib/csrc/cudnn-frontend)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.6
hooks:
- id: ruff-check
args: ["--fix"]
- id: ruff-format
types_or: [python]
exclude: "examples"
================================================
FILE: LICENSE
================================================
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# Introduction
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
# Installation
Each [`apex.contrib`](./apex/contrib) module requires one or more install options other than `--cpp_ext` and `--cuda_ext`.
Note that contrib modules do not necessarily support stable PyTorch releases, some of them might only be compatible with nightlies.
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
## From Source
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
The latest stable release obtainable from https://pytorch.org should also work.
We recommend installing [`Ninja`](https://ninja-build.org/) to make compilation faster.
### Linux
For performance and full functionality, we recommend installing Apex with CUDA and C++ extensions using environment variables:
#### Using Environment Variables (Recommended)
```bash
git clone https://github.com/NVIDIA/apex
cd apex
# Build with core extensions (cpp and cuda)
APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation .
# To build with additional extensions, specify them with environment variables
APEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_FAST_MULTIHEAD_ATTN=1 APEX_FUSED_CONV_BIAS_RELU=1 pip install -v --no-build-isolation .
# To build all contrib extensions at once
APEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_ALL_CONTRIB_EXT=1 pip install -v --no-build-isolation .
```
To reduce the build time, parallel building can be enabled:
```bash
NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation .
```
When CPU cores or memory are limited, the `--parallel` option is generally preferred over `--threads`. See [pull#1882](https://github.com/NVIDIA/apex/pull/1882) for more details.
#### Using Command-Line Flags (Legacy Method)
The traditional command-line flags are still supported:
```bash
# Using pip config-settings (pip >= 23.1)
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# For older pip versions
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# To build with additional extensions
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./
```
#### Python-Only Build
APEX also supports a Python-only build via:
```bash
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
## Custom C++/CUDA Extensions and Install Options
If a requirement of a module is not met, then it will not be built.
| Module Name | Environment Variable | Install Option | Misc |
|---------------|------------------------|------------------|--------|
| `apex_C` | `APEX_CPP_EXT=1` | `--cpp_ext` | |
| `amp_C` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `syncbn` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `fused_layer_norm_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | [`apex.normalization`](./apex/normalization) |
| `mlp_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `scaled_upper_triang_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `generic_scaled_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `scaled_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | |
| `fused_weight_gradient_mlp_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | Requires CUDA>=11 |
| `permutation_search_cuda` | `APEX_PERMUTATION_SEARCH=1` | `--permutation_search` | [`apex.contrib.sparsity`](./apex/contrib/sparsity) |
| `bnp` | `APEX_BNP=1` | `--bnp` | [`apex.contrib.groupbn`](./apex/contrib/groupbn) |
| `xentropy` | `APEX_XENTROPY=1` | `--xentropy` | [`apex.contrib.xentropy`](./apex/contrib/xentropy) |
| `focal_loss_cuda` | `APEX_FOCAL_LOSS=1` | `--focal_loss` | [`apex.contrib.focal_loss`](./apex/contrib/focal_loss) |
| `fused_index_mul_2d` | `APEX_INDEX_MUL_2D=1` | `--index_mul_2d` | [`apex.contrib.index_mul_2d`](./apex/contrib/index_mul_2d) |
| `fused_adam_cuda` | `APEX_DEPRECATED_FUSED_ADAM=1` | `--deprecated_fused_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
| `fused_lamb_cuda` | `APEX_DEPRECATED_FUSED_LAMB=1` | `--deprecated_fused_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
| `fast_layer_norm` | `APEX_FAST_LAYER_NORM=1` | `--fast_layer_norm` | [`apex.contrib.layer_norm`](./apex/contrib/layer_norm). different from `fused_layer_norm` |
| `fmhalib` | `APEX_FMHA=1` | `--fmha` | [`apex.contrib.fmha`](./apex/contrib/fmha) |
| `fast_multihead_attn` | `APEX_FAST_MULTIHEAD_ATTN=1` | `--fast_multihead_attn` | [`apex.contrib.multihead_attn`](./apex/contrib/multihead_attn) |
| `transducer_joint_cuda` | `APEX_TRANSDUCER=1` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
| `transducer_loss_cuda` | `APEX_TRANSDUCER=1` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
| `cudnn_gbn_lib` | `APEX_CUDNN_GBN=1` | `--cudnn_gbn` | Requires cuDNN>=8.5, [`apex.contrib.cudnn_gbn`](./apex/contrib/cudnn_gbn) |
| `peer_memory_cuda` | `APEX_PEER_MEMORY=1` | `--peer_memory` | [`apex.contrib.peer_memory`](./apex/contrib/peer_memory) |
| `nccl_p2p_cuda` | `APEX_NCCL_P2P=1` | `--nccl_p2p` | Requires NCCL >= 2.10, [`apex.contrib.nccl_p2p`](./apex/contrib/nccl_p2p) |
| `fast_bottleneck` | `APEX_FAST_BOTTLENECK=1` | `--fast_bottleneck` | Requires `peer_memory_cuda` and `nccl_p2p_cuda`, [`apex.contrib.bottleneck`](./apex/contrib/bottleneck) |
| `fused_conv_bias_relu` | `APEX_FUSED_CONV_BIAS_RELU=1` | `--fused_conv_bias_relu` | Requires cuDNN>=8.4, [`apex.contrib.conv_bias_relu`](./apex/contrib/conv_bias_relu) |
| `distributed_adam_cuda` | `APEX_DISTRIBUTED_ADAM=1` | `--distributed_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
| `distributed_lamb_cuda` | `APEX_DISTRIBUTED_LAMB=1` | `--distributed_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
| `_apex_nccl_allocator` | `APEX_NCCL_ALLOCATOR=1` | `--nccl_allocator` | Requires NCCL >= 2.19, [`apex.contrib.nccl_allocator`](./apex/contrib/nccl_allocator) |
| `_apex_gpu_direct_storage` | `APEX_GPU_DIRECT_STORAGE=1` | `--gpu_direct_storage` | [`apex.contrib.gpu_direct_storage`](./apex/contrib/gpu_direct_storage) |
You can also build all contrib extensions at once by setting `APEX_ALL_CONTRIB_EXT=1`.
================================================
FILE: apex/__init__.py
================================================
import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
# For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error.
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
# so they expect those backends to be available, but for some reason they actually aren't
# available (for example because they built improperly in a way that isn't revealed until
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
__all__ = ["optimizers", "normalization"]
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
class DeprecatedFeatureWarning(FutureWarning):
pass
def deprecated_warning(msg: str) -> None:
if (
not torch.distributed.is_available
or not torch.distributed.is_initialized()
or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
):
warnings.warn(msg, DeprecatedFeatureWarning)
================================================
FILE: apex/_autocast_utils.py
================================================
from typing import Optional, Sequence
import torch
__all__ = ["_cast_if_autocast_enabled"]
def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
return [torch.half]
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
================================================
FILE: apex/contrib/__init__.py
================================================
================================================
FILE: apex/contrib/bottleneck/__init__.py
================================================
from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import (
HaloExchangerNoComm,
HaloExchangerAllGather,
HaloExchangerSendRecv,
HaloExchangerPeer,
)
================================================
FILE: apex/contrib/bottleneck/bottleneck.py
================================================
import functools as func
import torch
from torch import nn
from apex import check_cudnn_version_and_warn
import fast_bottleneck
import nccl_p2p_cuda as inc
assert check_cudnn_version_and_warn(__name__, 8400)
def kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu"):
weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
@torch.jit.script_method
def get_scale_bias(self, nhwc):
# type: (bool) -> List[torch.Tensor]
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
if nhwc:
scale = scale.reshape(1, 1, 1, -1)
bias = bias.reshape(1, 1, 1, -1)
else:
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return scale, bias
@torch.jit.script_method
def forward(self, x):
scale, bias = self.get_scale_bias(False)
return x * scale + bias
@torch.jit.script
def drelu_dscale1(grad_o, output, scale1):
relu_mask = output > 0
dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1
return g1, dx_relu
@torch.jit.script
def drelu_dscale2(grad_o, output, scale1, scale2):
relu_mask = output > 0
dx_relu = relu_mask * grad_o
g1 = dx_relu * scale1
g2 = dx_relu * scale2
return g1, g2
class BottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
if ctx.downsample:
args.append(conv[3])
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)
ctx.save_for_backward(*(args + outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1
return outputs[2]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(
grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]
)
else:
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
# create input vector for backward
t_list = [*ctx.saved_tensors[0:10]]
t_list.append(grad_conv3)
t_list.append(grad_conv4)
# outputs used for wgrad and generating drelu mask
t_list.append(outputs[0])
t_list.append(outputs[1])
# in case there is downsample
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)
return (None, None, None, None, *grads)
bottleneck_function = BottleneckFunction.apply
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
stride=1,
groups=1,
dilation=1,
norm_func=None,
use_cudnn=False,
explicit_nhwc=False,
):
super(Bottleneck, self).__init__()
if groups != 1:
raise RuntimeError("Only support groups == 1")
if dilation != 1:
raise RuntimeError("Only support dilation == 1")
if norm_func == None:
norm_func = FrozenBatchNorm2d
else:
raise RuntimeError("Only support frozen BN now.")
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride),
norm_func(out_channels),
)
else:
self.downsample = None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
self.conv3 = conv1x1(bottleneck_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
# setup conv weights
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
if self.downsample is not None:
self.w_conv.append(self.downsample[0].weight)
# init weight in nchw format before possible transpose
for w in self.w_conv:
kaiming_uniform_(w, a=1)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self.explicit_nhwc = explicit_nhwc
if self.explicit_nhwc:
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0, 2, 3, 1).contiguous()
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b))
if self.explicit_nhwc:
self.w_scale.append(s.reshape(1, 1, 1, -1))
self.w_bias.append(b.reshape(1, 1, 1, -1))
else:
self.w_scale.append(s.reshape(1, -1, 1, 1))
self.w_bias.append(b.reshape(1, -1, 1, 1))
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(
self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv
)
else:
out = bottleneck_function(
self.explicit_nhwc,
self.stride,
self.w_scale,
self.w_bias,
x,
*self.w_conv,
)
return out
if self.explicit_nhwc:
raise RuntimeError("explicit nhwc with native ops is not supported.")
# fallback to native ops
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
spatial_group_size,
spatial_group_rank,
spatial_communicator,
spatial_halo_exchanger,
spatial_method,
use_delay_kernel,
explicit_nhwc,
stride_1x1,
scale,
bias,
thresholdTop,
thresholdBottom,
x,
*conv,
):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
if ctx.downsample:
args.append(conv[3])
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
if spatial_group_size > 1:
out1 = outputs[0]
if explicit_nhwc:
N, Hs, W, C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N, Hs + 2, W, C], dtype=out1.dtype, device="cuda")
else:
N, C, Hs, W = list(out1.shape)
memory_format = (
torch.channels_last
if out1.is_contiguous(memory_format=torch.channels_last)
else torch.contiguous_format
)
out1_pad = torch.empty(
[N, C, Hs + 2, W],
dtype=out1.dtype,
device="cuda",
memory_format=memory_format,
)
stream1.wait_stream(torch.cuda.current_stream())
if spatial_method != 2:
stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_out1_halo = out1_pad[:, :1, :, :]
btm_out1_halo = out1_pad[:, Hs + 1 : Hs + 2, :, :]
spatial_halo_exchanger.left_right_halo_exchange(
out1[:, :1, :, :],
out1[:, Hs - 1 :, :, :],
top_out1_halo,
btm_out1_halo,
)
else:
top_out1_halo = out1_pad[:, :, :1, :]
btm_out1_halo = out1_pad[:, :, Hs + 1 : Hs + 2, :]
spatial_halo_exchanger.left_right_halo_exchange(
out1[:, :, :1, :],
out1[:, :, Hs - 1 :, :],
top_out1_halo,
btm_out1_halo,
)
if spatial_method == 1:
# overlap mid convolution with halo transfer
if spatial_group_rank < spatial_group_size - 1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
if explicit_nhwc:
btm_fat_halo = torch.empty(
(N, 3, W, C), dtype=out1.dtype, device=out1.device
)
btm_fat_halo[:, 0:2, :, :].copy_(out1[:, Hs - 2 :, :, :])
btm_fat_halo[:, 2:, :, :].copy_(btm_out1_halo)
else:
btm_fat_halo = torch.empty(
(N, C, 3, W), dtype=out1.dtype, device=out1.device
)
btm_fat_halo[:, :, 0:2, :].copy_(out1[:, :, Hs - 2 :, :])
btm_fat_halo[:, :, 2:, :].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(
explicit_nhwc, btm_fat_halo, args
)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty(
(N, 3, W, C), dtype=out1.dtype, device=out1.device
)
top_fat_halo[:, :1, :, :].copy_(top_out1_halo)
top_fat_halo[:, 1:3, :, :].copy_(out1[:, :2, :, :])
else:
top_fat_halo = torch.empty(
(N, C, 3, W), dtype=out1.dtype, device=out1.device
)
top_fat_halo[:, :, :1, :].copy_(top_out1_halo)
top_fat_halo[:, :, 1:3, :].copy_(out1[:, :, :2, :])
top_out2 = fast_bottleneck.forward_out2_halo(
explicit_nhwc, top_fat_halo, args
)
if use_delay_kernel:
inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert False, "spatial_method must be 1, 2 or 3"
if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)
else:
out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
if explicit_nhwc:
out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)
else:
out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(
explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom
)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)
else:
out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
out2 = outputs[1]
if explicit_nhwc:
top_out2_halo = out2[:, :1, :, :]
btm_out2_halo = out2[:, Hs - 1 :, :, :]
else:
top_out2_halo = out2[:, :, :1, :]
btm_out2_halo = out2[:, :, Hs - 1 :, :]
if spatial_method == 1:
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size - 1:
torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank < spatial_group_size - 1:
stream2.wait_stream(stream1) # wait for halo transfers to finish
stream2.wait_stream(
torch.cuda.current_stream()
) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2):
w1by3 = args[2][:, 2:3, :, :].clone()
btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(
explicit_nhwc,
btm_out1_halo,
args,
w1by3,
btm_out2_halo.clone(),
)
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
stream1.wait_stream(
torch.cuda.current_stream()
) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:, :1, :, :].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(
explicit_nhwc,
top_out1_halo,
args,
w1by3,
top_out2_halo.clone(),
)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size - 1:
torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(
*(
args
+ outputs
+ [
out1_pad,
]
)
)
else:
ctx.save_for_backward(*(args + outputs))
# save relu outputs for drelu
ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size
if spatial_group_size > 1:
ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.use_delay_kernel = use_delay_kernel
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1
ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
out1_pad = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
if ctx.downsample:
grad_conv3, grad_conv4 = drelu_dscale2(
grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]
)
else:
grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
# create input vector for backward
t_list = [*ctx.saved_tensors[0:10]]
t_list.append(grad_conv3)
t_list.append(grad_conv4)
# outputs used for wgrad and generating drelu mask
t_list.append(outputs[0])
t_list.append(outputs[1])
# in case there is downsample
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
wgrad3_stream = torch.cuda.Stream()
wgrad3_stream.wait_stream(torch.cuda.current_stream())
grad_out2 = fast_bottleneck.backward_grad_out2(
ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads
)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
if ctx.explicit_nhwc:
N, Hs, W, C = list(grad_out2.shape)
else:
N, C, Hs, W = list(grad_out2.shape)
relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(
grad_out2[:, :1, :, :], grad_out2[:, Hs - 1 :, :, :]
)
# copy halos to send buffer
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
# 1 -> halo recompute approach
# 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
if ctx.spatial_group_rank < ctx.spatial_group_size - 1:
ctx.stream2.wait_stream(ctx.stream1)
with torch.cuda.stream(ctx.stream2):
if ctx.explicit_nhwc:
btm_fat_halo = torch.empty(
(N, 3, W, C),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
btm_fat_halo[:, :2, :, :].copy_(grad_out2[:, Hs - 2 :, :, :])
btm_fat_halo[:, 2:, :, :].copy_(btm_halo)
btm_fat_relu_halo = torch.empty(
(N, 3, W, C),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
btm_fat_relu_halo[:, :2, :, :].copy_(relu1[:, Hs - 2 :, :, :])
btm_fat_relu_halo[:, 2:, :, :].zero_()
else:
btm_fat_halo = torch.empty(
(N, C, 3, W),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
btm_fat_halo[:, :, :2, :].copy_(grad_out2[:, :, Hs - 2 :, :])
btm_fat_halo[:, :, 2:, :].copy_(btm_halo)
btm_fat_relu_halo = torch.empty(
(N, C, 3, W),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
btm_fat_relu_halo[:, :, :2, :].copy_(relu1[:, :, Hs - 2 :, :])
btm_fat_relu_halo[:, :, 2:, :].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
grads,
btm_fat_halo,
btm_fat_relu_halo,
)
if ctx.explicit_nhwc:
btm_grad_out1_halo = btm_grad_out1_halo[:, 1:2, :, :]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:, :, 1:2, :]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
if ctx.explicit_nhwc:
top_fat_halo = torch.empty(
(N, 3, W, C),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
top_fat_halo[:, :1, :, :].copy_(top_halo)
top_fat_halo[:, 1:, :, :].copy_(grad_out2[:, :2, :, :])
top_fat_relu_halo = torch.empty(
(N, 3, W, C),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
top_fat_relu_halo[:, :1, :, :].zero_()
top_fat_relu_halo[:, 1:, :, :].copy_(relu1[:, :2, :, :])
else:
top_fat_halo = torch.empty(
(N, C, 3, W),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
top_fat_halo[:, :, :1, :].copy_(top_halo)
top_fat_halo[:, :, 1:, :].copy_(grad_out2[:, :, :2, :])
top_fat_relu_halo = torch.empty(
(N, C, 3, W),
dtype=grad_out2.dtype,
device=grad_out2.device,
)
top_fat_relu_halo[:, :, :1, :].zero_()
top_fat_relu_halo[:, :, 1:, :].copy_(relu1[:, :, :2, :])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
grads,
top_fat_halo,
top_fat_relu_halo,
)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:, 1:2, :, :]
else:
top_grad_out1_halo = top_grad_out1_halo[:, :, 1:2, :]
if ctx.use_delay_kernel:
inc.add_delay(10)
elif ctx.spatial_method != 3:
assert False, "spatial_method must be 1, 2 or 3"
# compute grad_out1 for internal cells
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(
ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2
)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
grads,
grad_out2,
ctx.thresholdTop,
ctx.thresholdBottom,
)
# apply halo cells to grad_out1
if ctx.spatial_group_size > 1:
w = t_list[2]
z = t_list[4]
relu1 = t_list[12]
# print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
if ctx.spatial_group_rank < ctx.spatial_group_size - 1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
if ctx.explicit_nhwc:
grad_out1[:, Hs - 1 :, :, :].copy_(btm_grad_out1_halo)
else:
grad_out1[:, :, Hs - 1 :, :].copy_(btm_grad_out1_halo)
# print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.explicit_nhwc:
grad_out1[:, :1, :, :].copy_(top_grad_out1_halo)
else:
grad_out1[:, :, :1, :].copy_(top_grad_out1_halo)
# print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size - 1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:, Hs - 1 :, :, :].clone()
btm_grad_out1 = grad_out1[:, Hs - 1 :, :, :]
else:
btm_relu_halo = relu1[:, :, Hs - 1 :, :].clone()
btm_grad_out1 = grad_out1[:, :, Hs - 1 :, :]
w1by3 = w[:, :1, :, :].clone()
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
ctx.stream2.wait_stream(
torch.cuda.current_stream()
) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream2):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
w1by3,
grads,
btm_halo,
btm_relu_halo,
btm_grad_out1.clone(),
)
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:, :1, :, :].clone()
top_grad_out1 = grad_out1[:, :1, :, :]
else:
top_relu_halo = relu1[:, :, :1, :].clone()
top_grad_out1 = grad_out1[:, :, :1, :]
w1by3 = w[:, 2:, :, :].clone()
ctx.stream1.wait_stream(
torch.cuda.current_stream()
) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
w1by3,
grads,
top_halo,
top_relu_halo,
top_grad_out1.clone(),
)
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size - 1:
torch.cuda.current_stream().wait_stream(
ctx.stream2
) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
wgrad1_stream = torch.cuda.Stream()
wgrad1_stream.wait_stream(torch.cuda.current_stream())
fast_bottleneck.backward_rest(
ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1
)
with torch.cuda.stream(wgrad3_stream):
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
fast_bottleneck.backward_wgrad2_pad(
ctx.explicit_nhwc,
ctx.stride_1x1,
t_list,
grads,
out1_pad,
grad_out2,
)
else:
fast_bottleneck.backward_wgrad2(
ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2
)
with torch.cuda.stream(wgrad1_stream):
fast_bottleneck.backward_wgrad1(
ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1
)
torch.cuda.current_stream().wait_stream(wgrad3_stream)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
torch.cuda.current_stream().wait_stream(wgrad1_stream)
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
*grads,
)
spatial_bottleneck_function = SpatialBottleneckFunction.apply
class SpatialBottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
stride=1,
groups=1,
dilation=1,
norm_func=None,
use_cudnn=False,
explicit_nhwc=False,
spatial_parallel_args=None,
):
super(SpatialBottleneck, self).__init__()
if groups != 1:
raise RuntimeError("Only support groups == 1")
if dilation != 1:
raise RuntimeError("Only support dilation == 1")
if norm_func == None:
norm_func = FrozenBatchNorm2d
else:
raise RuntimeError("Only support frozen BN now.")
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride),
norm_func(out_channels),
)
else:
self.downsample = None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
self.conv3 = conv1x1(bottleneck_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
# setup conv weights
self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
if self.downsample is not None:
self.w_conv.append(self.downsample[0].weight)
# init weight in nchw format before possible transpose
for w in self.w_conv:
kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self.explicit_nhwc = explicit_nhwc
if self.explicit_nhwc:
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0, 2, 3, 1).contiguous()
# spatial communicator
if spatial_parallel_args is None:
self.spatial_parallel_args = (1, 0, None, None, 0, False)
else:
self.spatial_parallel_args = spatial_parallel_args
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b))
if self.explicit_nhwc:
self.w_scale.append(s.reshape(1, 1, 1, -1))
self.w_bias.append(b.reshape(1, 1, 1, -1))
else:
self.w_scale.append(s.reshape(1, -1, 1, 1))
self.w_bias.append(b.reshape(1, -1, 1, 1))
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N, H, W, C = list(x.shape)
else:
N, C, H, W = list(x.shape)
self.thresholdTop = torch.tensor(
[1 if spatial_group_rank > 0 else 0],
dtype=torch.int32,
device="cuda",
)
self.thresholdBottom = torch.tensor(
[H - 2 if spatial_group_rank < spatial_group_size - 1 else H - 1],
dtype=torch.int32,
device="cuda",
)
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(
*self.spatial_parallel_args,
self.explicit_nhwc,
self.stride,
w_scale,
w_bias,
self.thresholdTop,
self.thresholdBottom,
x,
*self.w_conv,
)
else:
out = spatial_bottleneck_function(
*self.spatial_parallel_args,
self.explicit_nhwc,
self.stride,
self.w_scale,
self.w_bias,
self.thresholdTop,
self.thresholdBottom,
x,
*self.w_conv,
)
return out
if self.explicit_nhwc:
raise RuntimeError("explicit nhwc with native ops is not supported.")
# fallback to native ops
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
================================================
FILE: apex/contrib/bottleneck/halo_exchangers.py
================================================
import torch
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
self.group_size = len(ranks)
self.ranks = ranks
self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (
rank_in_group + self.group_size - 1
) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group - 1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group + 1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.comm = comm
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
N, Hh, W, C = list(left_output_halo.shape)
send_halos = torch.empty(
(N, 2 * Hh, W, C),
dtype=left_output_halo.dtype,
device=left_output_halo.device,
)
send_halos[:, :Hh, :, :].copy_(left_output_halo)
send_halos[:, Hh:, :, :].copy_(right_output_halo)
all_halos = torch.empty(
(N, 2 * Hh * self.group_size, W, C),
dtype=left_output_halo.dtype,
device=left_output_halo.device,
)
all_halos = [
all_halos[:, i * 2 * Hh : (i + 1) * 2 * Hh, :, :] for i in range(self.group_size)
]
torch.distributed.all_gather(all_halos, send_halos, group=self.comm, no_copy=True)
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:, Hh:, :, :]
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:, :Hh, :, :]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert torch.distributed.get_rank() == self.ranks[self.rank_in_group], (
"ranks[%d](%d) != torch.distributed.get_rank()(%d)"
% (
self.rank_in_group,
self.ranks[self.rank_in_group],
torch.distributed.get_rank(),
)
)
self.handle = inc.init_nccl_comm(
nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size()
)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(
self.handle,
self.left_rank,
self.right_rank,
left_output_halo,
right_output_halo,
)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(
self.handle,
self.left_rank,
self.right_rank,
left_output_halo,
right_output_halo,
left_input_halo,
right_input_halo,
)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
def _allocate_peer_tensor(self, halo):
# Compute size in bytes
# Note: Pad buffer so each CUDA block gets required buffer size
size = 4 * halo.numel() * halo.element_size()
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
size = (size + size_per_block - 1) // size_per_block * size_per_block
# Construct dtype peer buffer with desired size
shape = [1, 1, 1, size // halo.element_size()]
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
def left_right_halo_exchange(
self,
left_output_halo,
right_output_halo,
left_input_halo=None,
right_input_halo=None,
):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = (
left_output_halo.is_contiguous(memory_format=torch.channels_last)
and not self.explicit_nhwc
)
left_tx = self._allocate_peer_tensor(left_input_halo)
right_tx = self._allocate_peer_tensor(right_input_halo)
pm.push_pull_halos_1d(
self.diagnostics,
self.explicit_nhwc,
self.numSM,
self.rank_in_group,
self.left_zero,
left_output_halo,
left_tx[self.rank_in_group],
right_tx[self.wrap_around_left_rank_in_group],
left_input_halo,
self.right_zero,
right_output_halo,
right_tx[self.rank_in_group],
left_tx[self.wrap_around_right_rank_in_group],
right_input_halo,
)
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N, H, W, C = list(y.shape)
if H_split:
padded_shape = [N, H + 2 * half_halo, W, C]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.contiguous_format,
)
yleft = ypad[:, :half_halo, :, :]
ymid = ypad[:, half_halo : H + half_halo, :, :]
yright = ypad[:, H + half_halo : H + 2 * half_halo, :, :]
oleft = y[:, :half_halo, :, :]
oright = y[:, H - half_halo :, :, :]
else:
padded_shape = [N, H, W + 2 * half_halo, C]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.contiguous_format,
)
yleft = ypad[:, :, :half_halo, :]
ymid = ypad[:, :, half_halo : W + half_halo, :]
yright = ypad[:, :, W + half_halo : W + 2 * half_halo, :]
oleft = y[:, :, :half_halo, :]
oright = y[:, :, W - half_halo :, :]
else:
N, C, H, W = list(y.shape)
if H_split:
padded_shape = [N, C, H + 2 * half_halo, W]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.channels_last,
)
yleft = ypad[:, :, :half_halo, :]
ymid = ypad[:, :, half_halo : H + half_halo, :]
yright = ypad[:, :, H + half_halo : H + 2 * half_halo, :]
oleft = y[:, :, :half_halo, :]
oright = y[:, :, H - half_halo :, :]
else:
padded_shape = [N, C, H, W + 2 * half_halo]
ypad = torch.empty(
shape=padded_shape,
dtype=y.dtype,
device=y.device,
memory_format=torch.channels_last,
)
yleft = ypad[:, :, :, :half_halo]
ymid = ypad[:, :, :, half_halo : W + half_halo]
yright = ypad[:, :, :, W + half_halo : W + 2 * half_halo]
oleft = y[:, :, :, :half_halo]
oright = y[:, :, :, W - half_halo :]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
================================================
FILE: apex/contrib/bottleneck/test.py
================================================
import torch
from bottleneck import Bottleneck
torch.manual_seed(23337)
# use True to print layerwise sum for all outputs in reference code path
DEBUG = False # True
for stride, o_channel in [(1, 32), (1, 128), (2, 32)]:
print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel)
a_ = torch.randn(17, 32, 28, 28)
a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()
model = (
Bottleneck(32, 8, o_channel, stride=stride)
.cuda()
.half()
.to(memory_format=torch.channels_last)
)
# test model
b = model(a)
b.mean().backward()
d_grad = a.grad.float()
a.grad = None
torch.cuda.synchronize()
if DEBUG:
print("[DEBUG] ref dx :", d_grad.sum().item())
# print wgrad. we don't need to reset since later cpp print before accumulation
for i, w in enumerate(model.w_conv):
print("[DEBUG] ref wgrad{} :".format(i + 1), w.grad.sum().item())
wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float())
model.use_cudnn = True
model.zero_grad()
c = model(a)
c.mean().backward()
torch.cuda.synchronize()
print("comparing native and channels_last:")
print(
"max error fprop:",
(b - c).abs().max().item(),
"max elem:",
b.abs().max().item(),
)
print(
"max error dgrad:",
(d_grad - a.grad.float()).abs().max().item(),
"max elem:",
d_grad.abs().max().item(),
)
for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):
print(
"max error wgrad{}:".format(i + 1),
(wgrad - w.grad.float()).abs().max().item(),
"max elem:",
wgrad.abs().max().item(),
)
nhwc_a = a_.permute(0, 2, 3, 1).contiguous().cuda().half().requires_grad_()
nhwc_model = (
Bottleneck(32, 8, o_channel, stride=stride, explicit_nhwc=True, use_cudnn=True)
.cuda()
.half()
)
for p, q in zip(model.parameters(), nhwc_model.parameters()):
# model's storage is already in nhwc, we clone and assign to explicit nhwc model
q.data.copy_(p.data.permute(0, 2, 3, 1).contiguous())
for p, q in zip(model.buffers(), nhwc_model.buffers()):
q.data.copy_(p.data)
d = nhwc_model(nhwc_a)
d.mean().backward()
torch.cuda.synchronize()
# reset reference to cudnn channels_last permute
# c_s = c.storage().tolist()
# d_s = d.storage().tolist()
# print(max([x-y for x,y in zip(c_s,d_s)]))
c = c.contiguous(memory_format=torch.contiguous_format).permute(0, 2, 3, 1).contiguous()
d_grad = a.grad.float().permute(0, 2, 3, 1).contiguous()
wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float().permute(0, 2, 3, 1).contiguous())
torch.cuda.synchronize()
print("comparing nhwc and channels_last:")
print(
"max error fprop:",
(d - c).abs().max().item(),
"max elem:",
c.abs().max().item(),
)
print(
"max error dgrad:",
(d_grad - nhwc_a.grad.float()).abs().max().item(),
"max elem:",
d_grad.abs().max().item(),
)
for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):
print(
"max error wgrad{}:".format(i + 1),
(wgrad - w.grad.float()).abs().max().item(),
"max elem:",
wgrad.abs().max().item(),
)
================================================
FILE: apex/contrib/clip_grad/__init__.py
================================================
from .clip_grad import clip_grad_norm_
================================================
FILE: apex/contrib/clip_grad/clip_grad.py
================================================
from typing import Union, Iterable
import torch
_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except ImportError:
_kernel_import_succeeded = False
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
parameters: _tensor_or_tensors,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
This is identical to torch.nn.utils.clip_grad_norm_, except it
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
in float32 and float16.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
# Trivial case
if len(parameters) == 0:
return torch.tensor(0.0)
# Fallback implementation
if not (_kernel_import_succeeded and norm_type == 2.0 and any(p.is_cuda for p in parameters)):
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
)
# Find fp32 and fp16 gradients on GPU
device = next(p.device for p in parameters if p.is_cuda)
grads_fp32, grads_fp16, grads_misc = [], [], []
for p in parameters:
grad = p.grad.detach()
if p.dtype == torch.float32 and p.device == device:
grads_fp32.append(grad)
elif p.dtype == torch.float16 and p.device == device:
grads_fp16.append(grad)
else:
grads_misc.append(grad)
# Compute gradient L2 norms
norms = []
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
if grads_fp32:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp32],
False,
)[0]
)
if grads_fp16:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp16],
False,
)[0],
)
for g in grads_misc:
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
total_norm = torch.linalg.norm(torch.cat(norms))
# Check for non-finite values
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
# Scale gradients
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if grads_fp32:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp32, grads_fp32],
clip_coef_clamped,
)
if grads_fp16:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp16, grads_fp16],
clip_coef_clamped,
)
for g in grads_misc:
g.mul_(clip_coef_clamped.to(g.device))
return total_norm
================================================
FILE: apex/contrib/conv_bias_relu/__init__.py
================================================
from .conv_bias_relu import (
ConvBiasReLU,
ConvBias,
ConvBiasMaskReLU,
ConvFrozenScaleBiasReLU,
)
================================================
FILE: apex/contrib/conv_bias_relu/conv_bias_relu.py
================================================
import torch
from apex import check_cudnn_version_and_warn
import fused_conv_bias_relu
check_cudnn_version_and_warn(__name__, 8400)
class ConvBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda")
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
class ConvBiasMaskReLU_(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda")
def forward(ctx, x, weight, bias, mask, padding, stride):
outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
ctx.save_for_backward(x, weight, outputs[0])
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None, None
class ConvBias_(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda")
def forward(ctx, x, weight, bias, padding, stride):
outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
ctx.save_for_backward(x, weight)
ctx.padding = padding
ctx.stride = stride
return outputs[0]
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
return grads[0], grads[1], grads[2], None, None
class ConvFrozenScaleBiasReLU_(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda")
def forward(ctx, x, weight, scale, bias, padding, stride):
output = fused_conv_bias_relu.forward_cscale_cbias_relu(
[x, weight, scale, bias], padding, stride
)
ctx.save_for_backward(x, weight, scale, output)
ctx.padding = padding
ctx.stride = stride
return output
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
bwd_args = [*ctx.saved_tensors, grad_output]
padding = ctx.padding
stride = ctx.stride
grads = fused_conv_bias_relu.backward_cscale_cbias_relu(bwd_args, padding, stride)
return grads[0], grads[1], None, None, None, None
ConvBiasReLU = ConvBiasReLU_.apply
ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
ConvBias = ConvBias_.apply
ConvFrozenScaleBiasReLU = ConvFrozenScaleBiasReLU_.apply
================================================
FILE: apex/contrib/csrc/bottleneck/bottleneck.cpp
================================================
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
#include <cudnn_frontend.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <iostream>
#include <vector>
#ifdef DEBUG
#define DEBUG_MSG(str) \
do { \
std::cout << str << std::endl; \
} while (false)
#else
#define DEBUG_MSG(str) \
do { \
} while (false)
#endif
#ifdef DEBUG_CUDNN
#define DEBUG_CUDNN_MSG(buf, str) \
do { \
buf << str << std::endl; \
} while (false)
#else
#define DEBUG_CUDNN_MSG(buf, str) \
do { \
} while (false)
#endif
#define checkCudnnErr(...) \
do { \
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
if (err) { \
return; \
} \
} while (0)
int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
if (code) {
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
return 1;
}
return 0;
}
void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true);
#define checkCUDAError(val) \
{ \
checkError((val), #val, __FILE__, __LINE__); \
} // in-line regular function
void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) {
if (code != cudaSuccess) {
const char* errorMessage = cudaGetErrorString(code);
fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code,
errorMessage);
if (abort) {
cudaDeviceReset();
exit(code);
}
}
}
void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
if (filterFormat == CUDNN_TENSOR_NCHW) {
strideA[nbDims - 1] = 1;
for (int64_t d = nbDims - 2; d >= 0; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
} else {
// Here we assume that the format is CUDNN_TENSOR_NHWC
strideA[1] = 1;
strideA[nbDims - 1] = strideA[1] * dimA[1];
for (int64_t d = nbDims - 2; d >= 2; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
strideA[0] = strideA[2] * dimA[2];
}
}
int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; }
int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); }
int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) {
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
return (p);
}
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
};
using common_conv_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::ConvDesc>;
common_conv_descriptors create_common_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,
int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,
cudnnDataType_t dataType, cudnnConvolutionMode_t mode) {
const int convDim = 2;
int64_t strideA_padded[4];
int64_t outstrideA_padded[4];
int64_t filterstrideA_padded[4];
generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC);
return common_conv_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, strideA_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, outstrideA_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, filterstrideA_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(mode)
.setNDims(convDim)
.setStrides(convDim, convstrideA)
.setPrePadding(convDim, padA)
.setPostPadding(convDim, padA)
.setDilation(convDim, dilationA)
.build());
}
using common_convbias_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor>;
common_convbias_descriptors create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, int64_t* padA,
int64_t* convstrideA, int64_t* dilationA,
int64_t* w_dim_padded, int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return common_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
};
using dconv_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;
dconv_descriptors create_dconv_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,
int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
// create a cache for plan
std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
// TODO: better name
std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA,
int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) {
for (int i = 0; i < 4; i++) {
fusion_string += 'X';
fusion_string += std::to_string(x_dim_padded[i]);
}
for (int i = 0; i < 4; i++) {
fusion_string += 'W';
fusion_string += std::to_string(w_dim_padded[i]);
}
for (int i = 0; i < 2; i++) {
fusion_string += 'P';
fusion_string += std::to_string(padA[i]);
}
for (int i = 0; i < 2; i++) {
fusion_string += 'S';
fusion_string += std::to_string(convstrideA[i]);
}
for (int i = 0; i < 2; i++) {
fusion_string += 'D';
fusion_string += std::to_string(dilationA[i]);
}
fusion_string += 'T';
fusion_string += std::to_string(dataType);
return fusion_string;
}
cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf,
cudnn_frontend::OperationGraph& opGraph, std::string cache_string,
bool use_heuristic = true) {
auto it = plan_cache.find(cache_string);
if (it != plan_cache.end()) {
DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
return it->second;
} else {
if (use_heuristic) {
// TODO: confirm which mode to use
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
// try 3 times for now as WAR for no heuristic training
int max_tries = 3, count = 0;
auto& engine_configs = heuristics.getEngineConfig(max_tries);
while (true) {
try {
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(engine_configs[count], opGraph.getTag())
.build()));
break;
} catch (cudnn_frontend::cudnnException e) {
if (++count == max_tries) throw e;
}
}
} else {
DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
// How many engines support this operation graph ?
auto total_engines = opGraph.getEngineCount();
DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
// We have to randomly pick one engine from [0, total_engines)
// Selecting "0" by default
auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
DEBUG_CUDNN_MSG(log_buf, engine.describe());
auto& knobs = engine.getSupportedKnobs();
for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
DEBUG_CUDNN_MSG(log_buf, it->describe());
}
if (knobs.begin() != knobs.end()) {
DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
}
// Createmplacee the requisite engine config
auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
plan_cache.emplace(
cache_string,
std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
}
return plan_cache.find(cache_string)->second;
}
}
void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,
at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,
at::Half* devPtrB, at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,
w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
// Define the add operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op,
&act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(devPtrI ? ops.size() : 4, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(devPtrI ? 6 : 5, data_ptrs)
.setUids(devPtrI ? 6 : 5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX,
at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,
w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
// Define the add operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors)) // TODO: change enum to aftermul
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &scale_op, &add_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX,
at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors =
create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &scale_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR};
int64_t uids[] = {'x', 'y', 'w', 's', 'r'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded,
int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY,
cudnnBackendDescriptorType_t mode) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors =
create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
// mode should be one of following
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
conv_op_builder.setdxDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta);
} else {
conv_op_builder.setxDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setdwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta);
}
auto conv_op = conv_op_builder.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW};
int64_t uids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded,
int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,
at::Half* devPtrY, at::Half* devPtrR) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors =
create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the add backward operation
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<RELU_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &add_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR};
int64_t uids[] = {'x', 'y', 'w', 'r'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
// inputs contains x,w,z,b,(i)
std::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t dimA[] = {0, 0, 0, 0};
int64_t filterdimA1[] = {0, 0, 0, 0};
int64_t filterdimA2[] = {0, 0, 0, 0};
int64_t filterdimA3[] = {0, 0, 0, 0};
int64_t filterdimA4[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[]{0, 1, 2, 3};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
}
for (int dim = 0; dim < 4; dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim = 0; dim < 4; dim++) {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below
// use these fixed value for test run
int64_t padA[] = {0, 0};
int64_t padA1[] = {1, 1};
int64_t dilationA[] = {1, 1};
int64_t convstrideA[] = {1, 1};
int64_t convstride1X1[] = {stride_1X1, stride_1X1};
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] =
getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] =
getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] =
getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t outdim1[] = {0, 0, 0, 0};
int64_t outdim2[] = {0, 0, 0, 0};
int64_t outdim3[] = {0, 0, 0, 0};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim = 0; dim < 4; dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* b = inputs[7].data_ptr<at::Half>();
auto out1 = at::empty(outdim1, inputs[0].type(), output_format);
at::Half* y1 = out1.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, w,
y1, z, b, nullptr);
DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item<float>());
w = inputs[2].data_ptr<at::Half>();
z = inputs[5].data_ptr<at::Half>();
b = inputs[8].data_ptr<at::Half>();
auto out2 = at::empty(outdim2, inputs[0].type(), output_format);
at::Half* y2 = out2.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF,
y1, w, y2, z, b, nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
// create output of conv3
auto out3 = at::empty(outdim3, inputs[0].type(), output_format);
at::Half* y3 = out3.data_ptr<at::Half>();
// create output of conv4 that may exist
auto identity = at::empty_like(out3);
at::Half* yi = identity.data_ptr<at::Half>();
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
w = inputs[10].data_ptr<at::Half>();
z = inputs[11].data_ptr<at::Half>();
b = inputs[12].data_ptr<at::Half>();
run_conv_scale_bias(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b);
DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item<float>());
} else {
yi = x;
}
w = inputs[3].data_ptr<at::Half>();
z = inputs[6].data_ptr<at::Half>();
b = inputs[9].data_ptr<at::Half>();
run_conv_scale_bias_add_activation(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, y2,
w, y3, z, b, yi);
DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item<float>());
outputs.push_back(out1);
outputs.push_back(out2);
outputs.push_back(out3);
return outputs;
}
std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t dimA[] = {0, 0, 0, 0};
int64_t filterdimA1[] = {0, 0, 0, 0};
int64_t filterdimA2[] = {0, 0, 0, 0};
int64_t filterdimA3[] = {0, 0, 0, 0};
int64_t filterdimA4[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[]{0, 1, 2, 3};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
}
for (int dim = 0; dim < 4; dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim = 0; dim < 4; dim++) {
filterdimA4[dim] = inputs[14].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below
// use these fixed value for test run
int64_t padA[] = {0, 0};
int64_t padA1[] = {1, 1};
int64_t dilationA[] = {1, 1};
int64_t convstrideA[] = {1, 1};
int64_t convstride1X1[] = {stride_1X1, stride_1X1};
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] =
getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] =
getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] =
getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t outdim1[] = {0, 0, 0, 0};
int64_t outdim2[] = {0, 0, 0, 0};
int64_t outdim3[] = {0, 0, 0, 0};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim = 0; dim < 4; dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
}
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad
auto wgrad3 = at::empty_like(inputs[3]);
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
run_dconv(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format);
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* w = inputs[3].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* relu2 = inputs[13].data_ptr<at::Half>();
run_dconv_drelu_dscale(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, dy2, w, dy3, z,
relu2);
DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item<float>());
// dconv2+drelu1+dscale1
conv_in = inputs[12].data_ptr<at::Half>();
// wgrad
auto wgrad2 = at::empty_like(inputs[2]);
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
run_dconv(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
w = inputs[2].data_ptr<at::Half>();
z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2,
z, relu1);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item<float>());
// create grads of conv4 that may exist
auto grad_x_conv4 = at::empty_like(inputs[0]);
at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();
at::Tensor wgrad4;
// x used for dconv1 and dconv4 wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
w = inputs[14].data_ptr<at::Half>();
at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();
if (requires_grad) {
run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4 = at::empty_like(inputs[14]);
at::Half* dw4 = wgrad4.data_ptr<at::Half>();
run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
} else {
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4 = inputs[11].data_ptr<at::Half>();
}
// dconv1+add
// wgrad
auto wgrad1 = at::empty_like(inputs[1]);
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, dw1, dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
w = inputs[1].data_ptr<at::Half>();
auto grad_x = at::empty_like(inputs[0]);
at::Half* dx = grad_x.data_ptr<at::Half>();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (requires_grad) {
if (stride_1X1 != 1) {
run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// add 2 together
grad_x.add_(grad_x_conv4);
} else {
run_dconv_add(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4);
}
}
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
outputs.push_back(grad_x);
outputs.push_back(wgrad1);
outputs.push_back(wgrad2);
outputs.push_back(wgrad3);
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
outputs.push_back(wgrad4);
}
return outputs;
}
namespace {
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
AFTERACT_TENSOR,
GEN_INDEX_TENSOR,
MASK_TOP_TENSOR,
MASK_BOTTOM_TENSOR,
MASK_TENSOR,
THRESHOLD_TOP_TENSOR,
THRESHOLD_BOTTOM_TENSOR,
};
using masked_convbias_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
masked_convbias_descriptors create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, int64_t* padA,
int64_t* convstrideA, int64_t* dilationA,
int64_t* w_dim_padded, int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('E') // after act for masked
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
DGRAD_INPUT_TENSOR,
DGRAD_OPTIONAL_TENSOR,
DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
};
using dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;
dconv_add_descriptors create_dconv_add_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,
int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_add_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;
dconv_mask_descriptors create_dconv_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,
int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,
int64_t* threshold_dim, cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,
at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,
at::Half* devPtrB, at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,
w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// create an add node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride,
int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded,
int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX,
at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB,
at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<AFTERACT_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERACT_TENSOR>(tensors))
.setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setbDesc(std::get<AFTERACT_TENSOR>(tensors))
.settDesc(std::get<MASK_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
if (devPtrI) {
std::array<cudnn_frontend::Operation const*, 10> ops = {
&conv_op, &scale_op, &bias_op, &add_op, &act_op,
&genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(8, data_ptrs)
.setUids(8, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} else {
std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op, &scale_op, &bias_op,
&act_op, &genIndex_op, &lessThan_op,
&greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
}
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,
at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,
at::Half* devPtrR, at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_add_descriptors tensors =
create_dconv_add_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc =
cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph =
cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
// Create string encoding for plan caching
auto cache_string =
getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,
int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim,
cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY,
at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(x_dim_padded, pad, convstride, dilation,
w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(
gitextract_8yaiblk9/
├── .clang-format
├── .git-blame-ignore-revs
├── .github/
│ └── ISSUE_TEMPLATE/
│ └── bug_report.md
├── .gitignore
├── .gitmodules
├── .nojekyll
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── apex/
│ ├── __init__.py
│ ├── _autocast_utils.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── bottleneck/
│ │ │ ├── __init__.py
│ │ │ ├── bottleneck.py
│ │ │ ├── halo_exchangers.py
│ │ │ └── test.py
│ │ ├── clip_grad/
│ │ │ ├── __init__.py
│ │ │ └── clip_grad.py
│ │ ├── conv_bias_relu/
│ │ │ ├── __init__.py
│ │ │ └── conv_bias_relu.py
│ │ ├── csrc/
│ │ │ ├── bottleneck/
│ │ │ │ └── bottleneck.cpp
│ │ │ ├── conv_bias_relu/
│ │ │ │ └── conv_bias_relu.cpp
│ │ │ ├── cudnn_gbn/
│ │ │ │ ├── cudnn_gbn.cpp
│ │ │ │ ├── norm_sample.cpp
│ │ │ │ └── norm_sample.h
│ │ │ ├── fmha/
│ │ │ │ ├── fmha_api.cpp
│ │ │ │ └── src/
│ │ │ │ ├── fmha/
│ │ │ │ │ ├── gemm.h
│ │ │ │ │ ├── gmem_tile.h
│ │ │ │ │ ├── kernel_traits.h
│ │ │ │ │ ├── mask.h
│ │ │ │ │ ├── smem_tile.h
│ │ │ │ │ ├── softmax.h
│ │ │ │ │ └── utils.h
│ │ │ │ ├── fmha.h
│ │ │ │ ├── fmha_dgrad_fp16_128_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_256_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_384_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_fp16_512_64_kernel.sm80.cu
│ │ │ │ ├── fmha_dgrad_kernel_1xN_reload.h
│ │ │ │ ├── fmha_dgrad_kernel_1xN_reload_nl.h
│ │ │ │ ├── fmha_fill.cu
│ │ │ │ ├── fmha_fprop_fp16_128_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_256_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_384_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_fp16_512_64_kernel.sm80.cu
│ │ │ │ ├── fmha_fprop_kernel_1xN.h
│ │ │ │ ├── fmha_kernel.h
│ │ │ │ ├── fmha_noloop_reduce.cu
│ │ │ │ └── fmha_utils.h
│ │ │ ├── focal_loss/
│ │ │ │ ├── focal_loss_cuda.cpp
│ │ │ │ └── focal_loss_cuda_kernel.cu
│ │ │ ├── gpu_direct_storage/
│ │ │ │ ├── gds.cpp
│ │ │ │ ├── gds.h
│ │ │ │ └── gds_pybind.cpp
│ │ │ ├── group_norm/
│ │ │ │ ├── group_norm_nhwc.cpp
│ │ │ │ ├── group_norm_nhwc.h
│ │ │ │ ├── group_norm_nhwc_bwd_one_pass.h
│ │ │ │ ├── group_norm_nhwc_bwd_one_pass_kernel.cuh
│ │ │ │ ├── group_norm_nhwc_bwd_two_pass.cu
│ │ │ │ ├── group_norm_nhwc_fwd_one_pass.h
│ │ │ │ ├── group_norm_nhwc_fwd_one_pass_kernel.cuh
│ │ │ │ ├── group_norm_nhwc_fwd_two_pass.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_10.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_112.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_12.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_120.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_128.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_14.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_16.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_160.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_20.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_24.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_26.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_28.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_30.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_32.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_4.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_40.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_42.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_48.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_56.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_60.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_64.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_70.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_8.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_80.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_84.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_96.cu
│ │ │ │ ├── group_norm_nhwc_one_pass_98.cu
│ │ │ │ ├── group_norm_nhwc_op.cpp
│ │ │ │ ├── macros.h
│ │ │ │ └── traits.h
│ │ │ ├── group_norm_v2/
│ │ │ │ ├── generate_gn_cuda_inst.py
│ │ │ │ ├── gn.cpp
│ │ │ │ ├── gn.hpp
│ │ │ │ ├── gn_cuda.cu
│ │ │ │ ├── gn_cuda_host_template.cuh
│ │ │ │ ├── gn_cuda_inst_1024_1280.cu
│ │ │ │ ├── gn_cuda_inst_1024_1920.cu
│ │ │ │ ├── gn_cuda_inst_1024_320.cu
│ │ │ │ ├── gn_cuda_inst_1024_640.cu
│ │ │ │ ├── gn_cuda_inst_1024_960.cu
│ │ │ │ ├── gn_cuda_inst_256_1280.cu
│ │ │ │ ├── gn_cuda_inst_256_1920.cu
│ │ │ │ ├── gn_cuda_inst_256_2560.cu
│ │ │ │ ├── gn_cuda_inst_256_640.cu
│ │ │ │ ├── gn_cuda_inst_4096_320.cu
│ │ │ │ ├── gn_cuda_inst_4096_640.cu
│ │ │ │ ├── gn_cuda_inst_4096_960.cu
│ │ │ │ ├── gn_cuda_inst_64_1280.cu
│ │ │ │ ├── gn_cuda_inst_64_2560.cu
│ │ │ │ ├── gn_cuda_kernel.cuh
│ │ │ │ ├── gn_dispatch_hw_c.hpp
│ │ │ │ ├── gn_utils.cpp
│ │ │ │ └── gn_utils.hpp
│ │ │ ├── groupbn/
│ │ │ │ ├── batch_norm.cu
│ │ │ │ ├── batch_norm.h
│ │ │ │ ├── batch_norm_add_relu.cu
│ │ │ │ ├── batch_norm_add_relu.h
│ │ │ │ ├── cuda_utils.h
│ │ │ │ ├── interface.cpp
│ │ │ │ ├── ipc.cu
│ │ │ │ └── nhwc_batch_norm_kernel.h
│ │ │ ├── index_mul_2d/
│ │ │ │ ├── index_mul_2d_cuda.cpp
│ │ │ │ └── index_mul_2d_cuda_kernel.cu
│ │ │ ├── layer_norm/
│ │ │ │ ├── ln.h
│ │ │ │ ├── ln_api.cpp
│ │ │ │ ├── ln_bwd_kernels.cuh
│ │ │ │ ├── ln_bwd_semi_cuda_kernel.cu
│ │ │ │ ├── ln_fwd_cuda_kernel.cu
│ │ │ │ ├── ln_fwd_kernels.cuh
│ │ │ │ ├── ln_kernel_traits.h
│ │ │ │ └── ln_utils.cuh
│ │ │ ├── multihead_attn/
│ │ │ │ ├── additive_masked_softmax_dropout_cuda.cu
│ │ │ │ ├── dropout.cuh
│ │ │ │ ├── encdec_multihead_attn_cuda.cu
│ │ │ │ ├── encdec_multihead_attn_norm_add_cuda.cu
│ │ │ │ ├── layer_norm.cuh
│ │ │ │ ├── masked_softmax_dropout_cuda.cu
│ │ │ │ ├── multihead_attn_frontend.cpp
│ │ │ │ ├── philox.cuh
│ │ │ │ ├── self_multihead_attn_bias_additive_mask_cuda.cu
│ │ │ │ ├── self_multihead_attn_bias_cuda.cu
│ │ │ │ ├── self_multihead_attn_cuda.cu
│ │ │ │ ├── self_multihead_attn_norm_add_cuda.cu
│ │ │ │ ├── softmax.cuh
│ │ │ │ └── strided_batched_gemm.cuh
│ │ │ ├── nccl_allocator/
│ │ │ │ └── NCCLAllocator.cpp
│ │ │ ├── nccl_p2p/
│ │ │ │ ├── nccl_p2p.cpp
│ │ │ │ ├── nccl_p2p_cuda.cu
│ │ │ │ ├── nccl_p2p_cuda.cuh
│ │ │ │ ├── nccl_version.cpp
│ │ │ │ └── nccl_version_check.cu
│ │ │ ├── optimizers/
│ │ │ │ ├── fused_adam_cuda.cpp
│ │ │ │ ├── fused_adam_cuda_kernel.cu
│ │ │ │ ├── fused_lamb_cuda.cpp
│ │ │ │ ├── fused_lamb_cuda_kernel.cu
│ │ │ │ ├── multi_tensor_distopt_adam.cpp
│ │ │ │ ├── multi_tensor_distopt_adam_kernel.cu
│ │ │ │ ├── multi_tensor_distopt_lamb.cpp
│ │ │ │ └── multi_tensor_distopt_lamb_kernel.cu
│ │ │ ├── peer_memory/
│ │ │ │ ├── peer_memory.cpp
│ │ │ │ ├── peer_memory_cuda.cu
│ │ │ │ └── peer_memory_cuda.cuh
│ │ │ ├── transducer/
│ │ │ │ ├── transducer_joint.cpp
│ │ │ │ ├── transducer_joint_kernel.cu
│ │ │ │ ├── transducer_loss.cpp
│ │ │ │ └── transducer_loss_kernel.cu
│ │ │ └── xentropy/
│ │ │ ├── interface.cpp
│ │ │ └── xentropy_kernel.cu
│ │ ├── cudnn_gbn/
│ │ │ ├── __init__.py
│ │ │ └── batch_norm.py
│ │ ├── examples/
│ │ │ ├── gpu_direct_storage/
│ │ │ │ ├── benchmark_load.py
│ │ │ │ ├── benchmark_save.py
│ │ │ │ ├── example_load.py
│ │ │ │ └── example_save.py
│ │ │ ├── multihead_attn/
│ │ │ │ ├── func_test_multihead_attn.py
│ │ │ │ └── perf_test_multihead_attn.py
│ │ │ └── nccl_allocator/
│ │ │ ├── allreduce.py
│ │ │ ├── cache.py
│ │ │ ├── change_cuda_allocator.py
│ │ │ └── toy_ddp.py
│ │ ├── fmha/
│ │ │ ├── __init__.py
│ │ │ └── fmha.py
│ │ ├── focal_loss/
│ │ │ ├── __init__.py
│ │ │ └── focal_loss.py
│ │ ├── gpu_direct_storage/
│ │ │ ├── README.md
│ │ │ └── __init__.py
│ │ ├── group_norm/
│ │ │ ├── __init__.py
│ │ │ └── group_norm.py
│ │ ├── groupbn/
│ │ │ ├── __init__.py
│ │ │ └── batch_norm.py
│ │ ├── index_mul_2d/
│ │ │ ├── __init__.py
│ │ │ └── index_mul_2d.py
│ │ ├── layer_norm/
│ │ │ ├── __init__.py
│ │ │ └── layer_norm.py
│ │ ├── multihead_attn/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── encdec_multihead_attn.py
│ │ │ ├── encdec_multihead_attn_func.py
│ │ │ ├── fast_encdec_multihead_attn_func.py
│ │ │ ├── fast_encdec_multihead_attn_norm_add_func.py
│ │ │ ├── fast_self_multihead_attn_func.py
│ │ │ ├── fast_self_multihead_attn_norm_add_func.py
│ │ │ ├── mask_softmax_dropout_func.py
│ │ │ ├── self_multihead_attn.py
│ │ │ └── self_multihead_attn_func.py
│ │ ├── nccl_allocator/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ └── nccl_allocator.py
│ │ ├── openfold_triton/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── _layer_norm_backward_kernels.py
│ │ │ ├── _layer_norm_config_ampere.py
│ │ │ ├── _layer_norm_config_hopper.py
│ │ │ ├── _layer_norm_forward_kernels.py
│ │ │ ├── _mha_kernel.py
│ │ │ ├── fused_adam_swa.py
│ │ │ ├── layer_norm.py
│ │ │ └── mha.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── distributed_fused_adam.py
│ │ │ ├── distributed_fused_lamb.py
│ │ │ ├── fp16_optimizer.py
│ │ │ ├── fused_adam.py
│ │ │ ├── fused_lamb.py
│ │ │ └── fused_sgd.py
│ │ ├── peer_memory/
│ │ │ ├── __init__.py
│ │ │ ├── peer_halo_exchanger_1d.py
│ │ │ └── peer_memory.py
│ │ ├── sparsity/
│ │ │ ├── COPYRIGHT
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── asp.py
│ │ │ ├── permutation_lib.py
│ │ │ ├── permutation_search_kernels/
│ │ │ │ ├── CUDA_kernels/
│ │ │ │ │ └── permutation_search_kernels.cu
│ │ │ │ ├── __init__.py
│ │ │ │ ├── call_permutation_search_kernels.py
│ │ │ │ ├── channel_swap.py
│ │ │ │ ├── exhaustive_search.py
│ │ │ │ └── permutation_utilities.py
│ │ │ ├── permutation_tests/
│ │ │ │ ├── README.md
│ │ │ │ ├── ablation_studies.sh
│ │ │ │ ├── permutation_test.py
│ │ │ │ ├── runtime_table.sh
│ │ │ │ └── unstructured_study.sh
│ │ │ ├── sparse_masklib.py
│ │ │ └── test/
│ │ │ ├── checkpointing_test_part1.py
│ │ │ ├── checkpointing_test_part2.py
│ │ │ ├── checkpointing_test_reference.py
│ │ │ ├── test_permutation_application.py
│ │ │ └── toy_problem.py
│ │ ├── test/
│ │ │ ├── __init__.py
│ │ │ ├── bottleneck/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_bottleneck_module.py
│ │ │ ├── clip_grad/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_clip_grad.py
│ │ │ ├── conv_bias_relu/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_conv_bias_relu.py
│ │ │ ├── cudnn_gbn/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_cudnn_gbn_with_two_gpus.py
│ │ │ ├── fmha/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_fmha.py
│ │ │ ├── focal_loss/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_focal_loss.py
│ │ │ ├── fused_dense/
│ │ │ │ └── test_fused_dense.py
│ │ │ ├── group_norm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_group_norm.py
│ │ │ ├── index_mul_2d/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_index_mul_2d.py
│ │ │ ├── layer_norm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_fast_layer_norm.py
│ │ │ ├── multihead_attn/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_encdec_multihead_attn.py
│ │ │ │ ├── test_encdec_multihead_attn_norm_add.py
│ │ │ │ ├── test_fast_self_multihead_attn_bias.py
│ │ │ │ ├── test_mha_fused_softmax.py
│ │ │ │ ├── test_self_multihead_attn.py
│ │ │ │ └── test_self_multihead_attn_norm_add.py
│ │ │ ├── openfold_triton/
│ │ │ │ ├── test_fused_adam_swa.py
│ │ │ │ ├── test_openfold_mha.py
│ │ │ │ └── test_sync_triton_auto_tune_cache_across_gpus.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_dist_adam.py
│ │ │ │ └── test_distributed_fused_lamb.py
│ │ │ ├── peer_memory/
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_peer_halo_exchange_module.py
│ │ │ ├── transducer/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_transducer_joint.py
│ │ │ │ └── test_transducer_loss.py
│ │ │ └── xentropy/
│ │ │ ├── __init__.py
│ │ │ └── test_label_smoothing.py
│ │ ├── torchsched/
│ │ │ ├── __init__.py
│ │ │ ├── backend.py
│ │ │ ├── config.py
│ │ │ ├── inductor/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _utils.py
│ │ │ │ ├── event.py
│ │ │ │ ├── graph.py
│ │ │ │ ├── scheduler.py
│ │ │ │ └── wrapper.py
│ │ │ ├── ops/
│ │ │ │ ├── __init__.py
│ │ │ │ └── layer_norm.py
│ │ │ └── passes/
│ │ │ ├── __init__.py
│ │ │ └── pre_grad_passes.py
│ │ ├── transducer/
│ │ │ ├── __init__.py
│ │ │ ├── _transducer_ref.py
│ │ │ └── transducer.py
│ │ └── xentropy/
│ │ ├── __init__.py
│ │ └── softmax_xentropy.py
│ ├── distributed_testing/
│ │ ├── __init__.py
│ │ ├── _ucc_util.py
│ │ └── distributed_test_base.py
│ ├── fused_dense/
│ │ ├── __init__.py
│ │ └── fused_dense.py
│ ├── mlp/
│ │ ├── __init__.py
│ │ └── mlp.py
│ ├── multi_tensor_apply/
│ │ ├── __init__.py
│ │ └── multi_tensor_apply.py
│ ├── normalization/
│ │ ├── __init__.py
│ │ └── fused_layer_norm.py
│ └── optimizers/
│ ├── __init__.py
│ ├── fused_adagrad.py
│ ├── fused_adam.py
│ ├── fused_lamb.py
│ ├── fused_mixed_precision_lamb.py
│ ├── fused_novograd.py
│ └── fused_sgd.py
├── csrc/
│ ├── amp_C_frontend.cpp
│ ├── flatten_unflatten.cpp
│ ├── fused_dense.cpp
│ ├── fused_dense_cuda.cu
│ ├── layer_norm_cuda.cpp
│ ├── layer_norm_cuda_kernel.cu
│ ├── megatron/
│ │ ├── fused_rotary_positional_embedding.cpp
│ │ ├── fused_rotary_positional_embedding.h
│ │ ├── fused_rotary_positional_embedding_cuda.cu
│ │ ├── fused_weight_gradient_dense.cpp
│ │ ├── fused_weight_gradient_dense_16bit_prec_cuda.cu
│ │ ├── fused_weight_gradient_dense_cuda.cu
│ │ ├── generic_scaled_masked_softmax.cpp
│ │ ├── generic_scaled_masked_softmax.h
│ │ ├── generic_scaled_masked_softmax_cuda.cu
│ │ ├── scaled_masked_softmax.cpp
│ │ ├── scaled_masked_softmax.h
│ │ ├── scaled_masked_softmax_cuda.cu
│ │ ├── scaled_softmax.cpp
│ │ ├── scaled_softmax_cuda.cu
│ │ ├── scaled_upper_triang_masked_softmax.cpp
│ │ ├── scaled_upper_triang_masked_softmax.h
│ │ └── scaled_upper_triang_masked_softmax_cuda.cu
│ ├── mlp.cpp
│ ├── mlp_cuda.cu
│ ├── multi_tensor_adagrad.cu
│ ├── multi_tensor_adam.cu
│ ├── multi_tensor_apply.cuh
│ ├── multi_tensor_axpby_kernel.cu
│ ├── multi_tensor_l2norm_kernel.cu
│ ├── multi_tensor_l2norm_kernel_mp.cu
│ ├── multi_tensor_l2norm_scale_kernel.cu
│ ├── multi_tensor_lamb.cu
│ ├── multi_tensor_lamb_mp.cu
│ ├── multi_tensor_lamb_stage_1.cu
│ ├── multi_tensor_lamb_stage_2.cu
│ ├── multi_tensor_novograd.cu
│ ├── multi_tensor_scale_kernel.cu
│ ├── multi_tensor_sgd_kernel.cu
│ ├── static_switch.h
│ ├── syncbn.cpp
│ ├── type_shim.h
│ ├── update_scale_hysteresis.cu
│ └── welford.cu
├── docs/
│ ├── Makefile
│ └── source/
│ ├── _static/
│ │ └── css/
│ │ └── pytorch_theme.css
│ ├── _templates/
│ │ └── layout.html
│ ├── conf.py
│ ├── index.rst
│ ├── layernorm.rst
│ └── optimizers.rst
├── examples/
│ ├── README.md
│ ├── dcgan/
│ │ ├── README.md
│ │ └── main_amp.py
│ ├── docker/
│ │ ├── Dockerfile
│ │ └── README.md
│ ├── imagenet/
│ │ ├── README.md
│ │ └── main_amp.py
│ └── simple/
│ └── distributed/
│ ├── README.md
│ ├── distributed_data_parallel.py
│ └── run.sh
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
├── setup.py
└── tests/
├── L0/
│ ├── run_fused_layer_norm/
│ │ └── test_fused_layer_norm.py
│ ├── run_mlp/
│ │ └── test_mlp.py
│ ├── run_optimizers/
│ │ ├── __init__.py
│ │ ├── test_adam.py
│ │ ├── test_fused_novograd.py
│ │ ├── test_fused_optimizer.py
│ │ └── test_lamb.py
│ └── run_test.py
├── L1/
│ ├── common/
│ │ ├── compare.py
│ │ ├── main_amp.py
│ │ └── run_test.sh
│ ├── cross_product/
│ │ └── run.sh
│ └── cross_product_distributed/
│ └── run.sh
├── distributed/
│ ├── DDP/
│ │ ├── ddp_race_condition_test.py
│ │ └── run_race_test.sh
│ ├── amp_master_params/
│ │ ├── amp_master_params.py
│ │ ├── compare.py
│ │ └── run.sh
│ └── synced_batchnorm/
│ ├── python_single_gpu_unit_test.py
│ ├── single_gpu_unit_test.py
│ ├── test_batchnorm1d.py
│ ├── test_groups.py
│ ├── two_gpu_test_different_batch_size.py
│ ├── two_gpu_unit_test.py
│ └── unit_test.sh
└── docker_extension_builds/
└── run.sh
SYMBOL INDEX (1684 symbols across 192 files)
FILE: apex/__init__.py
function check_cudnn_version_and_warn (line 21) | def check_cudnn_version_and_warn(global_option: str, required_cudnn_vers...
class DeprecatedFeatureWarning (line 33) | class DeprecatedFeatureWarning(FutureWarning):
function deprecated_warning (line 37) | def deprecated_warning(msg: str) -> None:
FILE: apex/_autocast_utils.py
function _get_autocast_dtypes (line 9) | def _get_autocast_dtypes() -> Sequence[torch.dtype]:
function _get_current_dtype (line 15) | def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
function _cast_if_autocast_enabled (line 22) | def _cast_if_autocast_enabled(*args):
FILE: apex/contrib/bottleneck/bottleneck.py
function kaiming_uniform_ (line 14) | def kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="leaky_rel...
function compute_scale_bias_one (line 19) | def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var...
function compute_scale_bias_method (line 26) | def compute_scale_bias_method(nhwc, args):
class FrozenBatchNorm2d (line 32) | class FrozenBatchNorm2d(torch.jit.ScriptModule):
method __init__ (line 37) | def __init__(self, n):
method get_scale_bias (line 45) | def get_scale_bias(self, nhwc):
method forward (line 58) | def forward(self, x):
function drelu_dscale1 (line 64) | def drelu_dscale1(grad_o, output, scale1):
function drelu_dscale2 (line 72) | def drelu_dscale2(grad_o, output, scale1, scale2):
class BottleneckFunction (line 80) | class BottleneckFunction(torch.autograd.Function):
method forward (line 82) | def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):
method backward (line 104) | def backward(ctx, grad_o):
function conv3x3 (line 135) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 149) | def conv1x1(in_planes, out_planes, stride=1):
class Bottleneck (line 154) | class Bottleneck(torch.nn.Module):
method __init__ (line 162) | def __init__(
method get_scale_bias_callable (line 233) | def get_scale_bias_callable(self):
method forward (line 250) | def forward(self, x):
class SpatialBottleneckFunction (line 304) | class SpatialBottleneckFunction(torch.autograd.Function):
method forward (line 306) | def forward(
method backward (line 546) | def backward(ctx, grad_o):
class SpatialBottleneck (line 833) | class SpatialBottleneck(torch.nn.Module):
method __init__ (line 841) | def __init__(
method get_scale_bias_callable (line 920) | def get_scale_bias_callable(self):
method forward (line 937) | def forward(self, x):
FILE: apex/contrib/bottleneck/halo_exchangers.py
class HaloExchanger (line 10) | class HaloExchanger(object):
method __init__ (line 11) | def __init__(self, ranks, rank_in_group):
class HaloExchangerNoComm (line 28) | class HaloExchangerNoComm(HaloExchanger):
method __init__ (line 29) | def __init__(self, ranks, rank_in_group):
method left_right_halo_exchange (line 32) | def left_right_halo_exchange(
class HaloExchangerAllGather (line 46) | class HaloExchangerAllGather(HaloExchanger):
method __init__ (line 47) | def __init__(self, ranks, rank_in_group, comm):
method left_right_halo_exchange (line 52) | def left_right_halo_exchange(
class HaloExchangerSendRecv (line 95) | class HaloExchangerSendRecv(HaloExchanger):
method __init__ (line 96) | def __init__(self, ranks, rank_in_group):
method left_right_halo_exchange (line 118) | def left_right_halo_exchange(
class HaloExchangerPeer (line 146) | class HaloExchangerPeer(HaloExchanger):
method __init__ (line 147) | def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, num...
method _allocate_peer_tensor (line 154) | def _allocate_peer_tensor(self, halo):
method left_right_halo_exchange (line 165) | def left_right_halo_exchange(
class HaloPadder (line 203) | class HaloPadder:
method __init__ (line 204) | def __init__(self, halo_ex):
method __call__ (line 209) | def __call__(self, y, half_halo, explicit_nhwc, H_split):
method wait (line 273) | def wait(self):
FILE: apex/contrib/clip_grad/clip_grad.py
function clip_grad_norm_ (line 17) | def clip_grad_norm_(
FILE: apex/contrib/conv_bias_relu/conv_bias_relu.py
class ConvBiasReLU_ (line 9) | class ConvBiasReLU_(torch.autograd.Function):
method forward (line 12) | def forward(ctx, x, weight, bias, padding, stride):
method backward (line 22) | def backward(ctx, grad_output):
class ConvBiasMaskReLU_ (line 31) | class ConvBiasMaskReLU_(torch.autograd.Function):
method forward (line 34) | def forward(ctx, x, weight, bias, mask, padding, stride):
method backward (line 44) | def backward(ctx, grad_output):
class ConvBias_ (line 53) | class ConvBias_(torch.autograd.Function):
method forward (line 56) | def forward(ctx, x, weight, bias, padding, stride):
method backward (line 66) | def backward(ctx, grad_output):
class ConvFrozenScaleBiasReLU_ (line 75) | class ConvFrozenScaleBiasReLU_(torch.autograd.Function):
method forward (line 78) | def forward(ctx, x, weight, scale, bias, padding, stride):
method backward (line 90) | def backward(ctx, grad_output):
FILE: apex/contrib/csrc/bottleneck/bottleneck.cpp
function checkCudnnError (line 40) | int checkCudnnError(cudnnStatus_t code, const char* expr, const char* fi...
function checkError (line 54) | void checkError(cudaError_t code, char const* func, const char* file, co...
function generateStrides (line 66) | void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, ...
function getFwdConvDilatedFilterDim (line 85) | int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((f...
function getFwdConvPaddedImageDim (line 87) | int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim ...
function getFwdConvOutputDim (line 89) | int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int strid...
function common_conv_descriptors (line 110) | common_conv_descriptors create_common_descriptors(int64_t* x_dim_padded,...
function common_convbias_descriptors (line 160) | common_convbias_descriptors create_conv_bias_add_act_descriptors(int64_t...
function dconv_descriptors (line 273) | dconv_descriptors create_dconv_descriptors(int64_t* x_dim_padded, int64_...
function getConvFusionString (line 351) | std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, in...
function run_conv_scale_bias_add_activation (line 437) | void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* ...
function run_conv_scale_bias (line 583) | void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* c...
function run_dconv_drelu_dscale (line 696) | void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t...
function run_dconv (line 810) | void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride,...
function run_dconv_add (line 905) | void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstr...
function bottleneck_forward (line 1004) | std::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int strid...
function bottleneck_backward (line 1143) | std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stri...
function masked_convbias_descriptors (line 1410) | masked_convbias_descriptors create_conv_bias_add_act_mask_descriptors(in...
function dconv_add_descriptors (line 1588) | dconv_add_descriptors create_dconv_add_descriptors(int64_t* x_dim_padded...
function dconv_mask_descriptors (line 1683) | dconv_mask_descriptors create_dconv_mask_descriptors(int64_t* x_dim_padd...
function run_conv_add_scale_bias_activation (line 1820) | void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* ...
function run_conv_scale_bias_add_activation_mask (line 1964) | void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int6...
function run_dconv_add_drelu_dscale (line 2235) | void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int...
function run_dconv_drelu_dscale_mask (line 2366) | void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, in...
type bottleneck_forward_status (line 2570) | struct bottleneck_forward_status {
method init (line 2603) | void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> ...
function bottleneck_forward_init (line 2731) | std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int ...
function bottleneck_forward_out1 (line 2754) | void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::ve...
function bottleneck_forward_out2_halo (line 2775) | at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor f...
function bottleneck_forward_out2_halo_corr (line 2797) | at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Ten...
function bottleneck_forward_out2 (line 2821) | void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::ve...
function bottleneck_forward_out2_mask (line 2852) | void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, st...
function bottleneck_forward_out2_pad (line 2886) | void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std...
function bottleneck_forward_rest (line 2917) | void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::ve...
type bottleneck_backward_state (line 2960) | struct bottleneck_backward_state {
method init (line 2996) | void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> ...
function bottleneck_backward_init (line 3136) | std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int...
function bottleneck_backward_wgrad3 (line 3162) | void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std:...
function bottleneck_backward_grad_out2 (line 3177) | at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_...
function bottleneck_backward_grad_out1 (line 3209) | at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_...
function bottleneck_backward_grad_out1_mask (line 3238) | at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int st...
function bottleneck_backward_grad_out1_halo_corr (line 3269) | at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, i...
function bottleneck_backward_grad_out1_halo (line 3309) | at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int st...
function bottleneck_backward_wgrad2_pad (line 3342) | void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, ...
function bottleneck_backward_wgrad2 (line 3369) | void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std:...
function bottleneck_backward_wgrad2_halo (line 3397) | at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int strid...
function bottleneck_backward_wgrad1 (line 3434) | void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std:...
function bottleneck_backward_rest (line 3448) | void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::v...
function PYBIND11_MODULE (line 3557) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp
function checkCudnnError (line 46) | int checkCudnnError(cudnnStatus_t code, const char* expr, const char* fi...
function checkError (line 60) | void checkError(cudaError_t code, char const* func, const char* file, co...
function generateStrides (line 72) | void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, ...
function getFwdConvDilatedFilterDim (line 91) | int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((f...
function getFwdConvPaddedImageDim (line 93) | int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim ...
function getFwdConvOutputDim (line 95) | int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int strid...
function getConvFusionString (line 103) | std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, in...
function run_conv_bias (line 193) | void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64...
function run_conv_bias_mask_relu (line 331) | void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_...
function run_conv_cscale_cbias_relu (line 530) | void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t*...
function run_conv_bias_relu (line 732) | void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, ...
function run_drelu_dscale (line 897) | void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Hal...
function run_drelu_dbias (line 1033) | void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half...
function run_dconv_drelu_dbias (line 1159) | void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_di...
function run_dconv (line 1323) | void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* ...
function run_dbias (line 1429) | void run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPt...
function conv_bias_mask_relu_forward (line 1513) | std::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tens...
function conv_cscale_cbias_relu_forward (line 1564) | at::Tensor conv_cscale_cbias_relu_forward(std::vector<at::Tensor> inputs...
function conv_cscale_cbias_relu_backward (line 1609) | std::vector<at::Tensor> conv_cscale_cbias_relu_backward(std::vector<at::...
function conv_bias_relu_forward (line 1674) | std::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> i...
function conv_bias_relu_backward (line 1724) | std::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> ...
function conv_bias_forward (line 1789) | std::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs...
function conv_bias_backward (line 1839) | std::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> input...
function PYBIND11_MODULE (line 1901) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp
type bn_type (line 11) | enum bn_type { BN_FWD, BN_BWD }
function gbn_forward (line 16) | at::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, con...
function gbn_backward (line 67) | std::vector<at::Tensor> gbn_backward(const at::Tensor& x, const at::Tens...
function PYBIND11_MODULE (line 120) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/cudnn_gbn/norm_sample.cpp
function checkCudaError (line 33) | int64_t checkCudaError(cudaError_t code, const char* expr, const char* f...
function checkCudnnError (line 41) | int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char...
function AllowAll (line 49) | bool AllowAll(cudnnBackendDescriptor_t engine_config) {
function generateStrides (line 54) | void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDi...
function run_batch_norm_forward (line 74) | cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims...
function execute_batch_norm_forward (line 222) | void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void...
function run_batch_norm_backward (line 275) | cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDim...
function execute_batch_norm_backward (line 404) | void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, voi...
FILE: apex/contrib/csrc/fmha/fmha_api.cpp
function set_params (line 34) | void set_params(Fused_multihead_attention_fprop_params& params,
function mha_fwd (line 79) | std::vector<at::Tensor> mha_fwd(
function mha_bwd (line 160) | std::vector<at::Tensor> mha_bwd(
function mha_bwd_nl (line 240) | std::vector<at::Tensor> mha_bwd_nl(
function PYBIND11_MODULE (line 322) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/fmha/src/fmha.h
type Qkv_params (line 51) | struct Qkv_params {
function Qkv_params (line 64) | struct Fused_multihead_attention_fprop_params : public Qkv_params {
FILE: apex/contrib/csrc/fmha/src/fmha/gemm.h
function namespace (line 34) | namespace fmha {
type Fragment_accumulator (line 134) | struct Fragment_accumulator
function add (line 140) | void add(const Other_fragment_& other) {
FILE: apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
function namespace (line 30) | namespace fmha {
function __device__ (line 109) | inline __device__ void store(const uint4 (&data)[LDGS]) {
function __device__ (line 120) | inline __device__ void move() {
function __device__ (line 125) | inline __device__ void move(int steps) {
function __device__ (line 201) | inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
function __device__ (line 221) | inline __device__ void move() {
function __device__ (line 226) | inline __device__ void move(const int steps) {
function __device__ (line 276) | __device__ Gmem_tile_mma_sd(void* ptr, const Params& params, const int b...
function __device__ (line 286) | inline __device__ void store(const Type& data, const int mi, const int n...
function __device__ (line 298) | inline __device__ void move() { ptr_ += LOOP_STRIDE_BYTES; }
function __device__ (line 299) | inline __device__ void move(const int steps) { ptr_ += LOOP_STRIDE_BYTES...
function Base (line 308) | struct Gmem_tile_mma_s : public Base {
function Base (line 413) | struct Gmem_tile_dq : public Base {
FILE: apex/contrib/csrc/fmha/src/fmha/mask.h
function namespace (line 30) | namespace fmha {
FILE: apex/contrib/csrc/fmha/src/fmha/smem_tile.h
function namespace (line 33) | namespace fmha {
function __device__ (line 190) | inline __device__ void move_to_next_read_buffer() {
function __device__ (line 199) | inline __device__ void move_next_read_buffer() { this->move_to_next_read...
function __device__ (line 202) | inline __device__ void move_to_next_read_buffer(int N) {
function __device__ (line 210) | inline __device__ void move_next_read_buffer(int N) { this->move_to_next...
function __device__ (line 213) | inline __device__ void move_to_next_write_buffer() {
function __device__ (line 222) | inline __device__ void move_next_write_buffer() { this->move_to_next_wri...
function __device__ (line 225) | inline __device__ void move_read_offset(int delta) { this->smem_read_off...
function __device__ (line 228) | inline __device__ void move_write_offset(int delta) { this->smem_write_o...
function __device__ (line 357) | inline __device__ Smem_tile_row_a(void* smem, int tidx) : Base(smem, tid...
function __device__ (line 422) | inline __device__ void reset_read_offset() {
function __device__ (line 450) | inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {}
function __device__ (line 518) | inline __device__ Smem_tile_col_b(void* smem, int tidx) : Base(smem, tid...
function __device__ (line 587) | inline __device__ void reset_read_offset() {
function __device__ (line 615) | inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {}
function __device__ (line 660) | inline __device__ Smem_tile_row_b(void* smem, int tidx) : Base(smem, tid...
function __device__ (line 796) | inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {}
function __device__ (line 814) | inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {
function __device__ (line 904) | inline __device__ Smem_tile_o(void* smem, int tidx) {
function store (line 957) | void store(const Accumulator (&acc)[M][N], int mi) {
function __device__ (line 1027) | inline __device__ Smem_tile_mma(char* smem, int tidx) {
function store (line 1045) | void store(const uint4 (®s)[M][N]) {
function __device__ (line 1075) | inline __device__ Smem_tile_mma_transposed(char* smem, int tidx) : Base(...
function load (line 1086) | void load(Fragment (&frag)[M][N]) {
function __device__ (line 1120) | inline __device__ Smem_tile_mma_epilogue(char* smem, int tidx) : Base(sm...
function store (line 1135) | void store(const Acc (&acc)[M][N]) {
function store (line 1167) | void store(const uint4 (®s)[M][N]) {
FILE: apex/contrib/csrc/fmha/src/fmha/softmax.h
function namespace (line 30) | namespace fmha {
FILE: apex/contrib/csrc/fmha/src/fmha/utils.h
function namespace (line 38) | namespace fmha {
function __device__ (line 323) | static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
function __device__ (line 331) | static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
function __device__ (line 339) | static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
function __device__ (line 347) | static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
function __device__ (line 356) | static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
function __device__ (line 367) | static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
function __device__ (line 395) | static inline __device__ uint32_t habs2(uint32_t x) {
function __device__ (line 410) | static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
function __device__ (line 418) | static inline __device__ uint16_t float_to_half(float f) {
function __device__ (line 426) | static inline __device__ uint32_t float2_to_half2(float a, float b) {
function __device__ (line 440) | static inline __device__ uint32_t float_to_half2(float a) { return float...
function __device__ (line 444) | static inline __device__ uint32_t float2_to_half2(const float2& f) { ret...
function __device__ (line 448) | static inline __device__ uint2 float4_to_half4(float x, float y, float z...
function __device__ (line 457) | static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t...
function __device__ (line 465) | static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uin...
function __device__ (line 477) | static inline __device__ uint32_t h0_h0(uint32_t x) {
function __device__ (line 485) | static inline __device__ float h0_to_float(uint32_t h2) {
function __device__ (line 500) | static inline __device__ uint32_t h1_h1(uint32_t x) {
function __device__ (line 508) | static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
function __device__ (line 516) | static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return ...
function __device__ (line 520) | static inline __device__ uint2 hadd4(uint2 a, uint2 b) {
function __device__ (line 529) | static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, ...
function __device__ (line 533) | static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
function __device__ (line 544) | static inline __device__ uint4 fadd4(uint4 a, uint4 b) {
function __device__ (line 555) | static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, ...
function __device__ (line 559) | static inline __device__ float half_to_float(uint16_t h) {
function __device__ (line 567) | static inline __device__ float2 half2_to_float2(uint32_t x) {
function __device__ (line 583) | static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t ...
function __device__ (line 591) | static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
function __device__ (line 599) | static inline __device__ float sigmoid(float x) { return 1.f / (1.f + ex...
function __device__ (line 730) | inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)...
function __device__ (line 733) | inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); }
function __device__ (line 736) | inline __device__ void load(int ii, bool p) {
function __device__ (line 885) | inline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast<u...
function __device__ (line 889) | inline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast<...
function __device__ (line 893) | inline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast<...
function __device__ (line 897) | inline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast<uin...
function __device__ (line 901) | inline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast<uin...
function __device__ (line 909) | inline __device__ void sts(uint32_t ptr, uint16_t val) {
function __device__ (line 915) | inline __device__ void sts(uint32_t ptr, uint32_t val) {
function __device__ (line 921) | inline __device__ void sts(uint32_t ptr, uint2 val) {
function __device__ (line 927) | inline __device__ void sts(uint32_t ptr, uint4 val) {
function __device__ (line 975) | __device__ inline T operator()(T const& x, T const& y) { return x > y ? ...
function __device__ (line 982) | __device__ inline T operator()(T const& x, T const& y) { return x + y; }
function T (line 991) | inline T run(T x, Operator& op) {
type Allreduce (line 1001) | struct Allreduce
function T (line 1003) | inline T run(T x, Operator& op) {
FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
function namespace (line 35) | namespace fmha {
FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
function namespace (line 35) | namespace fmha {
FILE: apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
function namespace (line 35) | namespace fmha {
FILE: apex/contrib/csrc/fmha/src/fmha_kernel.h
function namespace (line 39) | namespace fmha {
function __device__ (line 76) | __device__ Noloop_traits(const int bidc, const Block_info& binfo) : bidc...
function move_all (line 89) | void move_all(Tiles&... tiles) const {
FILE: apex/contrib/csrc/fmha/src/fmha_utils.h
type Data_type (line 49) | enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_T...
function set_alpha (line 53) | static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtyp...
function get_size_in_bytes (line 71) | static inline size_t get_size_in_bytes(size_t n, Data_type dtype) {
FILE: apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp
function focal_loss_forward (line 23) | std::vector<at::Tensor> focal_loss_forward(const at::Tensor& cls_output,...
function focal_loss_backward (line 34) | at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::...
function PYBIND11_MODULE (line 42) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/gpu_direct_storage/gds.cpp
type apex::contrib::gds (line 16) | namespace apex::contrib::gds {
function cuFileGetErrorString (line 20) | std::string cuFileGetErrorString(T status) {
function cuFileGetErrorString (line 27) | std::string cuFileGetErrorString(T status) {
FILE: apex/contrib/csrc/gpu_direct_storage/gds.h
function namespace (line 10) | namespace apex::contrib::gds {
FILE: apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp
function PYBIND11_MODULE (line 10) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/group_norm/group_norm_nhwc.cpp
function unpack (line 16) | float inline unpack(const T& x) {
function unpack (line 21) | float inline unpack(const __half& x) {
function unpack (line 26) | float inline unpack(const __nv_bfloat16& x) {
function unpack (line 31) | float inline unpack(const float& x) {
function check_results (line 38) | void check_results(const char* name, const T* out, const T* ref, size_t ...
function group_norm_nhwc_bwd_ (line 132) | static void group_norm_nhwc_bwd_(void* dx_h, float* dgamma_h, float* dbe...
function group_norm_nhwc_fwd_ (line 310) | static void group_norm_nhwc_fwd_(void* y_h, const void* x_h, const float...
function random_data (line 403) | void random_data(T* dst_h, size_t n, bool use_1s, int range = 3) {
type Mode (line 430) | enum class Mode { FWD_INFERENCE, FWD_TRAINING, BWD }
function main (line 434) | int main(int argc, char** argv) {
FILE: apex/contrib/csrc/group_norm/group_norm_nhwc.h
function div_up (line 28) | int div_up(int m, int n) { return (m + n - 1) / n; }
function sigmoid (line 32) | float sigmoid(float x) { return 1.f / (1.f + expf(-x)); }
function __device__ (line 36) | static inline __device__ void spin_wait_(int* barrier, int step, int exp...
type PrecisionMode (line 51) | enum PrecisionMode {
type Group_sums (line 65) | struct Group_sums {
type Group_sums_op (line 76) | struct Group_sums_op {
type Group_norm_nhwc_fwd_params (line 88) | struct Group_norm_nhwc_fwd_params {
type Group_norm_nhwc_bwd_params (line 147) | struct Group_norm_nhwc_bwd_params {
FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h
function group_norm_nhwc_bwd_one_pass_setup (line 88) | void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params& para...
function group_norm_nhwc_bwd_one_pass_run (line 160) | inline void group_norm_nhwc_bwd_one_pass_run(const Group_norm_nhwc_bwd_p...
FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h
function group_norm_nhwc_fwd_one_pass_setup (line 88) | inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_param...
function group_norm_nhwc_fwd_one_pass_run (line 150) | inline void group_norm_nhwc_fwd_one_pass_run(const Group_norm_nhwc_fwd_p...
FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp
function group_norm_fwd (line 42) | std::vector<torch::Tensor> group_norm_fwd(torch::Tensor input, int group...
function group_norm_bwd (line 148) | std::vector<torch::Tensor> group_norm_bwd(torch::Tensor grad_output, tor...
function PYBIND11_MODULE (line 262) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/group_norm/traits.h
type Fp32 (line 17) | struct Fp32 {
function float2 (line 27) | float2 pack(const float2& f2) { return f2; }
function __device__ (line 29) | static inline __device__ float2 zero() { return {0.f, 0.f}; }
type Fp16 (line 34) | struct Fp16 {
function __half2 (line 47) | __half2 pack(const float2& f2) {
function __device__ (line 52) | static inline __device__ __half2 zero() {
type Bf16 (line 60) | struct Bf16 {
function __nv_bfloat162 (line 73) | __nv_bfloat162 pack(const float2& f2) {
function __device__ (line 78) | static inline __device__ __nv_bfloat162 zero() {
type Fp32IOFp16W (line 85) | struct Fp32IOFp16W {
type Fp32IOBf16W (line 92) | struct Fp32IOBf16W {
type Fp32IOFp32W (line 99) | struct Fp32IOFp32W {
type Fp16IOFp16W (line 108) | struct Fp16IOFp16W {
type Fp16IOBf16W (line 115) | struct Fp16IOBf16W {
type Fp16IOFp32W (line 122) | struct Fp16IOFp32W {
type Bf16IOFp16W (line 130) | struct Bf16IOFp16W {
type Bf16IOBf16W (line 137) | struct Bf16IOBf16W {
type Bf16IOFp32W (line 144) | struct Bf16IOFp32W {
FILE: apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py
function run (line 22) | def run():
FILE: apex/contrib/csrc/group_norm_v2/gn.cpp
type group_norm_v2 (line 6) | namespace group_norm_v2 {
function gn (line 8) | torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, fl...
function gn_bwd (line 52) | auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor ...
function PYBIND11_MODULE (line 105) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/group_norm_v2/gn.hpp
type group_norm_v2 (line 7) | namespace group_norm_v2 {
type Meta (line 9) | struct Meta {
FILE: apex/contrib/csrc/group_norm_v2/gn_utils.cpp
type group_norm_v2 (line 6) | namespace group_norm_v2 {
function cudaDeviceProp (line 8) | cudaDeviceProp const& get_device_prop(int device_id) {
FILE: apex/contrib/csrc/group_norm_v2/gn_utils.hpp
type group_norm_v2 (line 40) | namespace group_norm_v2 {
function __host__ (line 47) | __host__ __device__ inline int print_rank_0(char const* fmt, Ts&&... a...
FILE: apex/contrib/csrc/groupbn/batch_norm.h
function class (line 41) | class NhwcBatchNorm {
function createTensorDescriptor (line 185) | void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) {
function destroyTensorDescriptor (line 191) | void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
type StorageType (line 215) | typedef uint16_t StorageType;
function _fwdKernelLauncher (line 245) | void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams para...
function _bwdKernelLauncher (line 298) | void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams para...
function smem_driven_bwd_occupancy (line 385) | static int smem_driven_bwd_occupancy(int device_id, const int max_cta_pe...
function std (line 394) | const std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {
function _setFwdParams (line 423) | void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams* params) const {
function _setFwdInferenceParams (line 447) | void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferencePara...
function _setBwdParams (line 460) | void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams* params) const {
function fwdInference (line 481) | void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
function dim3 (line 516) | dim3 NhwcBatchNorm::calc_fwd_grid(int* loop, const int grid_dim_x) {
function dim3 (line 539) | dim3 NhwcBatchNorm::calc_bwd_grid(int* loop, const int grid_dim_x) {
function fwd (line 562) | void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_dat...
function dgrad (line 593) | void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_d...
FILE: apex/contrib/csrc/groupbn/batch_norm_add_relu.h
function class (line 41) | class NhwcBatchNormAddRelu {
function createTensorDescriptor (line 189) | void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) {
function destroyTensorDescriptor (line 195) | void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
type StorageType (line 220) | typedef uint16_t StorageType;
function _fwdKernelLauncher (line 249) | void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams para...
function _bwdKernelLauncher (line 292) | void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams para...
function smem_driven_bwd_occupancy (line 347) | static int smem_driven_bwd_occupancy(int device_id, const int max_cta_pe...
function std (line 356) | const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
function _setFwdParams (line 391) | void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams* params)...
function _setFwdInferenceParams (line 415) | void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInfere...
function _setBwdParams (line 428) | void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams* params)...
function fwdInference (line 449) | void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
function dim3 (line 479) | dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int* loop, const int grid_dim_x) {
function dim3 (line 502) | dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int* loop, const int grid_dim_x) {
function fwd (line 525) | void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void*...
function dgrad (line 558) | void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, voi...
FILE: apex/contrib/csrc/groupbn/cuda_utils.h
function namespace (line 5) | namespace at {
FILE: apex/contrib/csrc/groupbn/interface.cpp
function PYBIND11_MODULE (line 75) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
type T (line 44) | typedef T Type;
type Type (line 52) | typedef int Type;
function DEVICE_FUNCTION (line 239) | DEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&s...
function DEVICE_FUNCTION (line 245) | DEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&s...
function DEVICE_FUNCTION (line 251) | DEVICE_FUNCTION void scaled_write_to_gmem(float* gmem, int idx, const fl...
function DEVICE_FUNCTION (line 258) | DEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x...
function DEVICE_FUNCTION (line 264) | DEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[1]...
function DEVICE_FUNCTION (line 268) | DEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x...
function DEVICE_FUNCTION (line 274) | DEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[2]) {
function Storage (line 341) | Storage relu(Storage in) {
function parallel_sums (line 528) | void parallel_sums(float* smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
type ParallelSums (line 621) | struct ParallelSums
type ParallelSums (line 635) | struct ParallelSums
function div_up (line 646) | static inline int div_up(int m, int n) { return (m + n - 1) / n; }
function DEVICE_FUNCTION (line 651) | DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expect...
type NhwcBatchNormFwdInferenceParams (line 677) | struct NhwcBatchNormFwdInferenceParams {
type NhwcBatchNormFwdParams (line 773) | struct NhwcBatchNormFwdParams {
type PackedStorage (line 834) | typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
type typename (line 836) | typedef typename PackedStorage_::Type PackedStorageType;
type NhwcBatchNormBwdParams (line 1339) | struct NhwcBatchNormBwdParams {
function nhwc_batch_norm_bwd (line 1467) | void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {
function nhwc_batch_norm_bwd_relu (line 1816) | void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {
function nhwc_batch_norm_bwd_add_relu (line 2188) | void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {
FILE: apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
function index_mul_2d_float_forward (line 34) | void index_mul_2d_float_forward(at::Tensor& out, const at::Tensor& in1, ...
function index_mul_2d_float_backward (line 38) | void index_mul_2d_float_backward(at::Tensor& grad_in1, at::Tensor& grad_...
function index_mul_2d_float_backwrad_backward (line 43) | void index_mul_2d_float_backwrad_backward(at::Tensor& grad_grad_out, at:...
function index_mul_2d_half_forward (line 51) | void index_mul_2d_half_forward(at::Tensor& out, const at::Tensor& in1, c...
function index_mul_2d_half_backward (line 55) | void index_mul_2d_half_backward(at::Tensor& grad_in1, at::Tensor& grad_i...
function index_mul_2d_half_backwrad_backward (line 60) | void index_mul_2d_half_backwrad_backward(at::Tensor& grad_grad_out, at::...
function PYBIND11_MODULE (line 68) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/layer_norm/ln.h
function namespace (line 10) | namespace layer_norm {
FILE: apex/contrib/csrc/layer_norm/ln_api.cpp
type layer_norm (line 24) | namespace layer_norm {
function get_type_id (line 33) | uint32_t get_type_id(torch::Dtype dtype) {
function get_key (line 47) | uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype ...
function ln_fwd (line 83) | std::vector<at::Tensor> ln_fwd(const at::Tensor& x, // BxSxhidden_size
function ln_bwd (line 158) | std::vector<at::Tensor> ln_bwd(const at::Tensor& dz, ...
function PYBIND11_MODULE (line 253) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/layer_norm/ln_kernel_traits.h
function namespace (line 5) | namespace layer_norm {
function Base (line 63) | struct Kernel_traits : public Base {
FILE: apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp
type multihead_attn (line 12) | namespace multihead_attn {
type fused_softmax (line 13) | namespace fused_softmax {
type additive_mask_softmax_dropout (line 14) | namespace additive_mask_softmax_dropout {
function fwd (line 22) | std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, in...
function bwd (line 35) | torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& o...
type mask_softmax_dropout (line 49) | namespace mask_softmax_dropout {
function fwd (line 57) | std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, in...
function bwd (line 71) | torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& o...
type encdec (line 89) | namespace encdec {
type cublas_gemmex (line 90) | namespace cublas_gemmex {
function fwd (line 104) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 129) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
type encdec_norm_add (line 170) | namespace encdec_norm_add {
type cublas_gemmex (line 171) | namespace cublas_gemmex {
function fwd (line 190) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 221) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
type self (line 278) | namespace self {
type cublas_gemmex (line 279) | namespace cublas_gemmex {
function fwd (line 291) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 311) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
type self_bias (line 342) | namespace self_bias {
type cublas_gemmex (line 343) | namespace cublas_gemmex {
function fwd (line 358) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 379) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
type self_bias_additive_mask (line 410) | namespace self_bias_additive_mask {
type cublas_gemmex (line 411) | namespace cublas_gemmex {
function fwd (line 428) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 450) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
type self_norm_add (line 481) | namespace self_norm_add {
type cublas_gemmex (line 482) | namespace cublas_gemmex {
function fwd (line 498) | std::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, ...
function bwd (line 523) | std::vector<torch::Tensor> bwd(int heads, torch::Tensor const& out...
function PYBIND11_MODULE (line 572) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp
function nccl_free_plug (line 23) | void nccl_free_plug(void* ptr, std::size_t size, int device, void* strea...
function maybe_init (line 27) | void maybe_init() {
function get_nccl_allocator (line 34) | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> get_nccl...
function PYBIND11_MODULE (line 39) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/nccl_p2p/nccl_version.cpp
function PYBIND11_MODULE (line 9) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_nccl_version", &ge...
FILE: apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
function strided_check_finite (line 30) | void strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy,...
function adam (line 34) | void adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& ...
function reversible_adam (line 50) | void reversible_adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, a...
function maybe_adam_undo (line 67) | void maybe_adam_undo(at::Tensor& overflow_flag, at::Tensor& p, at::Tenso...
function maybe_cast (line 82) | void maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor&...
function PYBIND11_MODULE (line 91) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp
function PYBIND11_MODULE (line 8) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
function PYBIND11_MODULE (line 17) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/peer_memory/peer_memory.cpp
function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/transducer/transducer_joint.cpp
function transducer_joint_forward (line 19) | std::vector<torch::Tensor> transducer_joint_forward(torch::Tensor f, tor...
function transducer_joint_backward (line 32) | std::vector<torch::Tensor> transducer_joint_backward(std::vector<torch::...
function PYBIND11_MODULE (line 44) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/transducer/transducer_loss.cpp
function transducer_loss_forward (line 20) | std::vector<torch::Tensor> transducer_loss_forward(torch::Tensor x, torc...
function transducer_loss_backward (line 31) | torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lo...
function PYBIND11_MODULE (line 48) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/csrc/xentropy/interface.cpp
function softmax_xentropy_forward (line 22) | std::vector<at::Tensor> softmax_xentropy_forward(const at::Tensor& input...
function softmax_xentropy_backward (line 30) | at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const ...
function PYBIND11_MODULE (line 41) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: apex/contrib/cudnn_gbn/batch_norm.py
class _GroupBatchNorm2d (line 10) | class _GroupBatchNorm2d(torch.autograd.Function):
method forward (line 13) | def forward(
method backward (line 51) | def backward(ctx, grad_output):
class GroupBatchNorm2d (line 85) | class GroupBatchNorm2d(_BatchNorm):
method __init__ (line 122) | def __init__(
method get_peer_buffers (line 147) | def get_peer_buffers(self, num_features):
method _check_input_dim (line 166) | def _check_input_dim(self, input):
method _check_input_channels (line 170) | def _check_input_channels(self, input):
method forward (line 174) | def forward(self, input: Tensor) -> Tensor:
FILE: apex/contrib/examples/gpu_direct_storage/benchmark_load.py
function run_benchmark_torch_load (line 5) | def run_benchmark_torch_load():
function run_benchmark (line 28) | def run_benchmark(func):
function load_data_yes_gds (line 53) | def load_data_yes_gds(tensor, filename):
function load_data_no_gds (line 57) | def load_data_no_gds(tensor, filename):
FILE: apex/contrib/examples/gpu_direct_storage/benchmark_save.py
function run_benchmark (line 6) | def run_benchmark(func):
function save_data_yes_gds (line 28) | def save_data_yes_gds(tensor, filename):
function save_data_no_gds (line 32) | def save_data_no_gds(tensor, filename):
FILE: apex/contrib/examples/nccl_allocator/cache.py
function set_device (line 5) | def set_device(dev):
function print_used_mem (line 11) | def print_used_mem(string, nvsmi, device_id = 0):
FILE: apex/contrib/examples/nccl_allocator/toy_ddp.py
class ToyModel (line 12) | class ToyModel(nn.Module):
method __init__ (line 13) | def __init__(self):
method forward (line 19) | def forward(self, x):
FILE: apex/contrib/fmha/fmha.py
class FMHAFun (line 33) | class FMHAFun(torch.autograd.Function):
method forward (line 35) | def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_...
method backward (line 69) | def backward(ctx, dout):
class FMHA (line 96) | class FMHA(torch.nn.Module):
method __init__ (line 97) | def __init__(self, config):
method forward (line 106) | def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tenso...
FILE: apex/contrib/focal_loss/focal_loss.py
class FocalLoss (line 6) | class FocalLoss(torch.autograd.Function):
method forward (line 8) | def forward(
method backward (line 32) | def backward(ctx, grad_loss):
function focal_loss (line 42) | def focal_loss(
FILE: apex/contrib/gpu_direct_storage/__init__.py
function GDSFile (line 6) | def GDSFile(filename, mode):
FILE: apex/contrib/group_norm/group_norm.py
function one_time_warning (line 22) | def one_time_warning(msg: str):
function get_cc_and_sm_count (line 29) | def get_cc_and_sm_count(device_index: int):
function torch_group_norm (line 37) | def torch_group_norm(x, g, w, b, eps, act=""):
function group_norm_nhwc_fprop (line 50) | def group_norm_nhwc_fprop(
function fake_group_norm_nhwc_fprop (line 86) | def fake_group_norm_nhwc_fprop(
function group_norm_nhwc_bprop (line 104) | def group_norm_nhwc_bprop(
function fake_group_norm_nhwc_bprop (line 145) | def fake_group_norm_nhwc_bprop(
function backward (line 163) | def backward(ctx, grad_output, grad_sums):
function setup_context (line 178) | def setup_context(ctx, inputs, output):
function cuda_group_norm_nhwc_one_pass (line 193) | def cuda_group_norm_nhwc_one_pass(x, G, weight, bias, eps, act=None):
function cuda_group_norm_nhwc_two_pass (line 198) | def cuda_group_norm_nhwc_two_pass(x, G, weight, bias, eps, act=None):
function cuda_group_norm_v2_nhwc (line 203) | def cuda_group_norm_v2_nhwc(x, G, weight, bias, eps, act=None):
class GroupNorm (line 211) | class GroupNorm(torch.nn.Module):
method __init__ (line 327) | def __init__(
method reset_parameters (line 358) | def reset_parameters(self) -> None:
method _check_legality (line 363) | def _check_legality(self, input: Tensor) -> bool:
method _check_v2_legality (line 390) | def _check_v2_legality(self, input: Tensor) -> bool:
method forward (line 418) | def forward(self, input: Tensor) -> Tensor:
method extra_repr (line 448) | def extra_repr(self) -> str:
FILE: apex/contrib/groupbn/batch_norm.py
class bn_NHWC_impl (line 8) | class bn_NHWC_impl(torch.autograd.Function):
method forward (line 10) | def forward(
method backward (line 81) | def backward(ctx, grad_y):
class bn_addrelu_NHWC_impl (line 148) | class bn_addrelu_NHWC_impl(torch.autograd.Function):
method forward (line 150) | def forward(
method backward (line 223) | def backward(ctx, grad_y):
class BatchNorm2d_NHWC (line 290) | class BatchNorm2d_NHWC(_BatchNorm):
method __init__ (line 292) | def __init__(
method forward (line 406) | def forward(self, x, z=None):
method __del__ (line 462) | def __del__(self):
FILE: apex/contrib/index_mul_2d/index_mul_2d.py
class IndexMul2d_ (line 6) | class IndexMul2d_(torch.autograd.Function):
method forward (line 14) | def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Ten...
method backward (line 47) | def backward(ctx, grad_out):
class IndexMul2dBackward_ (line 55) | class IndexMul2dBackward_(torch.autograd.Function):
method forward (line 57) | def forward(
method backward (line 90) | def backward(ctx, grad_grad_in1, grad_grad_in2):
FILE: apex/contrib/layer_norm/layer_norm.py
class FastLayerNormFN (line 8) | class FastLayerNormFN(torch.autograd.Function):
method forward (line 10) | def forward(ctx, x, gamma, beta, epsilon, memory_efficient=False):
method backward (line 27) | def backward(ctx, dy):
function _fast_layer_norm (line 39) | def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient):
class FastLayerNorm (line 45) | class FastLayerNorm(torch.nn.Module):
method __init__ (line 46) | def __init__(self, hidden_size, eps=1e-5, memory_efficient=False):
method reset_parameters (line 54) | def reset_parameters(self):
method forward (line 58) | def forward(self, x):
FILE: apex/contrib/multihead_attn/encdec_multihead_attn.py
function jit_dropout_add (line 15) | def jit_dropout_add(x, residual, prob, is_training):
class EncdecMultiheadAttn (line 22) | class EncdecMultiheadAttn(nn.Module):
method __init__ (line 28) | def __init__(
method reset_parameters (line 92) | def reset_parameters(self):
method forward (line 111) | def forward(
FILE: apex/contrib/multihead_attn/encdec_multihead_attn_func.py
class EncdecAttnFunc (line 5) | class EncdecAttnFunc(torch.autograd.Function):
method forward (line 7) | def forward(
method backward (line 205) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py
class FastEncdecAttnFunc (line 6) | class FastEncdecAttnFunc(torch.autograd.Function):
method forward (line 8) | def forward(
method backward (line 75) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py
class FastEncdecAttnNormAddFunc (line 13) | class FastEncdecAttnNormAddFunc(torch.autograd.Function):
method forward (line 15) | def forward(
method backward (line 96) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/fast_self_multihead_attn_func.py
class FastSelfAttnFunc (line 6) | class FastSelfAttnFunc(torch.autograd.Function):
method forward (line 8) | def forward(
method backward (line 155) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py
class FastSelfAttnNormAddFunc (line 6) | class FastSelfAttnNormAddFunc(torch.autograd.Function):
method forward (line 8) | def forward(
method backward (line 82) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/mask_softmax_dropout_func.py
class MaskSoftmaxDropout (line 6) | class MaskSoftmaxDropout(torch.autograd.Function):
method forward (line 8) | def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, ...
method backward (line 62) | def backward(ctx, output_grads):
FILE: apex/contrib/multihead_attn/self_multihead_attn.py
function jit_dropout_add (line 15) | def jit_dropout_add(x, residual, prob, is_training):
class SelfMultiheadAttn (line 22) | class SelfMultiheadAttn(nn.Module):
method __init__ (line 28) | def __init__(
method reset_parameters (line 114) | def reset_parameters(self):
method forward (line 141) | def forward(
FILE: apex/contrib/multihead_attn/self_multihead_attn_func.py
class SelfAttnFunc (line 5) | class SelfAttnFunc(torch.autograd.Function):
method forward (line 7) | def forward(
method backward (line 182) | def backward(ctx, output_grads):
FILE: apex/contrib/nccl_allocator/nccl_allocator.py
function get_func_args (line 11) | def get_func_args(func):
function create_nccl_mem_pool (line 18) | def create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.Me...
function init (line 36) | def init() -> None:
class nccl_mem (line 41) | class nccl_mem:
method __init__ (line 42) | def __init__(self, pool, enabled=True, device=None, group=None):
method __enter__ (line 66) | def __enter__(self):
method __exit__ (line 75) | def __exit__(self, *args):
FILE: apex/contrib/openfold_triton/__init__.py
function _get_tuneable_triton_func_name (line 42) | def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFu...
function _save_triton_auto_tune_cache (line 62) | def _save_triton_auto_tune_cache(strict: bool = True, verbose: bool = Fa...
function _load_triton_auto_tune_cache (line 82) | def _load_triton_auto_tune_cache(f: BinaryIO, strict: bool = True, verbo...
function sync_triton_auto_tune_cache_across_gpus (line 102) | def sync_triton_auto_tune_cache_across_gpus(strict: bool = True, verbose...
FILE: apex/contrib/openfold_triton/_layer_norm_backward_kernels.py
function _layer_norm_backward_dx (line 34) | def _layer_norm_backward_dx(
function _layer_norm_backward_dw_db_partial (line 112) | def _layer_norm_backward_dw_db_partial(
function _layer_norm_backward_dx_strided (line 161) | def _layer_norm_backward_dx_strided(
function _layer_norm_backward_dw_db_partial_strided (line 254) | def _layer_norm_backward_dw_db_partial_strided(
function _layer_norm_backward_buf_reduce (line 308) | def _layer_norm_backward_buf_reduce(
FILE: apex/contrib/openfold_triton/_layer_norm_forward_kernels.py
function _layer_norm_forward (line 34) | def _layer_norm_forward(
function _layer_norm_forward_strided (line 85) | def _layer_norm_forward_strided(
FILE: apex/contrib/openfold_triton/_mha_kernel.py
function init_to_zero (line 7) | def init_to_zero(name):
function get_configs_fwd (line 11) | def get_configs_fwd():
function _attention_core (line 47) | def _attention_core(
function _bwd_preprocess (line 303) | def _bwd_preprocess(
function get_configs_bwd (line 334) | def get_configs_bwd():
function _bwd_kernel (line 371) | def _bwd_kernel(
FILE: apex/contrib/openfold_triton/fused_adam_swa.py
class _DTypeEnum (line 22) | class _DTypeEnum(Enum):
class AdamMathType (line 47) | class AdamMathType(Enum):
function _adam_math (line 54) | def _adam_math(
function _swa_math (line 102) | def _swa_math(
function _multi_tensor_adam_swa (line 116) | def _multi_tensor_adam_swa(
class FusedAdamSWA (line 209) | class FusedAdamSWA(Optimizer):
method __init__ (line 210) | def __init__(
method _build_pointer_buffers (line 281) | def _build_pointer_buffers(self):
method step (line 372) | def step(
method from_optim (line 460) | def from_optim(
FILE: apex/contrib/openfold_triton/layer_norm.py
class LayerNormSmallShapeOptImpl (line 26) | class LayerNormSmallShapeOptImpl(Function):
method forward (line 28) | def forward(ctx, inputs, normalized_shape, weight, bias, eps=1e-05):
method backward (line 90) | def backward(ctx, d_y):
FILE: apex/contrib/openfold_triton/mha.py
function is_enabled (line 20) | def is_enabled() -> Optional[bool]:
function enable (line 25) | def enable() -> None:
function disable (line 30) | def disable() -> None:
function CanSchTriMHA (line 36) | def CanSchTriMHA(in_shape, has_bias=True, inf=1e9, training=True):
function schedule_triton_mha (line 90) | def schedule_triton_mha(in_shape, fwd=True):
class FusedAttenionCoreFunc (line 131) | class FusedAttenionCoreFunc(torch.autograd.Function):
method forward (line 133) | def forward(ctx, q, k, v, mask=None, bias=None, inf=1000000000.0, is_t...
method backward (line 246) | def backward(ctx, do):
function _attention_bias (line 396) | def _attention_bias(
function _attention_no_bias (line 434) | def _attention_no_bias(
FILE: apex/contrib/optimizers/distributed_fused_adam.py
function _coalescing_manager (line 64) | def _coalescing_manager(group, device, reqs):
class _CoalescingManager (line 74) | class _CoalescingManager:
method __init__ (line 75) | def __init__(self):
method append (line 78) | def append(self, work: torch.distributed.Work) -> None:
method wait (line 82) | def wait(self) -> None:
function _coalescing_manager (line 87) | def _coalescing_manager(
function _coalescing_manager_append_work (line 103) | def _coalescing_manager_append_work(
function _coalescing_manager_append_work (line 113) | def _coalescing_manager_append_work(
function _round_to_multiple (line 141) | def _round_to_multiple(
function _devices_match (line 150) | def _devices_match(device1: torch.device, device2: torch.device) -> bool:
function _multi_tensor_copy (line 168) | def _multi_tensor_copy(
function _disable_pre_forward_hook (line 225) | def _disable_pre_forward_hook(
function _bf16_rem_to_fp32 (line 244) | def _bf16_rem_to_fp32(
class DistributedFusedAdam (line 270) | class DistributedFusedAdam(torch.optim.Optimizer):
class ParameterFragment (line 389) | class ParameterFragment:
class StateBucket (line 416) | class StateBucket:
method __init__ (line 419) | def __init__(
method dtypes (line 478) | def dtypes(self) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
class GradientStatus (line 486) | class GradientStatus(enum.Enum):
class GradientBucket (line 498) | class GradientBucket:
method __init__ (line 501) | def __init__(self):
class ParameterStatus (line 513) | class ParameterStatus(enum.Enum):
class ParameterBucket (line 523) | class ParameterBucket:
method __init__ (line 526) | def __init__(self):
method __init__ (line 540) | def __init__(
method __repr__ (line 827) | def __repr__(self) -> str:
method _broadcast_params (line 847) | def _broadcast_params(self) -> None:
method _make_post_backward_hook (line 864) | def _make_post_backward_hook(
method _register_post_backward_hooks (line 899) | def _register_post_backward_hooks(self) -> None:
method _make_pre_forward_hook (line 915) | def _make_pre_forward_hook(
method _register_pre_forward_hooks (line 938) | def _register_pre_forward_hooks(self) -> None:
method init_param_buffer (line 1074) | def init_param_buffer(self) -> None:
method _init_grad_buffer (line 1156) | def _init_grad_buffer(self) -> None:
method parameters (line 1195) | def parameters(self) -> Iterable[torch.nn.Parameter]:
method parameter (line 1199) | def parameter(
method init_params (line 1228) | def init_params(
method init_params_bucket (line 1275) | def init_params_bucket(
method _init_param_state (line 1347) | def _init_param_state(
method zero_grad (line 1551) | def zero_grad(self, set_to_none: bool = False) -> None:
method _grad_copy (line 1600) | def _grad_copy(self, param: torch.nn.Parameter) -> None:
method _param_copy (line 1668) | def _param_copy(
method _param_copy_fragments (line 1716) | def _param_copy_fragments(
method grad_buffer_view (line 1769) | def grad_buffer_view(self, param: torch.nn.Parameter) -> torch.Tensor:
method _force_bucket_grad_sync (line 1796) | def _force_bucket_grad_sync(self) -> None:
method _try_start_bucket_grad_sync (line 1827) | def _try_start_bucket_grad_sync(
method _start_bucket_grad_sync (line 1877) | def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
method _finish_bucket_grad_sync (line 1966) | def _finish_bucket_grad_sync(self) -> None:
method _try_start_bucket_param_sync (line 1986) | def _try_start_bucket_param_sync(
method _start_bucket_param_sync (line 2032) | def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> ...
method _finish_bucket_param_sync (line 2084) | def _finish_bucket_param_sync(self) -> None:
method no_sync (line 2095) | def no_sync(
method grad_sync (line 2124) | def grad_sync(self) -> None:
method param_sync (line 2138) | def param_sync(self) -> None:
method _local_grad_norm (line 2150) | def _local_grad_norm(
method grad_norm (line 2236) | def grad_norm(
method clip_grad_norm (line 2275) | def clip_grad_norm(
method unscale_grads (line 2307) | def unscale_grads(
method step (line 2368) | def step(
method _local_step (line 2505) | def _local_step(self, bucket_ids: List[int]) -> None:
method _local_step_with_param_remainders (line 2611) | def _local_step_with_param_remainders(
method _local_step_with_scaled_states (line 2694) | def _local_step_with_scaled_states(
method _check_params_shard_dtypes (line 2777) | def _check_params_shard_dtypes(
method _apply_state_scale (line 2834) | def _apply_state_scale(
method state_dict (line 2862) | def state_dict(
method _state_dict_v1 (line 2907) | def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]:
method _state_dict_v2 (line 3059) | def _state_dict_v2(self) -> Optional[dict]:
method load_state_dict (line 3329) | def load_state_dict(self, state_dict: dict) -> None:
method _load_state_dict_v1 (line 3351) | def _load_state_dict_v1(self, state_dict: dict) -> None:
method _load_state_dict_v2 (line 3397) | def _load_state_dict_v2(self, state_dict: dict) -> None:
FILE: apex/contrib/optimizers/distributed_fused_lamb.py
function get_process_group_ranks (line 15) | def get_process_group_ranks(group):
class DistributedFusedLAMB (line 26) | class DistributedFusedLAMB(torch.optim.Optimizer):
class AtomicCounter (line 86) | class AtomicCounter(object):
method __init__ (line 87) | def __init__(self):
method add (line 94) | def add(self, idx):
method __init__ (line 99) | def __init__(
method _lazy_init_stage1 (line 387) | def _lazy_init_stage1(self):
method _lazy_init_stage2 (line 586) | def _lazy_init_stage2(self):
method set_is_accumulation_step (line 787) | def set_is_accumulation_step(self, is_accumulation_step):
method set_last_step (line 790) | def set_last_step(self, last_step):
method _get_flush_block (line 793) | def _get_flush_block(self):
method _full_all_reduce_scale (line 816) | def _full_all_reduce_scale(self, block_id, scale):
method _full_all_reduce (line 845) | def _full_all_reduce(self, block_id):
method _reduce_scatter_and_all_reduce_scale (line 860) | def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
method _reduce_scatter_and_all_reduce (line 903) | def _reduce_scatter_and_all_reduce(self, block_id):
method _pipeline_block_reductions (line 943) | def _pipeline_block_reductions(self, block_id):
method __compute_contrib_param_norm (line 1014) | def __compute_contrib_param_norm(self):
method __compute_contrib_update_norm (line 1054) | def __compute_contrib_update_norm(self):
method _pipeline_step (line 1070) | def _pipeline_step(self):
method _flatten_grad_mt (line 1167) | def _flatten_grad_mt(self, scale):
method _do_overlapped_reduction (line 1206) | def _do_overlapped_reduction(self, param_i, param):
method set_global_scale (line 1222) | def set_global_scale(self, global_scale):
method global_scale (line 1227) | def global_scale(self):
method L2_grad_norm (line 1231) | def L2_grad_norm(self):
method complete_reductions (line 1235) | def complete_reductions(self):
method step (line 1257) | def step(self, closure=None, grad_scaler=None):
method state_dict (line 1295) | def state_dict(self):
method load_state_dict (line 1312) | def load_state_dict(self, state_dict):
FILE: apex/contrib/optimizers/fp16_optimizer.py
class FP16_Optimizer (line 5) | class FP16_Optimizer(object):
method __init__ (line 26) | def __init__(
method zero_grad (line 82) | def zero_grad(self, set_grads_to_None=True):
method step (line 97) | def step(self, closure=None):
method backward (line 137) | def backward(self, loss):
method _update_scale (line 147) | def _update_scale(self, skip):
method _get_state (line 166) | def _get_state(self):
method _set_state (line 169) | def _set_state(self, value):
method _get_param_groups (line 176) | def _get_param_groups(self):
method _set_param_groups (line 179) | def _set_param_groups(self, value):
method state_dict (line 184) | def state_dict(self):
method load_state_dict (line 207) | def load_state_dict(self, state_dict):
FILE: apex/contrib/optimizers/fused_adam.py
class FusedAdam (line 7) | class FusedAdam(torch.optim.Optimizer):
method __init__ (line 38) | def __init__(
method step (line 78) | def step(self, closure=None, grads=None, output_params=None, scale=1.0...
FILE: apex/contrib/optimizers/fused_lamb.py
class FusedLAMB (line 7) | class FusedLAMB(torch.optim.Optimizer):
method __init__ (line 63) | def __init__(
method zero_grad (line 102) | def zero_grad(self):
method step (line 110) | def step(self, closure=None):
FILE: apex/contrib/optimizers/fused_sgd.py
class FusedSGD (line 8) | class FusedSGD(Optimizer):
method __init__ (line 67) | def __init__(
method __setstate__ (line 107) | def __setstate__(self, state):
method get_momentums (line 112) | def get_momentums(self, params):
method step (line 129) | def step(self, closure=None, grads=None, output_params=None, scale=1.0...
FILE: apex/contrib/peer_memory/peer_halo_exchanger_1d.py
class PeerHaloExchanger1d (line 5) | class PeerHaloExchanger1d:
method __init__ (line 6) | def __init__(self, ranks, rank_in_group, peer_pool, half_halo):
method _allocate_peer_tensor (line 18) | def _allocate_peer_tensor(self, halo):
method __call__ (line 29) | def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=0, diag...
FILE: apex/contrib/peer_memory/peer_memory.py
class PeerMemoryPool (line 6) | class PeerMemoryPool(object):
method __init__ (line 7) | def __init__(self, static_size, dynamic_size, peer_ranks=None):
method __del__ (line 47) | def __del__(self):
method reset (line 50) | def reset(self):
method allocate_peer_tensors (line 53) | def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
FILE: apex/contrib/sparsity/asp.py
function eligible_modules (line 17) | def eligible_modules(model, whitelist_layer_types, allowed_layer_names, ...
class ASP (line 27) | class ASP:
method init_model_for_pruning (line 39) | def init_model_for_pruning(
method already_init_asp_model (line 257) | def already_init_asp_model(cls):
method init_optimizer_for_pruning (line 269) | def init_optimizer_for_pruning(cls, optimizer):
method compute_sparse_masks (line 314) | def compute_sparse_masks(cls):
method restore_pruned_weights (line 390) | def restore_pruned_weights(cls):
method is_sparsity_enabled (line 410) | def is_sparsity_enabled(cls):
method prune_trained_model (line 431) | def prune_trained_model(cls, model, optimizer):
method set_permutation_saving_params (line 444) | def set_permutation_saving_params(
FILE: apex/contrib/sparsity/permutation_lib.py
function convert_fx_node_name (line 24) | def convert_fx_node_name(fx_node_name):
function get_node_parent_children (line 29) | def get_node_parent_children(fx_node):
function node_name_matches (line 54) | def node_name_matches(node_name, module_name):
function replicate_sequence (line 76) | def replicate_sequence(sequence, replications):
class Permutation (line 88) | class Permutation:
method set_identical_seed (line 158) | def set_identical_seed(cls, identical_seed=1):
method reset_seed (line 172) | def reset_seed(cls):
method set_tcpstore_port (line 188) | def set_tcpstore_port(cls, tcpstore_port):
method set_permutation_saving_params (line 196) | def set_permutation_saving_params(
method set_permutation_params_from_asp (line 214) | def set_permutation_params_from_asp(cls, model, sparse_parameters, all...
method permute_model (line 249) | def permute_model(
method get_permutation_stats (line 331) | def get_permutation_stats(cls):
method apply_permutation_in_C_dim (line 341) | def apply_permutation_in_C_dim(cls, node_name, permutation_sequence, d...
method permute_attr (line 442) | def permute_attr(cls, node_name, permutation_sequence, fx_graph, dryrun):
method apply_permutation_in_K_dim (line 482) | def apply_permutation_in_K_dim(cls, node_name, permutation_sequence, f...
method check_graph_for_unpermuted_nodes (line 566) | def check_graph_for_unpermuted_nodes(cls, fx_graph):
method find_sparse_parameters_for_node (line 624) | def find_sparse_parameters_for_node(cls, node_name):
method find_permutation_for_matrix_group (line 661) | def find_permutation_for_matrix_group(cls, matrix_group):
method skip_sibling_group (line 744) | def skip_sibling_group(cls, fx_graph, sibling_group_id, reason):
method collect_sparse_weights (line 762) | def collect_sparse_weights(cls, fx_graph, sibling_group, sibling_group...
method find_sibling_group_permutation (line 808) | def find_sibling_group_permutation(cls, fx_graph, sibling_group_id):
method permute_sibling_group (line 847) | def permute_sibling_group(cls, fx_graph, sibling_group_id, group_permu...
method apply_permutation_in_K_dim_to_children (line 914) | def apply_permutation_in_K_dim_to_children(cls, fx_graph, node_name, p...
method defer_prints (line 962) | def defer_prints(cls):
method resume_prints (line 978) | def resume_prints(cls):
method find_permutations (line 991) | def find_permutations(cls, fx_graph):
method sync_permutations (line 1026) | def sync_permutations(cls, fx_graph):
method apply_permutations (line 1092) | def apply_permutations(cls, fx_graph):
method insert_MHA_out_proj (line 1104) | def insert_MHA_out_proj(fx_graph, MHA_node, verbosity):
method init_grouped_conv_permutation_flags (line 1152) | def init_grouped_conv_permutation_flags(fx_graph, node_name, node_grou...
method init_permutation_flags (line 1183) | def init_permutation_flags(cls, fx_graph):
method collect_siblings (line 1302) | def collect_siblings(fx_graph, node_name, all_siblings):
method propagate_sibling_group (line 1324) | def propagate_sibling_group(fx_graph, all_siblings, verbosity):
method collect_coparents (line 1362) | def collect_coparents(fx_graph, node_name, all_coparents):
method propagate_coparent_group (line 1392) | def propagate_coparent_group(fx_graph, all_coparents, verbosity):
method fixup_concats (line 1435) | def fixup_concats(cls, fx_graph):
method enforce_dimension_agreement (line 1491) | def enforce_dimension_agreement(cls, fx_graph):
method make_sibling_coparent_groups (line 1543) | def make_sibling_coparent_groups(cls, fx_graph):
method propagate_permutation_flags (line 1596) | def propagate_permutation_flags(cls, fx_graph):
method find_node_real_children (line 1662) | def find_node_real_children(cls, fx_graph, node_name, found_children):
method find_real_children (line 1683) | def find_real_children(cls, fx_graph):
method find_node_real_parents (line 1724) | def find_node_real_parents(cls, fx_graph, node_name, found_parents):
method find_real_parents (line 1745) | def find_real_parents(cls, fx_graph):
method build_fx_graph (line 1776) | def build_fx_graph(
method trace_and_print_raw_fx_graph (line 1999) | def trace_and_print_raw_fx_graph(cls, model, print_tabular=False, gene...
method save_graph_to_json (line 2060) | def save_graph_to_json(cls, graph, save_dumped_graph_path_with_name="....
FILE: apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py
function accelerated_search_for_good_permutation (line 6) | def accelerated_search_for_good_permutation(matrix_group, options=None, ...
FILE: apex/contrib/sparsity/permutation_search_kernels/channel_swap.py
function try_swap (line 12) | def try_swap(matrix, dst, src):
function stripes_and_swap_idx_to_columns (line 30) | def stripes_and_swap_idx_to_columns(stripe0, stripe1, idx):
function columns_to_stripes_and_swap_idx (line 41) | def columns_to_stripes_and_swap_idx(col0, col1):
function build_stripe_pairs (line 57) | def build_stripe_pairs(matrix, used_stripes):
function compute_swap_map (line 72) | def compute_swap_map(matrix, used_stripes):
function build_swap_map (line 94) | def build_swap_map(matrix, swap_map, swap_ids, used_stripes, verbosity):
function use_swap_map (line 141) | def use_swap_map(
function Channel_Swap (line 209) | def Channel_Swap(matrix, escape_attempts=0, verbosity=0, permutation=None):
FILE: apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py
function is_canonical (line 21) | def is_canonical(perm, col):
function generate_unique_combinations (line 36) | def generate_unique_combinations(
function generate_all_unique_combinations (line 76) | def generate_all_unique_combinations(C, M, must_use_all_groups=False):
function predict_unique_combinations (line 102) | def predict_unique_combinations(C, M):
function search_matrix (line 114) | def search_matrix(matrix, group_width):
function collect_stripes (line 156) | def collect_stripes(matrix, stripes, group_width):
function apply_stripe_group_permutation (line 166) | def apply_stripe_group_permutation(sgp, stripes, group_width, permutation):
function generate_stripe_groups (line 184) | def generate_stripe_groups(num_stripes, window_size):
function build_stripe_map (line 209) | def build_stripe_map(
function use_stripe_map (line 296) | def use_stripe_map(matrix, group_width, stripe_map, stripe_ids, perm_map...
function Exhaustive_Search (line 374) | def Exhaustive_Search(matrix, stripe_group_size=-1, escape_attempts=0, p...
FILE: apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py
function use_gpu (line 23) | def use_gpu(initial_override=True):
function apply_2_to_4 (line 46) | def apply_2_to_4(matrix):
function sum_after_2_to_4 (line 56) | def sum_after_2_to_4(matrix):
function unstructured_prune (line 87) | def unstructured_prune(matrix, sparsity):
function try_swap (line 98) | def try_swap(matrix, dst, src):
function efficacy (line 116) | def efficacy(optimal_lost_magnitude, base_lost_magnitude, cur_lost_magni...
function magnitude_after_pruning_rows (line 127) | def magnitude_after_pruning_rows(matrix, rate=0.5):
function try_permutations_on_matrix (line 144) | def try_permutations_on_matrix(matrix, permutations):
function find_permutation (line 174) | def find_permutation(A, B):
function make_grouped (line 193) | def make_grouped(A):
function common_groups (line 206) | def common_groups(A, B):
function remove_common_groups (line 226) | def remove_common_groups(A, B):
function group_differences (line 256) | def group_differences(A, B):
function dictify (line 274) | def dictify(wrong_entries):
function move_groups_to_match (line 286) | def move_groups_to_match(B, A, debug=False):
function swap_and_correct (line 404) | def swap_and_correct(permutation, src, tgt):
function move_permutation_towards (line 415) | def move_permutation_towards(B, A, debug=False):
function permutation_distance (line 558) | def permutation_distance(A, B, matrix=None, magnitude_targets=None, debu...
FILE: apex/contrib/sparsity/permutation_tests/permutation_test.py
function str2bool (line 15) | def str2bool(v):
function find_minimum_sparsity (line 72) | def find_minimum_sparsity(matrix, search_function, **kwargs):
FILE: apex/contrib/sparsity/sparse_masklib.py
function fill (line 11) | def fill(x):
function reshape_1d (line 18) | def reshape_1d(matrix, m):
function compute_valid_1d_patterns (line 35) | def compute_valid_1d_patterns(m, n):
function mn_1d_best (line 52) | def mn_1d_best(matrix, m, n):
function m4n2_1d (line 65) | def m4n2_1d(mat, density):
function mn_2d_greedy (line 86) | def mn_2d_greedy(matrix, m, n):
function m4n2_2d_greedy (line 118) | def m4n2_2d_greedy(mat, density):
function compute_valid_2d_patterns (line 126) | def compute_valid_2d_patterns(m, n):
function mn_2d_best (line 150) | def mn_2d_best(matrix, m, n):
function m4n2_2d_best (line 169) | def m4n2_2d_best(mat, density):
function create_mask (line 176) | def create_mask(tensor, pattern="m4n2_1d", density=0.5):
FILE: apex/contrib/sparsity/test/checkpointing_test_part1.py
function build_model (line 8) | def build_model(args):
function train_step (line 35) | def train_step(args, model, optimizer, input_batch, target_batch, step):
function train_loop (line 46) | def train_loop(args, model, optimizer, step, num_steps):
function main (line 54) | def main(args):
class Args (line 103) | class Args:
FILE: apex/contrib/sparsity/test/checkpointing_test_part2.py
function build_model (line 8) | def build_model(args):
function train_step (line 35) | def train_step(args, model, optimizer, input_batch, target_batch, step):
function train_loop (line 46) | def train_loop(args, model, optimizer, step, num_steps):
function main (line 54) | def main(step, args, model_state_dict, optimizer_state_dict):
class Args (line 85) | class Args:
FILE: apex/contrib/sparsity/test/checkpointing_test_reference.py
function build_model (line 12) | def build_model(args):
function train_step (line 39) | def train_step(args, model, optimizer, input_batch, target_batch, step):
function train_loop (line 50) | def train_loop(args, model, optimizer, step, num_steps):
function main (line 58) | def main(args):
class Args (line 102) | class Args:
FILE: apex/contrib/sparsity/test/test_permutation_application.py
class simple_convs (line 21) | class simple_convs(torch.nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 98) | def forward(self, x: torch.Tensor):
class conv_1d (line 104) | class conv_1d(torch.nn.Module):
method __init__ (line 107) | def __init__(
method forward (line 131) | def forward(self, x: torch.Tensor):
class grouped_convs (line 144) | class grouped_convs(torch.nn.Module):
method __init__ (line 147) | def __init__(
method forward (line 208) | def forward(self, input: torch.Tensor):
class simple_forks_joins (line 212) | class simple_forks_joins(torch.nn.Module):
method __init__ (line 215) | def __init__(
method forward (line 289) | def forward(self, input: torch.Tensor):
class different_grouped_convs (line 296) | class different_grouped_convs(torch.nn.Module):
method __init__ (line 299) | def __init__(
method forward (line 338) | def forward(self, input: torch.Tensor):
class siblings_poison (line 346) | class siblings_poison(torch.nn.Module):
method __init__ (line 349) | def __init__(
method forward (line 380) | def forward(self, input: torch.Tensor):
class coparent_poison (line 386) | class coparent_poison(torch.nn.Module):
method __init__ (line 389) | def __init__(
method forward (line 420) | def forward(self, input: torch.Tensor):
class depthwise_child_is_sibling (line 426) | class depthwise_child_is_sibling(torch.nn.Module):
method __init__ (line 429) | def __init__(
method forward (line 463) | def forward(self, input: torch.Tensor):
class module_attribute (line 469) | class module_attribute(torch.nn.Module):
method __init__ (line 472) | def __init__(
method forward (line 511) | def forward(self, input: torch.Tensor):
class square_attribute (line 520) | class square_attribute(torch.nn.Module):
method __init__ (line 525) | def __init__(
method forward (line 543) | def forward(self, input: torch.Tensor):
class MHA_test (line 549) | class MHA_test(torch.nn.Module):
method __init__ (line 552) | def __init__(self, hidden_dim: int = 256, seq_len: int = 64, num_heads...
method forward (line 569) | def forward(self, input: torch.Tensor):
class one_sparse_sibling (line 575) | class one_sparse_sibling(torch.nn.Module):
method __init__ (line 578) | def __init__(
method forward (line 609) | def forward(self, input: torch.Tensor):
class test_concat (line 615) | class test_concat(torch.nn.Module):
method __init__ (line 618) | def __init__(
method forward (line 666) | def forward(self, input: torch.Tensor):
class test_flatten_op (line 679) | class test_flatten_op(torch.nn.Module):
method __init__ (line 682) | def __init__(
method forward (line 702) | def forward(self, input: torch.Tensor):
class test_flatten_module (line 708) | class test_flatten_module(torch.nn.Module):
method __init__ (line 711) | def __init__(
method forward (line 734) | def forward(self, input: torch.Tensor):
class test_trace_failure (line 738) | class test_trace_failure(torch.nn.Module):
method __init__ (line 741) | def __init__(self):
method forward (line 750) | def forward(self, input: torch.Tensor):
class already_sparse (line 760) | class already_sparse(torch.nn.Module):
method __init__ (line 763) | def __init__(self):
method forward (line 778) | def forward(self, input: torch.Tensor):
function test_model (line 783) | def test_model(model, tag, verbosity=0, save_onnx=False):
function main (line 919) | def main():
FILE: apex/contrib/sparsity/test/toy_problem.py
function build_model (line 8) | def build_model(args):
function train_step (line 35) | def train_step(args, model, optimizer, input_batch, target_batch, step):
function train_loop (line 46) | def train_loop(args, model, optimizer, step, num_steps):
function main (line 54) | def main(args):
class Args (line 95) | class Args:
FILE: apex/contrib/test/bottleneck/test_bottleneck_module.py
function ground_truth_bottleneck (line 17) | def ground_truth_bottleneck(C, dtype, explicit_nhwc):
function print_bottleneck_p_and_b (line 27) | def print_bottleneck_p_and_b(bottleneck):
function has_nan (line 35) | def has_nan(x):
function rel_diff_t (line 49) | def rel_diff_t(xx1, xx2):
function rel_diff (line 55) | def rel_diff(x1, x2):
function graph_it (line 64) | def graph_it(bottleneck, x):
function clone_inputs (line 73) | def clone_inputs(bottleneck, x, dy=None):
function fprop_and_bprop (line 85) | def fprop_and_bprop(bottleneck, x, dy):
function ground_truth (line 95) | def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
function print_ground_truth (line 111) | def print_ground_truth(gt):
function apply_to_different_bottleneck (line 119) | def apply_to_different_bottleneck(gt, bottleneck):
function compare_single_field (line 126) | def compare_single_field(results, f1, f2, l0, l1, l2):
function compare (line 137) | def compare(gt, bt):
function spatial_parallel_bottleneck (line 151) | def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, ...
function n_way_spatial (line 175) | def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, r...
function main (line 226) | def main():
class TestBottleneck (line 283) | class TestBottleneck(NcclDistributedTestBase):
method world_size (line 288) | def world_size(self) -> int:
method test_bottleneck_without_peer_memory (line 291) | def test_bottleneck_without_peer_memory(self) -> None:
method test_bottleneck_with_peer_memory (line 311) | def test_bottleneck_with_peer_memory(self) -> None:
FILE: apex/contrib/test/clip_grad/test_clip_grad.py
function make_params (line 13) | def make_params(
class ClipGradNormTest (line 46) | class ClipGradNormTest(unittest.TestCase):
method setUp (line 47) | def setUp(self, seed=1234):
method test_matches_pytorch (line 52) | def test_matches_pytorch(
method test_matches_pytorch_fp16 (line 106) | def test_matches_pytorch_fp16(self):
method test_matches_pytorch_fp32 (line 109) | def test_matches_pytorch_fp32(self):
method test_matches_pytorch_fp64 (line 112) | def test_matches_pytorch_fp64(self):
method test_matches_pytorch_cpu (line 115) | def test_matches_pytorch_cpu(self):
method test_matches_pytorch_infnorm (line 118) | def test_matches_pytorch_infnorm(self):
method test_matches_pytorch_1norm (line 121) | def test_matches_pytorch_1norm(self):
method test_raises_on_mismatch (line 124) | def test_raises_on_mismatch(self):
method test_raises_on_nan (line 161) | def test_raises_on_nan(self):
method test_raises_on_inf (line 166) | def test_raises_on_inf(self):
FILE: apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py
class FusedDenseTest (line 24) | class FusedDenseTest(unittest.TestCase):
method setUp (line 25) | def setUp(self, seed=0):
method test_conv_bias_relu (line 119) | def test_conv_bias_relu(self):
method test_conv_bias (line 152) | def test_conv_bias(self):
method test_conv_bias_mask_relu (line 186) | def test_conv_bias_mask_relu(self):
method test_conv_frozen_scale_bias_relu (line 220) | def test_conv_frozen_scale_bias_relu(self):
FILE: apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py
class BNModelRef (line 39) | class BNModelRef(nn.Module):
method __init__ (line 40) | def __init__(self, num_features, num_layers=1000):
method forward (line 55) | def forward(self, x):
class BNModel (line 59) | class BNModel(nn.Module):
method __init__ (line 60) | def __init__(self, num_features, num_layers=1000):
method forward (line 76) | def forward(self, x):
function get_rand_tensors (line 80) | def get_rand_tensors(global_shape, device):
class TestCudnnGBN (line 93) | class TestCudnnGBN(NcclDistributedTestBase):
method _prep (line 94) | def _prep(self):
method world_size (line 99) | def world_size(self) -> int:
method _test_cudnn_gbn (line 103) | def _test_cudnn_gbn(
method test_cudnngbn (line 160) | def test_cudnngbn(self):
FILE: apex/contrib/test/fmha/test_fmha.py
function _get_device_properties (line 41) | def _get_device_properties(device=torch.device("cuda")):
function py_mha (line 47) | def py_mha(qkv, amask, b, s, h, d):
class TestFMHA (line 68) | class TestFMHA(unittest.TestCase):
method run_test (line 69) | def run_test(self, s: int, b: int, zero_tensors: bool):
method test_128 (line 121) | def test_128(self):
method test_256 (line 127) | def test_256(self):
method test_384 (line 133) | def test_384(self):
method test_512 (line 139) | def test_512(self):
FILE: apex/contrib/test/focal_loss/test_focal_loss.py
class FocalLossTest (line 24) | class FocalLossTest(unittest.TestCase):
method test_focal_loss (line 31) | def test_focal_loss(self) -> None:
FILE: apex/contrib/test/fused_dense/test_fused_dense.py
class FusedDenseTest (line 16) | class FusedDenseTest(common_utils.TestCase):
method _test_fused_dense (line 17) | def _test_fused_dense(self, dtype, seed=0):
method test_fused_dense (line 49) | def test_fused_dense(self, dtype):
FILE: apex/contrib/test/group_norm/test_group_norm.py
function torch_group_norm_high_precision (line 27) | def torch_group_norm_high_precision(x, g, w, b, eps, act="", *, compute_...
function relative_ulp (line 49) | def relative_ulp(dtype, device):
function _ref_compute_type (line 56) | def _ref_compute_type(ref_func, xdtype: torch.dtype) -> torch.dtype:
function _estimate_group_norm_test_bytes (line 65) | def _estimate_group_norm_test_bytes(
function _has_sufficient_cuda_memory (line 102) | def _has_sufficient_cuda_memory(required_bytes: int, *, safety_factor: f...
class GroupNormTest (line 115) | class GroupNormTest(unittest.TestCase):
method setUp (line 116) | def setUp(self, seed=0):
method verify_group_norm (line 120) | def verify_group_norm(
method test_fp16_one_pass_algo (line 183) | def test_fp16_one_pass_algo(self):
method test_fp16_two_pass_algo (line 186) | def test_fp16_two_pass_algo(self):
method test_fp16_one_pass_algo_with_swish (line 189) | def test_fp16_one_pass_algo_with_swish(self):
method test_fp16_two_pass_algo_with_swish (line 192) | def test_fp16_two_pass_algo_with_swish(self):
method test_bf16_one_pass_algo (line 195) | def test_bf16_one_pass_algo(self):
method test_bf16_two_pass_algo (line 198) | def test_bf16_two_pass_algo(self):
method test_bf16_one_pass_algo_with_swish (line 201) | def test_bf16_one_pass_algo_with_swish(self):
method test_bf16_two_pass_algo_with_swish (line 204) | def test_bf16_two_pass_algo_with_swish(self):
method test_fp32_one_pass_algo (line 207) | def test_fp32_one_pass_algo(self):
method test_fp32_two_pass_algo (line 210) | def test_fp32_two_pass_algo(self):
method test_fp32_one_pass_algo_with_swish (line 213) | def test_fp32_one_pass_algo_with_swish(self):
method test_fp32_two_pass_algo_with_swish (line 216) | def test_fp32_two_pass_algo_with_swish(self):
method test_group_norm_module (line 219) | def test_group_norm_module(self):
method test_group_norm_inductor (line 222) | def test_group_norm_inductor(self):
method test_16_groups (line 254) | def test_16_groups(self):
method test_large_batch_two_pass (line 283) | def test_large_batch_two_pass(self):
method test_fp16_parameters (line 318) | def test_fp16_parameters(self):
method get_v2_hw_c_list (line 334) | def get_v2_hw_c_list():
method check_v2_cc_and_sm_count (line 344) | def check_v2_cc_and_sm_count(self):
method skip_if_v2_not_supported (line 351) | def skip_if_v2_not_supported(self):
method test_check_v2_legality (line 358) | def test_check_v2_legality(self):
method test_fp16_v2_32_groups (line 404) | def test_fp16_v2_32_groups(self):
method test_fp16_v2_16_groups_with_swish (line 422) | def test_fp16_v2_16_groups_with_swish(self):
method test_bf16_v2_32_groups (line 440) | def test_bf16_v2_32_groups(self):
method test_bf16_v2_16_groups_with_swish (line 458) | def test_bf16_v2_16_groups_with_swish(self):
FILE: apex/contrib/test/index_mul_2d/test_index_mul_2d.py
class IndexMul2dTest (line 16) | class IndexMul2dTest(unittest.TestCase):
method setUp (line 17) | def setUp(self, seed=0):
method test_index_mul_float (line 63) | def test_index_mul_float(self):
method test_index_mul_half (line 107) | def test_index_mul_half(self):
FILE: apex/contrib/test/layer_norm/test_fast_layer_norm.py
class GPUTimer (line 14) | class GPUTimer:
method __init__ (line 15) | def __init__(self, stream):
method start (line 20) | def start(self):
method stop (line 23) | def stop(self):
method sync (line 26) | def sync(self):
method millis (line 29) | def millis(self):
function size_in_bytes (line 33) | def size_in_bytes(t):
function metrics (line 37) | def metrics(y_ref, y, epsilon=1e-6):
function backward_ (line 53) | def backward_(dz, x, mu, rs, gamma):
function benchmark_ (line 74) | def benchmark_(S, B, hidden_size, itype, wtype, runs=100):
function _test_impl (line 141) | def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32, mem_eff=False):
class TestFastLayerNorm (line 211) | class TestFastLayerNorm(unittest.TestCase):
method assertAll (line 213) | def assertAll(self, l):
method test_all_configs (line 219) | def test_all_configs(self):
method test_run_benchmark (line 257) | def test_run_benchmark(self):
method test_compat_with_autocast (line 272) | def test_compat_with_autocast(self):
FILE: apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py
class EncdecMultiheadAttnTest (line 13) | class EncdecMultiheadAttnTest(unittest.TestCase):
method setUp (line 14) | def setUp(self, seed=1234):
method test_encdec_multihead_attn (line 78) | def test_encdec_multihead_attn(self):
method test_encdec_multihead_attn_time_mask (line 111) | def test_encdec_multihead_attn_time_mask(self):
method test_encdec_multihead_attn_pad_mask (line 154) | def test_encdec_multihead_attn_pad_mask(self):
FILE: apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py
class EncdecMultiheadAttnNormAddTest (line 13) | class EncdecMultiheadAttnNormAddTest(unittest.TestCase):
method setUp (line 14) | def setUp(self, seed=1234):
method test_encdec_multihead_attn_norm_add (line 79) | def test_encdec_multihead_attn_norm_add(self):
FILE: apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py
class SelfMultiheadAttnTest (line 13) | class SelfMultiheadAttnTest(unittest.TestCase):
method setUp (line 14) | def setUp(self, seed=1234):
method test_self_multihead_attn_additive_mask (line 68) | def test_self_multihead_attn_additive_mask(self):
FILE: apex/contrib/test/multihead_attn/test_mha_fused_softmax.py
class FusedSoftmaxTest (line 14) | class FusedSoftmaxTest(unittest.TestCase):
method setUp (line 15) | def setUp(self, seed=1234):
method test_fused_softmax (line 37) | def test_fused_softmax(self):
FILE: apex/contrib/test/multihead_attn/test_self_multihead_attn.py
class SelfMultiheadAttnTest (line 13) | class SelfMultiheadAttnTest(unittest.TestCase):
method setUp (line 14) | def setUp(self, seed=1234):
method test_self_multihead_attn (line 65) | def test_self_multihead_attn(self):
method test_self_multihead_attn_time_mask (line 97) | def test_self_multihead_attn_time_mask(self):
method test_self_multihead_attn_pad_mask (line 137) | def test_self_multihead_attn_pad_mask(self):
FILE: apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py
class SelfMultiheadAttnNormAddTest (line 13) | class SelfMultiheadAttnNormAddTest(unittest.TestCase):
method setUp (line 14) | def setUp(self, seed=1234):
method test_self_multihead_attn_norm_add (line 64) | def test_self_multihead_attn_norm_add(self):
FILE: apex/contrib/test/openfold_triton/test_fused_adam_swa.py
class AlphaFoldSWA (line 31) | class AlphaFoldSWA(nn.Module):
method __init__ (line 34) | def __init__(self, alphafold: nn.Module, enabled: bool, decay_rate: fl...
method update (line 46) | def update(self, alphafold: nn.Module) -> None:
method forward (line 50) | def forward(self, batch):
class swa_avg_fn (line 56) | class swa_avg_fn:
method __init__ (line 60) | def __init__(self, decay_rate: float) -> None:
method __call__ (line 63) | def __call__(
class FusedAdamSWATestCase (line 84) | class FusedAdamSWATestCase(unittest.TestCase):
method setUp (line 85) | def setUp(self):
method tearDown (line 96) | def tearDown(self):
method test_fused_update_on_random_data (line 99) | def test_fused_update_on_random_data(self):
method _run_fused_update_on_random_data (line 103) | def _run_fused_update_on_random_data(self):
FILE: apex/contrib/test/openfold_triton/test_openfold_mha.py
function openfold_attention_eager (line 14) | def openfold_attention_eager(
class OpenfoldMhaTest (line 54) | class OpenfoldMhaTest(unittest.TestCase):
method setUp (line 55) | def setUp(self, seed=1234):
method test_openfold_triton_mha (line 61) | def test_openfold_triton_mha(self, Z=256, H=4, N_CTX=256, D_HEAD=32, d...
FILE: apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py
class SyncTritonAutoTuneCacheTest (line 19) | class SyncTritonAutoTuneCacheTest(MultiProcessTestCase):
method __init__ (line 22) | def __init__(self, *args, **kwargs) -> None:
method setUp (line 25) | def setUp(self) -> None:
method tearDown (line 29) | def tearDown(self) -> None:
method world_size (line 35) | def world_size(self) -> int:
method init_method (line 39) | def init_method(self):
method destroy_pg_upon_exit (line 43) | def destroy_pg_upon_exit(self) -> bool:
method _create_process_group_nccl (line 46) | def _create_process_group_nccl(self):
method test_sync_triton_auto_tune_cache_across_gpus (line 69) | def test_sync_triton_auto_tune_cache_across_gpus(self):
FILE: apex/contrib/test/optimizers/test_dist_adam.py
class SimpleModel (line 19) | class SimpleModel(torch.nn.Module):
method __init__ (line 20) | def __init__(self, num_layers, size):
method forward (line 26) | def forward(self, x):
function make_models (line 33) | def make_models(
function dummy_context (line 111) | def dummy_context():
class TestDistributedFusedAdam (line 119) | class TestDistributedFusedAdam(NcclDistributedTestBase):
method test_matches_pytorch (line 122) | def test_matches_pytorch(
method test_matches_pytorch_l2_reg (line 267) | def test_matches_pytorch_l2_reg(self):
method test_matches_pytorch_no_overlap (line 270) | def test_matches_pytorch_no_overlap(self):
method test_matches_pytorch_sync_every_step (line 276) | def test_matches_pytorch_sync_every_step(self):
method test_matches_pytorch_contiguous_buffers (line 279) | def test_matches_pytorch_contiguous_buffers(self):
method test_matches_pytorch_fp64 (line 282) | def test_matches_pytorch_fp64(self):
method test_matches_pytorch_fp16 (line 290) | def test_matches_pytorch_fp16(self):
method test_matches_pytorch_bf16 (line 299) | def test_matches_pytorch_bf16(self):
method test_matches_pytorch_fp16_params (line 308) | def test_matches_pytorch_fp16_params(self):
method test_matches_pytorch_bf16_grads (line 319) | def test_matches_pytorch_bf16_grads(self):
method test_matches_pytorch_bf16_param_remainders (line 329) | def test_matches_pytorch_bf16_param_remainders(self):
method test_matches_pytorch_multi_dtypes (line 341) | def test_matches_pytorch_multi_dtypes(self):
method test_matches_pytorch_int64_param_sync (line 353) | def test_matches_pytorch_int64_param_sync(self):
method test_matches_pytorch_int32_param_sync_contiguous_buffers (line 358) | def test_matches_pytorch_int32_param_sync_contiguous_buffers(self):
method test_matches_pytorch_uint8_param_sync (line 364) | def test_matches_pytorch_uint8_param_sync(self):
method test_matches_pytorch_scaled_state (line 374) | def test_matches_pytorch_scaled_state(self):
method test_matches_pytorch_nccl_ub (line 386) | def test_matches_pytorch_nccl_ub(self):
method test_raises_on_mismatch (line 392) | def test_raises_on_mismatch(self):
method test_clip_grad_norm (line 421) | def test_clip_grad_norm(self):
method test_grad_scaler (line 453) | def test_grad_scaler(self):
method test_checkpoint (line 492) | def test_checkpoint(
method test_checkpoint_save_1gpu (line 756) | def test_checkpoint_save_1gpu(self):
method test_checkpoint_load_1gpu (line 760) | def test_checkpoint_load_1gpu(self):
method test_checkpoint_bf16 (line 764) | def test_checkpoint_bf16(self):
method test_checkpoint_scaled_state (line 785) | def test_checkpoint_scaled_state(self):
method test_bucket_low_utilization_warning (line 806) | def test_bucket_low_utilization_warning(self):
method test_cuda_graph (line 834) | def test_cuda_graph(self):
FILE: apex/contrib/test/optimizers/test_distributed_fused_lamb.py
function flat_dist_call (line 12) | def flat_dist_call(param_list: list[torch.Tensor], op, args):
function get_init_weights_func (line 20) | def get_init_weights_func():
class ModelFoo (line 29) | class ModelFoo(torch.nn.Module):
method __init__ (line 30) | def __init__(self):
method forward (line 35) | def forward(self, input_tensor, gt):
class NcclDistributedFusedLAMB (line 45) | class NcclDistributedFusedLAMB(NcclDistributedTestBase):
method world_size (line 47) | def world_size(self) -> int:
method test_distributed_fused_lamb (line 73) | def test_distributed_fused_lamb(self, no_copy, opt_kwargs):
class NcclDistributedFusedLAMB_partial_ar (line 161) | class NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB):
method world_size (line 163) | def world_size(self) -> int:
FILE: apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py
function nccl_halo_ex (line 19) | def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc...
function single_test (line 75) | def single_test(
function H_split_tests (line 170) | def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_...
function W_split_tests (line 217) | def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_...
function main (line 264) | def main():
class TestPeerMemory (line 284) | class TestPeerMemory(NcclDistributedTestBase):
method world_size (line 289) | def world_size(self) -> int:
method _check_world_size_and_may_skip (line 293) | def _check_world_size_and_may_skip(self) -> None:
method get_halo_excnahger_1d (line 297) | def get_halo_excnahger_1d(self):
method test_height_split (line 305) | def test_height_split(self):
method test_width_split (line 319) | def test_width_split(self):
FILE: apex/contrib/test/transducer/test_transducer_joint.py
class TransducerJointTest (line 14) | class TransducerJointTest(unittest.TestCase):
method setUp (line 15) | def setUp(self, seed=1234):
method gen_input (line 18) | def gen_input(self, for_vector_kernel):
method _pack (line 47) | def _pack(self, x, f_len, g_len):
method _unpack (line 59) | def _unpack(self, x, f_len, g_len):
method run_transducer_joint (line 74) | def run_transducer_joint(self, for_vector_kernel, pack_output, relu, d...
method test_transducer_joint (line 130) | def test_transducer_joint(self):
method test_transducer_joint_vec (line 135) | def test_transducer_joint_vec(self):
method test_transducer_joint_pack (line 140) | def test_transducer_joint_pack(self):
method test_transducer_joint_vec_pack (line 145) | def test_transducer_joint_vec_pack(self):
method test_transducer_joint_relu (line 150) | def test_transducer_joint_relu(self):
method test_transducer_joint_vec_relu (line 155) | def test_transducer_joint_vec_relu(self):
method test_transducer_joint_pack_relu (line 160) | def test_transducer_joint_pack_relu(self):
method test_transducer_joint_vec_pack_relu (line 165) | def test_transducer_joint_vec_pack_relu(self):
method test_transducer_joint_relu_dropout (line 171) | def test_transducer_joint_relu_dropout(self):
method test_transducer_joint_vec_relu_dropout (line 175) | def test_transducer_joint_vec_relu_dropout(self):
method test_transducer_joint_pack_relu_dropout (line 181) | def test_transducer_joint_pack_relu_dropout(self):
method test_transducer_joint_vec_pack_relu_dropout (line 187) | def test_transducer_joint_vec_pack_relu_dropout(self):
FILE: apex/contrib/test/transducer/test_transducer_loss.py
class TransducerLossTest (line 14) | class TransducerLossTest(unittest.TestCase):
method setUp (line 15) | def setUp(self, seed=1234):
method gen_input (line 18) | def gen_input(self, scalar_t, for_vector_kernel):
method _pack (line 54) | def _pack(self, x):
method _unpack (line 65) | def _unpack(self, x):
method run_transducer_loss (line 83) | def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_...
method test_transducer_loss_fp32 (line 114) | def test_transducer_loss_fp32(self):
method test_transducer_loss_fp16 (line 124) | def test_transducer_loss_fp16(self):
method test_transducer_loss_fp16_backward_fusion (line 134) | def test_transducer_loss_fp16_backward_fusion(self):
method test_transducer_loss_fp16_backward_fusion_packed (line 144) | def test_transducer_loss_fp16_backward_fusion_packed(self):
method test_transducer_loss_fp16_backward_fusion_packed_vec (line 154) | def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
FILE: apex/contrib/test/xentropy/test_label_smoothing.py
function label_smoothing_raw (line 16) | def label_smoothing_raw(x, target, padding_idx, smoothing):
function label_smoothing_opt_1 (line 27) | def label_smoothing_opt_1(x, target, padding_idx, smoothing):
class LabelSmoothingTest (line 39) | class LabelSmoothingTest(unittest.TestCase):
method setUp (line 40) | def setUp(self, seed=1234):
method gen_test_inputs (line 49) | def gen_test_inputs(self, N, T, H, smoothing, padding_idx, dtype=torch...
method print_max_diff_elem (line 58) | def print_max_diff_elem(self, ref, tst):
method _test_label_smoothing_function (line 68) | def _test_label_smoothing_function(self, dtype):
method test_label_smoothing_function_fp16 (line 101) | def test_label_smoothing_function_fp16(self):
method test_label_smoothing_function_bf16 (line 104) | def test_label_smoothing_function_bf16(self):
method test_label_smoothing_perf (line 107) | def test_label_smoothing_perf(self):
FILE: apex/contrib/torchsched/__init__.py
function torchsched (line 30) | def torchsched(
function set_default_backend (line 46) | def set_default_backend(backend: str) -> None:
function torchsched_compile (line 58) | def torchsched_compile(
FILE: apex/contrib/torchsched/backend.py
function enable_multi_stream_scheduling (line 37) | def enable_multi_stream_scheduling(compile_fn: Callable[P, R]) -> Callab...
function convolution_backward_decomp_dwb (line 51) | def convolution_backward_decomp_dwb(
function convolution_backward_decomp_wbd (line 116) | def convolution_backward_decomp_wbd(
class DecompositionsWrapper (line 181) | class DecompositionsWrapper(_TorchCompileInductorWrapper):
method __init__ (line 197) | def __init__(
method __eq__ (line 216) | def __eq__(self, rhs: object) -> bool:
method __call__ (line 232) | def __call__(
function get_backend (line 262) | def get_backend(
FILE: apex/contrib/torchsched/config.py
function _get_skip_post_grad_graph_ids (line 23) | def _get_skip_post_grad_graph_ids() -> set[int]:
function __get_dump_code_backends_and_dir (line 51) | def __get_dump_code_backends_and_dir(
FILE: apex/contrib/torchsched/inductor/_utils.py
function get_stream_name (line 30) | def get_stream_name(stream_idx: int) -> str:
class CUDAStreamPool (line 43) | class CUDAStreamPool:
method __init__ (line 51) | def __init__(self, device: int | None = None, pool_size: int = 8) -> N...
method acquire (line 67) | def acquire(self) -> torch.cuda.Stream:
method release (line 75) | def release(self, stream: torch.cuda.Stream | None) -> None:
method __enter__ (line 84) | def __enter__(self) -> torch.cuda.Stream:
method __exit__ (line 94) | def __exit__(
function get_cuda_stream_pool (line 114) | def get_cuda_stream_pool(device: int | None = None, pool_size: int = 32)...
FILE: apex/contrib/torchsched/inductor/event.py
class CudaEventSym (line 30) | class CudaEventSym:
method __lt__ (line 52) | def __lt__(self, rhs: CudaEventSym) -> bool:
method __eq__ (line 61) | def __eq__(self, rhs: object) -> bool:
method __str__ (line 71) | def __str__(self) -> str:
method __hash__ (line 82) | def __hash__(self) -> int:
method record (line 86) | def record(self, stream_idx: int) -> _CudaEventRecordLine:
method wait (line 105) | def wait(self, stream_idx: int) -> _CudaEventWaitLine:
class _CudaEventRecordLine (line 127) | class _CudaEventRecordLine(WrapperLine):
method codegen (line 132) | def codegen(self, code: IndentedBuffer) -> None:
class _CudaEventWaitLine (line 141) | class _CudaEventWaitLine(WrapperLine):
method codegen (line 145) | def codegen(self, code: IndentedBuffer) -> None:
class CudaEventFactory (line 157) | class CudaEventFactory:
method __init__ (line 165) | def __init__(self) -> None:
method get_entrance_event (line 173) | def get_entrance_event(self) -> CudaEventSym:
method get_sym_event (line 186) | def get_sym_event(self, originate_stream_idx: int) -> CudaEventSym:
method get_materialized_event (line 194) | def get_materialized_event(self, code: IndentedBuffer) -> str:
method deposit_materialized_event (line 203) | def deposit_materialized_event(self, event: str) -> None:
FILE: apex/contrib/torchsched/inductor/graph.py
function _torchsched_codegen (line 31) | def _torchsched_codegen(
function _mixed_codegen (line 77) | def _mixed_codegen(graph: GraphLowering) -> tuple[ValueWithLineMap, Valu...
function patch_graph_lowering (line 103) | def patch_graph_lowering(patch: bool = True) -> None:
FILE: apex/contrib/torchsched/inductor/scheduler.py
class MultiCudaStreamScheduler (line 39) | class MultiCudaStreamScheduler(Scheduler):
method __init__ (line 53) | def __init__(self, operations: list[ir.Operation]) -> None:
method current_stream_idx (line 72) | def current_stream_idx(self) -> int | None:
method current_stream_name (line 80) | def current_stream_name(self) -> str | None:
method buffers_recorded_on_current_stream (line 88) | def buffers_recorded_on_current_stream(self) -> set[str]:
method buffers_recorded_on_current_stream (line 94) | def buffers_recorded_on_current_stream(self, buffs: set[str]) -> None:
method debug_str_short (line 105) | def debug_str_short(self, node: BaseSchedulerNode) -> str:
method get_last_event (line 119) | def get_last_event(self, events: set[CudaEventSym]) -> CudaEventSym:
method schedule_multi_cuda_streams (line 123) | def schedule_multi_cuda_streams(self) -> None:
method get_final_events_to_sync (line 198) | def get_final_events_to_sync(self) -> set[CudaEventSym]:
method clear_unjoined_events (line 227) | def clear_unjoined_events(self) -> None:
method register_downstream_event (line 231) | def register_downstream_event(
method get_cross_stream_dependencies (line 271) | def get_cross_stream_dependencies(
method generate_stream_ctx_enter (line 344) | def generate_stream_ctx_enter(self, node: BaseSchedulerNode) -> None:
method generate_stream_ctx_exit (line 357) | def generate_stream_ctx_exit(self) -> None:
method propagate_cross_stream_dependencies (line 364) | def propagate_cross_stream_dependencies(self, node: BaseSchedulerNode)...
method generate_stream_ctx_switching (line 390) | def generate_stream_ctx_switching(self, node: BaseSchedulerNode) -> None:
method codegen (line 415) | def codegen(self) -> None:
FILE: apex/contrib/torchsched/inductor/wrapper.py
class EnterDeviceContextManagerWithStreamInfoLine (line 39) | class EnterDeviceContextManagerWithStreamInfoLine(EnterDeviceContextMana...
method codegen (line 46) | def codegen(self, code: IndentedBuffer) -> None:
class ExitDeviceContextManagerWithStreamInfoLine (line 70) | class ExitDeviceContextManagerWithStreamInfoLine(ExitDeviceContextManage...
method codegen (line 73) | def codegen(self, code: IndentedBuffer) -> None:
class EnterCudaStreamContextLine (line 84) | class EnterCudaStreamContextLine(WrapperLine):
method __post_init__ (line 100) | def __post_init__(self) -> None:
method codegen (line 104) | def codegen(self, code: IndentedBuffer) -> None:
class ExitCudaStreamContextLine (line 122) | class ExitCudaStreamContextLine(WrapperLine):
method codegen (line 131) | def codegen(self, code: IndentedBuffer) -> None:
class MultiStreamWrapperCodegen (line 137) | class MultiStreamWrapperCodegen(PythonWrapperCodegen):
method __init__ (line 140) | def __init__(self) -> None:
method create (line 154) | def create(
method _write_get_raw_stream (line 171) | def _write_get_raw_stream(self, device_idx: int, graph: GraphLowering ...
method codegen_graph_nvtx_range_push (line 181) | def codegen_graph_nvtx_range_push(self, post_grad_graph_id: int) -> None:
method codegen_graph_nvtx_range_pop (line 185) | def codegen_graph_nvtx_range_pop(self) -> None:
method codegen_device_guard_enter (line 189) | def codegen_device_guard_enter(self, device_idx: int) -> None:
method codegen_device_guard_exit (line 203) | def codegen_device_guard_exit(self) -> None:
method codegen_cuda_stream_enter (line 207) | def codegen_cuda_stream_enter(
method codegen_cuda_stream_exit (line 253) | def codegen_cuda_stream_exit(self) -> None:
method codegen_events_wait_stream (line 257) | def codegen_events_wait_stream(self, events: set[CudaEventSym], stream...
method codegen_buffers_record_stream (line 267) | def codegen_buffers_record_stream(
FILE: apex/contrib/torchsched/ops/layer_norm.py
class CuDNNManager (line 18) | class CuDNNManager:
method __init__ (line 24) | def __init__(self) -> None:
method __del__ (line 29) | def __del__(self) -> None:
method __enter__ (line 33) | def __enter__(self) -> CuDNNManager:
method __exit__ (line 39) | def __exit__(
method set_stream (line 49) | def set_stream(self, stream: torch.cuda.Stream) -> None:
method reset_stream (line 52) | def reset_stream(self) -> None:
method handle (line 56) | def handle(self) -> int:
method stream (line 60) | def stream(self) -> torch.cuda.Stream:
function get_cudnn_manager (line 67) | def get_cudnn_manager() -> CuDNNManager:
class LayerNormGraphFactory (line 79) | class LayerNormGraphFactory:
method get_forward_graph (line 99) | def get_forward_graph(
method get_backward_graph (line 181) | def get_backward_graph(
function layer_norm (line 269) | def layer_norm(
function layer_norm_fake (line 339) | def layer_norm_fake(
function layer_norm_backward (line 359) | def layer_norm_backward(
function layer_norm_backward_fake (line 417) | def layer_norm_backward_fake(
function layer_norm_setup_context (line 432) | def layer_norm_setup_context(
function layer_norm_backward_wrapper (line 446) | def layer_norm_backward_wrapper(
FILE: apex/contrib/torchsched/passes/pre_grad_passes.py
function register_pattern (line 24) | def register_pattern(name: str, pattern: Callable, replacement: Callable...
function replace_layer_norm (line 29) | def replace_layer_norm(
function run_pre_grad_pass (line 53) | def run_pre_grad_pass(
function pre_grad_custom_pass (line 95) | def pre_grad_custom_pass(graph: torch.fx.Graph) -> None:
FILE: apex/contrib/transducer/_transducer_ref.py
function transducer_loss_reference (line 4) | def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_gr...
function transducer_joint_reference (line 83) | def transducer_joint_reference(
FILE: apex/contrib/transducer/transducer.py
class TransducerJoint (line 6) | class TransducerJoint(torch.nn.Module):
method __init__ (line 28) | def __init__(
method forward (line 51) | def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
class TransducerLoss (line 88) | class TransducerLoss(torch.nn.Module):
method __init__ (line 102) | def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=Fal...
method forward (line 109) | def forward(
class TransducerLossFunc (line 169) | class TransducerLossFunc(torch.autograd.Function):
method forward (line 171) | def forward(
method backward (line 212) | def backward(ctx, loss_grad):
class TransducerJointFunc (line 234) | class TransducerJointFunc(torch.autograd.Function):
method forward (line 236) | def forward(
method backward (line 282) | def backward(ctx, loss_grad):
FILE: apex/contrib/xentropy/softmax_xentropy.py
class SoftmaxCrossEntropyLoss (line 6) | class SoftmaxCrossEntropyLoss(torch.autograd.Function):
method forward (line 8) | def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to...
method backward (line 23) | def backward(ctx, grad_loss):
FILE: apex/distributed_testing/distributed_test_base.py
class DistributedTestBase (line 24) | class DistributedTestBase(common_distributed.MultiProcessTestCase):
method __init__ (line 25) | def __init__(self, *args, **kwargs) -> None:
method setUp (line 28) | def setUp(self) -> None:
method tearDown (line 33) | def tearDown(self) -> None:
method world_size (line 38) | def world_size(self) -> int:
method init_method (line 42) | def init_method(self):
method destroy_pg_upon_exit (line 46) | def destroy_pg_upon_exit(self) -> bool:
method _run (line 51) | def _run(cls, rank, test_name, file_name, pipe, **kwargs):
method _setup_pre_spawn (line 82) | def _setup_pre_spawn(self):
class NcclDistributedTestBase (line 86) | class NcclDistributedTestBase(DistributedTestBase):
class UccDistributedTestBase (line 99) | class UccDistributedTestBase(DistributedTestBase):
method _setup_pre_spawn (line 102) | def _setup_pre_spawn(self) -> None:
method tearDown (line 122) | def tearDown(self) -> None:
method init_method (line 130) | def init_method(self):
FILE: apex/fused_dense/fused_dense.py
class FusedDenseFunc (line 8) | class FusedDenseFunc(torch.autograd.Function):
method forward (line 10) | def forward(ctx, input, weight, bias):
method backward (line 16) | def backward(ctx, grad_output):
class DenseNoBiasFunc (line 24) | class DenseNoBiasFunc(torch.autograd.Function):
method forward (line 26) | def forward(ctx, input, weight):
method backward (line 32) | def backward(ctx, grad_output):
class FusedDenseGeluDenseFunc (line 39) | class FusedDenseGeluDenseFunc(torch.autograd.Function):
method forward (line 41) | def forward(ctx, input, weight1, bias1, weight2, bias2):
method backward (line 50) | def backward(ctx, grad_output):
function _fused_dense (line 60) | def _fused_dense(input, weight, bias):
function _dense_no_bias (line 66) | def _dense_no_bias(input, weight):
function _fused_dense_gelu_dense (line 72) | def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):
class FusedDense (line 78) | class FusedDense(nn.Module):
method __init__ (line 79) | def __init__(self, in_features, out_features, bias=True):
method forward (line 90) | def forward(self, input):
class FusedDenseGeluDense (line 97) | class FusedDenseGeluDense(nn.Module):
method __init__ (line 98) | def __init__(self, in_features, intermediate_features, out_features, b...
method forward (line 109) | def forward(self, input):
FILE: apex/mlp/mlp.py
class MlpFunction (line 11) | class MlpFunction(torch.autograd.Function):
method forward (line 13) | def forward(ctx, bias, activation, *args):
method backward (line 22) | def backward(ctx, grad_o):
function mlp_function (line 28) | def mlp_function(bias, activation, *args):
class MLP (line 33) | class MLP(torch.nn.Module):
method __init__ (line 42) | def __init__(self, mlp_sizes, bias=True, activation="relu"):
method reset_parameters (line 72) | def reset_parameters(self):
method forward (line 82) | def forward(self, input):
method extra_repr (line 85) | def extra_repr(self):
FILE: apex/multi_tensor_apply/multi_tensor_apply.py
class MultiTensorApply (line 1) | class MultiTensorApply(object):
method __init__ (line 5) | def __init__(self, chunk_size):
method check_avail (line 15) | def check_avail(self):
method __call__ (line 24) | def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
FILE: apex/normalization/fused_layer_norm.py
function supports_custom_op (line 17) | def supports_custom_op() -> bool:
function manual_rms_norm (line 22) | def manual_rms_norm(input, normalized_shape, weight, eps):
class FusedLayerNormAffineFunction (line 38) | class FusedLayerNormAffineFunction(torch.autograd.Function):
method forward (line 40) | def forward(ctx, input, weight, bias, normalized_shape, eps, memory_ef...
method backward (line 60) | def backward(ctx, grad_output):
function fused_layer_norm_affine_fwd (line 80) | def fused_layer_norm_affine_fwd(
function fused_layer_norm_affine_fwd_fake (line 101) | def fused_layer_norm_affine_fwd_fake(
function fused_layer_norm_affine_bwd (line 125) | def fused_layer_norm_affine_bwd(
function fused_layer_norm_affine_bwd_fake (line 150) | def fused_layer_norm_affine_bwd_fake(
function _fused_layer_norm_affine_backward (line 166) | def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_...
function _fused_layer_norm_affine_setup_context (line 182) | def _fused_layer_norm_affine_setup_context(ctx, inputs, output):
class FusedRMSNormAffineFunction (line 202) | class FusedRMSNormAffineFunction(torch.autograd.Function):
method forward (line 204) | def forward(ctx, input, weight, normalized_shape, eps, memory_efficien...
method backward (line 223) | def backward(ctx, grad_output):
function fused_rms_norm_affine_fwd (line 241) | def fused_rms_norm_affine_fwd(
function fused_rms_norm_affine_fwd_fake (line 260) | def fused_rms_norm_affine_fwd_fake(
function fused_rms_norm_affine_bwd (line 289) | def fused_rms_norm_affine_bwd(
function fused_rms_norm_affine_bwd_fake (line 310) | def fused_rms_norm_affine_bwd_fake(
function _fused_rms_norm_affine_backward (line 323) | def _fused_rms_norm_affine_backward(ctx, grad_output, grad_invvar):
function _fused_rms_norm_affine_setup_context (line 337) | def _fused_rms_norm_affine_setup_context(ctx, inputs, output):
class FusedLayerNormAffineMixedDtypesFunction (line 356) | class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFuncti...
method forward (line 358) | def forward(ctx, input, weight, bias, normalized_shape, eps, memory_ef...
class FusedRMSNormAffineMixedDtypesFunction (line 378) | class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):
method forward (line 380) | def forward(ctx, input, weight, normalized_shape, eps, memory_efficien...
class FusedLayerNormFunction (line 399) | class FusedLayerNormFunction(torch.autograd.Function):
method forward (line 401) | def forward(ctx, input, normalized_shape, eps, memory_efficient=False):
method backward (line 417) | def backward(ctx, grad_output):
function fused_layer_norm_fwd (line 434) | def fused_layer_norm_fwd(
function fused_layer_norm_fwd_fake (line 449) | def fused_layer_norm_fwd_fake(
function fused_layer_norm_bwd (line 469) | def fused_layer_norm_bwd(
function fused_layer_norm_bwd_fake (line 490) | def fused_layer_norm_bwd_fake(
function _fused_layer_norm_backward (line 502) | def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar):
function _fused_layer_norm_setup_context (line 515) | def _fused_layer_norm_setup_context(ctx, inputs, output):
class FusedRMSNormFunction (line 533) | class FusedRMSNormFunction(torch.autograd.Function):
method forward (line 535) | def forward(ctx, input, normalized_shape, eps, memory_efficient=False):
method backward (line 551) | def backward(ctx, grad_output):
function fused_rms_norm_fwd (line 568) | def fused_rms_norm_fwd(
function fused_rms_norm_fwd_fake (line 583) | def fused_rms_norm_fwd_fake(
function fused_rms_norm_bwd (line 610) | def fused_rms_norm_bwd(
function fused_rms_norm_bwd_fake (line 629) | def fused_rms_norm_bwd_fake(
function _fused_rms_norm_backward (line 640) | def _fused_rms_norm_backward(ctx, grad_output, grad_invvar):
function _fused_rms_norm_setup_context (line 653) | def _fused_rms_norm_setup_context(ctx, inputs, output):
function fused_layer_norm_affine (line 670) | def fused_layer_norm_affine(
function fused_layer_norm (line 681) | def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient...
function mixed_dtype_fused_layer_norm_affine (line 690) | def mixed_dtype_fused_layer_norm_affine(
function fused_rms_norm_affine (line 698) | def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, mem...
function fused_rms_norm (line 707) | def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=F...
function mixed_dtype_fused_rms_norm_affine (line 716) | def mixed_dtype_fused_rms_norm_affine(
class FusedLayerNorm (line 724) | class FusedLayerNorm(torch.nn.Module):
method __init__ (line 784) | def __init__(
method reset_parameters (line 810) | def reset_parameters(self):
method forward (line 815) | def forward(self, input):
method extra_repr (line 835) | def extra_repr(self):
class FusedRMSNorm (line 841) | class FusedRMSNorm(torch.nn.Module):
method __init__ (line 901) | def __init__(
method reset_parameters (line 925) | def reset_parameters(self):
method forward (line 929) | def forward(self, input):
method extra_repr (line 949) | def extra_repr(self):
class MixedFusedLayerNorm (line 959) | class MixedFusedLayerNorm(FusedLayerNorm):
method __init__ (line 960) | def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=Fal...
method forward (line 978) | def forward(self, input: torch.Tensor):
class MixedFusedRMSNorm (line 1000) | class MixedFusedRMSNorm(FusedRMSNorm):
method __init__ (line 1001) | def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=Fal...
method forward (line 1019) | def forward(self, input: torch.Tensor):
FILE: apex/optimizers/fused_adagrad.py
class FusedAdagrad (line 5) | class FusedAdagrad(torch.optim.Optimizer):
method __init__ (line 44) | def __init__(
method zero_grad (line 67) | def zero_grad(self):
method step (line 75) | def step(self, closure=None):
FILE: apex/optimizers/fused_adam.py
class FusedAdam (line 5) | class FusedAdam(torch.optim.Optimizer):
method __init__ (line 68) | def __init__(
method zero_grad (line 138) | def zero_grad(self):
method step (line 146) | def step(
FILE: apex/optimizers/fused_lamb.py
class FusedLAMB (line 5) | class FusedLAMB(torch.optim.Optimizer):
method __init__ (line 63) | def __init__(
method zero_grad (line 106) | def zero_grad(self):
method step (line 114) | def step(self, closure=None):
FILE: apex/optimizers/fused_mixed_precision_lamb.py
class FusedMixedPrecisionLamb (line 9) | class FusedMixedPrecisionLamb(torch.optim.Optimizer):
method __init__ (line 10) | def __init__(
method load_state_dict (line 73) | def load_state_dict(self, state_dict):
method _setup_full_precision_params (line 140) | def _setup_full_precision_params(self):
method add_param_group (line 159) | def add_param_group(self, param_group):
method step (line 166) | def step(self, closure=None, grad_scaler=None):
FILE: apex/optimizers/fused_novograd.py
class FusedNovoGrad (line 5) | class FusedNovoGrad(torch.optim.Optimizer):
method __init__ (line 67) | def __init__(
method zero_grad (line 110) | def zero_grad(self):
method load_state_dict (line 118) | def load_state_dict(self, state_dict):
method step (line 126) | def step(self, closure=None):
FILE: apex/optimizers/fused_sgd.py
class FusedSGD (line 7) | class FusedSGD(Optimizer):
method __init__ (line 77) | def __init__(
method __setstate__ (line 124) | def __setstate__(self, state):
method zero_grad (line 129) | def zero_grad(self):
method get_momentums (line 137) | def get_momentums(self, params):
method step (line 154) | def step(self, closure=None):
FILE: csrc/amp_C_frontend.cpp
function PYBIND11_MODULE (line 83) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/flatten_unflatten.cpp
function flatten (line 5) | at::Tensor flatten(std::vector<at::Tensor> tensors) { return torch::util...
function unflatten (line 7) | std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tenso...
function PYBIND11_MODULE (line 11) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/fused_dense.cpp
function linear_bias_forward (line 26) | at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::...
function linear_bias_backward (line 54) | std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tenso...
function linear_gelu_linear_forward (line 88) | std::vector<at::Tensor> linear_gelu_linear_forward(at::Tensor input, at:...
function linear_gelu_linear_backward (line 122) | std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at...
function PYBIND11_MODULE (line 160) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/layer_norm_cuda.cpp
function compute_n1_n2 (line 8) | void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, i...
function check_args (line 21) | void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::...
function check_args (line 26) | void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma) {
function check_args (line 30) | void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int&...
function check_args (line 56) | void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::...
function check_args (line 62) | void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::...
function layer_norm (line 78) | std::vector<at::Tensor> layer_norm(const at::Tensor& input, at::IntArray...
function layer_norm_affine (line 92) | std::vector<at::Tensor> layer_norm_affine(const at::Tensor& input, at::I...
function layer_norm_affine_mixed_dtypes (line 110) | std::vector<at::Tensor> layer_norm_affine_mixed_dtypes(const at::Tensor&...
function layer_norm_gradient (line 132) | at::Tensor layer_norm_gradient(at::Tensor& dout, const std::optional<at:...
function layer_norm_gradient_affine (line 147) | std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor& dout, con...
function rms_norm (line 175) | std::vector<at::Tensor> rms_norm(const at::Tensor& input, at::IntArrayRe...
function rms_norm_affine (line 188) | std::vector<at::Tensor> rms_norm_affine(const at::Tensor& input, at::Int...
function rms_norm_affine_mixed_dtypes (line 204) | std::vector<at::Tensor> rms_norm_affine_mixed_dtypes(const at::Tensor& i...
function rms_norm_gradient (line 223) | at::Tensor rms_norm_gradient(at::Tensor& dout, at::Tensor& invvar, at::T...
function rms_norm_gradient_affine (line 236) | std::vector<at::Tensor> rms_norm_gradient_affine(at::Tensor& dout, at::T...
function PYBIND11_MODULE (line 252) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/fused_rotary_positional_embedding.cpp
type fused_rope (line 19) | namespace fused_rope {
function fwd (line 42) | torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, co...
function bwd (line 56) | torch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor&...
function fwd_cached (line 71) | torch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& co...
function bwd_cached (line 89) | torch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::...
function fwd_thd (line 109) | torch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor&...
function bwd_thd (line 123) | torch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::...
function fwd_2d (line 137) | torch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& ...
function bwd_2d (line 154) | torch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::T...
function PYBIND11_MODULE (line 175) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/fused_rotary_positional_embedding.h
function fused_rope_block_forward (line 28) | void fused_rope_block_forward(const scalar_t* src, const float* freqs, s...
function fused_rope_block_backward (line 63) | void fused_rope_block_backward(const scalar_t* src, const float* freqs, ...
function fused_rope_forward (line 99) | void fused_rope_forward(const int h, const int d, const int d2, const in...
function fused_rope_backward (line 111) | void fused_rope_backward(const int h, const int d, const int d2, const i...
function fused_rope_cached_block_forward (line 123) | void fused_rope_cached_block_forward(const scalar_t_0* src, const scalar...
function fused_rope_cached_block_backward (line 158) | void fused_rope_cached_block_backward(const scalar_t_0* src, const scala...
function fused_rope_cached_forward (line 193) | void fused_rope_cached_forward(const int h, const int d, const int d2, c...
function fused_rope_cached_backward (line 206) | void fused_rope_cached_backward(const int h, const int d, const int d2, ...
function fused_rope_thd_forward (line 219) | void fused_rope_thd_forward(const int h, const int d, const int d2, cons...
function fused_rope_thd_backward (line 233) | void fused_rope_thd_backward(const int h, const int d, const int d2, con...
function fused_rope_2d_forward (line 247) | void fused_rope_2d_forward(const int ih, const int iw, const int h, cons...
function fused_rope_2d_backward (line 269) | void fused_rope_2d_backward(const int ih, const int iw, const int h, con...
FILE: csrc/megatron/fused_weight_gradient_dense.cpp
function PYBIND11_MODULE (line 10) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/generic_scaled_masked_softmax.cpp
type multihead_attn (line 22) | namespace multihead_attn {
type fused_softmax (line 23) | namespace fused_softmax {
type generic_scaled_masked_softmax (line 24) | namespace generic_scaled_masked_softmax {
function fwd (line 30) | torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const&...
function bwd (line 39) | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor...
function PYBIND11_MODULE (line 57) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/generic_scaled_masked_softmax.h
function T (line 31) | T operator()(T a, T b) const { return a + b; }
function T (line 36) | T operator()(T a, T b) const { return a < b ? b : a; }
function scaled_masked_softmax_warp_backward_new (line 59) | int log2_elements>
FILE: csrc/megatron/scaled_masked_softmax.cpp
type multihead_attn (line 22) | namespace multihead_attn {
type fused_softmax (line 23) | namespace fused_softmax {
type scaled_masked_softmax (line 24) | namespace scaled_masked_softmax {
function fwd (line 32) | torch::Tensor fwd(torch::Tensor& input, torch::Tensor& mask, float...
function bwd (line 43) | torch::Tensor bwd(torch::Tensor& output_grads, torch::Tensor& soft...
function get_batch_per_block (line 59) | int get_batch_per_block(int query_seq_len, int key_seq_len, int ba...
function PYBIND11_MODULE (line 67) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/scaled_masked_softmax.h
function log2_ceil (line 62) | int log2_ceil(int value) {
function T (line 70) | T operator()(T a, T b) const { return a + b; }
function T (line 75) | T operator()(T a, T b) const { return a < b ? b : a; }
function scaled_softmax_warp_forward (line 105) | int log2_elements>
FILE: csrc/megatron/scaled_softmax.cpp
type multihead_attn (line 22) | namespace multihead_attn {
type fused_softmax (line 23) | namespace fused_softmax {
type scaled_softmax (line 24) | namespace scaled_softmax {
function fwd (line 30) | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
function bwd (line 38) | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor...
function PYBIND11_MODULE (line 56) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/scaled_upper_triang_masked_softmax.cpp
type multihead_attn (line 22) | namespace multihead_attn {
type fused_softmax (line 23) | namespace fused_softmax {
type scaled_upper_triang_masked_softmax (line 24) | namespace scaled_upper_triang_masked_softmax {
function fwd (line 30) | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
function bwd (line 38) | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor...
function PYBIND11_MODULE (line 56) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/megatron/scaled_upper_triang_masked_softmax.h
function log2_ceil (line 85) | int log2_ceil(int value) {
function T (line 93) | T operator()(T a, T b) const { return a + b; }
function T (line 98) | T operator()(T a, T b) const { return a < b ? b : a; }
function scaled_upper_triang_masked_softmax_warp_forward (line 129) | int log2_elements>
FILE: csrc/mlp.cpp
function mlp_forward (line 21) | std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::v...
function mlp_backward (line 61) | std::vector<at::Tensor> mlp_backward(int use_bias, int activation, at::T...
function PYBIND11_MODULE (line 109) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: csrc/syncbn.cpp
function PYBIND11_MODULE (line 71) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: docs/source/conf.py
function patched_make_field (line 205) | def patched_make_field(self, types, domain, items, **kw):
FILE: examples/dcgan/main_amp.py
function weights_init (line 114) | def weights_init(m):
class Generator (line 123) | class Generator(nn.Module):
method __init__ (line 124) | def __init__(self, ngpu):
method forward (line 150) | def forward(self, input):
class Discriminator (line 165) | class Discriminator(nn.Module):
method __init__ (line 166) | def __init__(self, ngpu):
method forward (line 189) | def forward(self, input):
FILE: examples/imagenet/main_amp.py
function to_python_float (line 22) | def to_python_float(scalar_tensor: torch.Tensor):
function fast_collate (line 25) | def fast_collate(batch, memory_format):
function parse (line 41) | def parse():
function main (line 92) | def main():
class data_prefetcher (line 247) | class data_prefetcher():
method __init__ (line 248) | def __init__(self, loader):
method preload (line 259) | def preload(self):
method next (line 290) | def next(self):
function train (line 302) | def train(train_loader, model, criterion, optimizer, scaler, epoch):
function validate (line 397) | def validate(val_loader, model, criterion):
function save_checkpoint (line 459) | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
class AverageMeter (line 465) | class AverageMeter(object):
method __init__ (line 467) | def __init__(self):
method reset (line 470) | def reset(self):
method update (line 476) | def update(self, val, n=1):
function adjust_learning_rate (line 483) | def adjust_learning_rate(optimizer, epoch, step, len_epoch):
function accuracy (line 503) | def accuracy(output, target, topk=(1,)):
function reduce_tensor (line 519) | def reduce_tensor(tensor):
FILE: setup.py
function has_flag (line 60) | def has_flag(flag, env_var):
function get_cuda_bare_metal_version (line 68) | def get_cuda_bare_metal_version(cuda_dir):
function check_cuda_torch_binary_vs_bare_metal (line 77) | def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
function raise_if_cuda_home_none (line 95) | def raise_if_cuda_home_none(global_option: str) -> None:
function check_cudnn_version_and_warn (line 105) | def check_cudnn_version_and_warn(global_option: str, required_cudnn_vers...
class BuildExtensionSeparateDir (line 1017) | class BuildExtensionSeparateDir(BuildExtension):
method finalize_options (line 1021) | def finalize_options(self):
method build_extension (line 1026) | def build_extension(self, ext):
FILE: tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
function _prep_inputs (line 15) | def _prep_inputs(batch_size, normalized_shape, dtype):
class TestFusedLayerNorm (line 26) | class TestFusedLayerNorm(common_utils.TestCase):
method _test_fused_layer_norm (line 27) | def _test_fused_layer_norm(
method _test_fused_rms_norm (line 101) | def _test_fused_rms_norm(
method test_layer_norm_regular (line 195) | def test_layer_norm_regular(
method test_layer_norm_elemwise (line 226) | def test_layer_norm_elemwise(
method test_layer_norm_mixed (line 257) | def test_layer_norm_mixed(
method test_layer_norm_half (line 279) | def test_layer_norm_half(
method test_layer_norm_bfloat16 (line 312) | def test_layer_norm_bfloat16(
method test_rms_norm_regular (line 346) | def test_rms_norm_regular(
method test_rms_norm_elemwise (line 377) | def test_rms_norm_elemwise(
method test_rms_norm_mixed (line 409) | def test_rms_norm_mixed(
method test_rms_norm_half (line 432) | def test_rms_norm_half(
method test_rms_norm_bfloat16 (line 464) | def test_rms_norm_bfloat16(
method test_autocast_fused_layer_norm (line 488) | def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, me...
method test_autocast_fused_rms_norm (line 529) | def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memo...
method _verify_export (line 563) | def _verify_export(self, fused, fused_x):
method test_rms_export (line 587) | def test_rms_export(self):
method test_layer_norm_export (line 596) | def test_layer_norm_export(self):
method test_compile_fused_layer_norm (line 606) | def test_compile_fused_layer_norm(self, elementwise_affine):
method test_compile_fused_rms_norm (line 630) | def test_compile_fused_rms_norm(self, elementwise_affine):
FILE: tests/L0/run_mlp/test_mlp.py
class TestMLP (line 21) | class TestMLP(common_utils.TestCase):
method test_creation (line 22) | def test_creation(self):
method test_numeric (line 25) | def test_numeric(self):
method _test_mlp_impl (line 55) | def _test_mlp_impl(self, use_activation: str, bias: bool, enable_autoc...
method test_mlp (line 102) | def test_mlp(self, use_activation: str, bias: bool):
method test_mlp_autocast_fp16 (line 109) | def test_mlp_autocast_fp16(self, use_activation: str, bias: bool):
method test_no_grad (line 112) | def test_no_grad(self):
method test_performance_half (line 137) | def test_performance_half(self):
FILE: tests/L0/run_optimizers/test_adam.py
class Model (line 16) | class Model(torch.nn.Module):
method __init__ (line 17) | def __init__(self):
method forward (line 32) | def forward(self, x):
class AdamTest (line 50) | class AdamTest(unittest.TestCase):
method setUp (line 51) | def setUp(self, seed=0):
method testGradScaler (line 63) | def testGradScaler(self):
method testGradScalerCapturable (line 114) | def testGradScalerCapturable(self):
method testGradScalerCapturableMaster (line 165) | def testGradScalerCapturableMaster(self):
method testNative (line 226) | def testNative(self):
method testLargeTensor (line 272) | def testLargeTensor(self):
FILE: tests/L0/run_optimizers/test_fused_novograd.py
class Novograd (line 10) | class Novograd(Optimizer):
method __init__ (line 29) | def __init__(
method __setstate__ (line 58) | def __setstate__(self, state):
method step (line 63) | def step(self, closure=None):
class TestFusedNovoGrad (line 130) | class TestFusedNovoGrad(TestFusedOptimizer):
method __init__ (line 131) | def __init__(self, *args, **kwargs):
method test_float (line 162) | def test_float(self):
method test_half (line 165) | def test_half(self):
method test_multi_device (line 169) | def test_multi_device(self):
method test_multi_params (line 176) | def test_multi_params(self):
FILE: tests/L0/run_optimizers/test_fused_optimizer.py
class TestFusedOptimizer (line 10) | class TestFusedOptimizer(unittest.TestCase):
method setUp (line 11) | def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
method tearDown (line 17) | def tearDown(self):
method gen_param_optim (line 20) | def gen_param_optim(self, tensors, options, tst_options=None):
method gen_grad (line 38) | def gen_grad(self, ref_param, tst_param):
method gen_mixed_grad (line 43) | def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
method get_max_diff (line 50) | def get_max_diff(self, ref_param, tst_param):
method gen_single_type_test (line 63) | def gen_single_type_test(
class TestFusedAdam (line 93) | class TestFusedAdam(TestFusedOptimizer):
method setUp (line 94) | def setUp(self):
method test_float (line 106) | def test_float(self):
method test_half (line 111) | def test_half(self):
method test_bfloat16 (line 114) | def test_bfloat16(self):
method test_multi_device (line 118) | def test_multi_device(self):
method test_multi_params (line 125) | def test_multi_params(self):
method test_scale (line 142) | def test_scale(self):
method test_fp16_output (line 158) | def test_fp16_output(self):
method test_adam_option (line 179) | def test_adam_option(self):
method test_frozen_model (line 201) | def test_frozen_model(self):
class TestFusedAdagrad (line 227) | class TestFusedAdagrad(TestFusedOptimizer):
method __init__ (line 228) | def __init__(self, *args, **kwargs):
method test_float (line 234) | def test_float(self):
method test_half (line 238) | def test_half(self):
method test_multi_device (line 242) | def test_multi_device(self):
method test_multi_params (line 248) | def test_multi_params(self):
method test_multi_params_different_devices_throws (line 266) | def test_multi_params_different_devices_throws(self):
method test_adagrad_option (line 278) | def test_adagrad_option(self):
class TestFusedSGD (line 295) | class TestFusedSGD(TestFusedOptimizer):
method __init__ (line 296) | def __init__(self, *args, **kwargs):
method test_float (line 302) | def test_float(self):
method test_half (line 305) | def test_half(self):
method test_multi_device (line 309) | def test_multi_device(self):
FILE: tests/L0/run_optimizers/test_lamb.py
class RefLAMB (line 11) | class RefLAMB(Optimizer):
method __init__ (line 30) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weig...
method step (line 53) | def step(self, closure=None):
class TestLamb (line 172) | class TestLamb(unittest.TestCase):
method setUp (line 173) | def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
method tearDown (line 179) | def tearDown(self):
method gen_param_optim (line 182) | def gen_param_optim(self, tensors, lamb_option):
method gen_grad (line 194) | def gen_grad(self, ref_param, tst_param):
method gen_mixed_grad (line 199) | def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
method gen_single_type_test (line 206) | def gen_single_type_test(self, param_type=torch.float, device="cuda"):
class TestFusedLAMB (line 236) | class TestFusedLAMB(TestLamb):
method __init__ (line 237) | def __init__(self, *args, **kwargs):
method test_float (line 242) | def test_float(self):
method test_half (line 246) | def test_half(self):
method test_multi_device (line 250) | def test_multi_device(self):
method test_multi_params (line 256) | def test_multi_params(self):
method test_lamb_option (line 278) | def test_lamb_option(self):
class TestFusedMixedPrecisionLamb (line 299) | class TestFusedMixedPrecisionLamb(TestLamb):
method __init__ (line 300) | def __init__(self, *args, **kwargs):
method test_float (line 305) | def test_float(self):
method test_bfloat16 (line 308) | def test_bfloat16(self):
method test_half (line 312) | def test_half(self):
method test_multi_device (line 317) | def test_multi_device(self):
method test_multi_params (line 323) | def test_multi_params(self):
method test_lamb_option (line 345) | def test_lamb_option(self):
FILE: tests/L0/run_test.py
function parse_args (line 33) | def parse_args():
function main (line 61) | def main(args: argparse.Namespace) -> None:
FILE: tests/L1/common/main_amp.py
function fast_collate (line 137) | def fast_collate(batch):
function main (line 178) | def main():
class data_prefetcher (line 354) | class data_prefetcher:
method __init__ (line 355) | def __init__(self, loader):
method preload (line 366) | def preload(self):
method next (line 383) | def next(self):
function train (line 391) | def train(train_loader, model, criterion, optimizer, epoch):
function validate (line 503) | def validate(val_loader, model, criterion):
function save_checkpoint (line 569) | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
class AverageMeter (line 575) | class AverageMeter(object):
method __init__ (line 578) | def __init__(self):
method reset (line 581) | def reset(self):
method update (line 587) | def update(self, val, n=1):
function adjust_learning_rate (line 594) | def adjust_learning_rate(optimizer, epoch, step, len_epoch):
function accuracy (line 614) | def accuracy(output, target, topk=(1,)):
function reduce_tensor (line 630) | def reduce_tensor(tensor):
FILE: tests/distributed/DDP/ddp_race_condition_test.py
class Model (line 27) | class Model(Module):
method __init__ (line 28) | def __init__(self):
method forward (line 33) | def forward(self, input):
function info (line 60) | def info(name, param, val):
FILE: tests/distributed/synced_batchnorm/python_single_gpu_unit_test.py
function compare (line 5) | def compare(desc, inp1, inp2, error):
FILE: tests/distributed/synced_batchnorm/single_gpu_unit_test.py
function compare (line 15) | def compare(desc, inp1, inp2, error):
FILE: tests/distributed/synced_batchnorm/test_groups.py
function compare (line 10) | def compare(desc, inp1, inp2, error):
FILE: tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
function compare (line 13) | def compare(desc, inp1, inp2, error=1e-5):
FILE: tests/distributed/synced_batchnorm/two_gpu_unit_test.py
function compare (line 10) | def compare(desc, inp1, inp2, error):
Condensed preview — 419 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,909K chars).
[
{
"path": ".clang-format",
"chars": 95,
"preview": "# Start with a built-in style and modify it\nBasedOnStyle: Google\n\n# Overrides\nColumnLimit: 120\n"
},
{
"path": ".git-blame-ignore-revs",
"chars": 444,
"preview": "# Commits to ignore in git-blame\n# These commits are bulk formatting or refactoring changes that should be skipped when "
},
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 623,
"preview": "---\nname: Bug report\nabout: Create a report to help us improve apex\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\n**Describe"
},
{
"path": ".gitignore",
"chars": 2184,
"preview": "apex.egg-info\ndist\nbuild\ndocs/build\n*~\n__pycache__\n.vscode\n\n# Copied from https://raw.githubusercontent.com/github/gitig"
},
{
"path": ".gitmodules",
"chars": 306,
"preview": "[submodule \"apex/contrib/csrc/multihead_attn/cutlass\"]\n\tpath = apex/contrib/csrc/multihead_attn/cutlass\n\turl = https://g"
},
{
"path": ".nojekyll",
"chars": 0,
"preview": ""
},
{
"path": ".pre-commit-config.yaml",
"chars": 475,
"preview": "repos:\n- repo: https://github.com/pre-commit/mirrors-clang-format\n rev: v22.1.1 # Or pin to your preferred clang-format"
},
{
"path": "LICENSE",
"chars": 1449,
"preview": "All rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification, are permitted pro"
},
{
"path": "README.md",
"chars": 8301,
"preview": "# Introduction\n\nThis repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training"
},
{
"path": "apex/__init__.py",
"chars": 1682,
"preview": "import logging\nimport warnings\n\n# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#unde"
},
{
"path": "apex/_autocast_utils.py",
"chars": 664,
"preview": "from typing import Optional, Sequence\n\nimport torch\n\n\n__all__ = [\"_cast_if_autocast_enabled\"]\n\n\ndef _get_autocast_dtypes"
},
{
"path": "apex/contrib/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "apex/contrib/bottleneck/__init__.py",
"chars": 190,
"preview": "from .bottleneck import Bottleneck, SpatialBottleneck\nfrom .halo_exchangers import (\n HaloExchangerNoComm,\n HaloEx"
},
{
"path": "apex/contrib/bottleneck/bottleneck.py",
"chars": 42136,
"preview": "import functools as func\n\nimport torch\nfrom torch import nn\n\nfrom apex import check_cudnn_version_and_warn\nimport fast_b"
},
{
"path": "apex/contrib/bottleneck/halo_exchangers.py",
"chars": 11118,
"preview": "import torch\nimport nccl_p2p_cuda as inc\nimport peer_memory_cuda as pm\n\n\n# Communication free halo exchanger.\n# NB! This"
},
{
"path": "apex/contrib/bottleneck/test.py",
"chars": 3472,
"preview": "import torch\nfrom bottleneck import Bottleneck\n\ntorch.manual_seed(23337)\n\n# use True to print layerwise sum for all outp"
},
{
"path": "apex/contrib/clip_grad/__init__.py",
"chars": 39,
"preview": "from .clip_grad import clip_grad_norm_\n"
},
{
"path": "apex/contrib/clip_grad/clip_grad.py",
"chars": 4347,
"preview": "from typing import Union, Iterable\n\nimport torch\n\n_kernel_import_succeeded = False\ntry:\n import amp_C\n from apex.m"
},
{
"path": "apex/contrib/conv_bias_relu/__init__.py",
"chars": 115,
"preview": "from .conv_bias_relu import (\n ConvBiasReLU,\n ConvBias,\n ConvBiasMaskReLU,\n ConvFrozenScaleBiasReLU,\n)\n"
},
{
"path": "apex/contrib/conv_bias_relu/conv_bias_relu.py",
"chars": 3428,
"preview": "import torch\n\nfrom apex import check_cudnn_version_and_warn\nimport fused_conv_bias_relu\n\ncheck_cudnn_version_and_warn(__"
},
{
"path": "apex/contrib/csrc/bottleneck/bottleneck.cpp",
"chars": 170573,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h> // for getcudnnhandle\n#include <cudnn_frontend.h>\n#include <torch"
},
{
"path": "apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp",
"chars": 80973,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h> // for getcudnnhandle\n#include <cudnn_frontend.h>\n#include <torch"
},
{
"path": "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp",
"chars": 4590,
"preview": "#include <ATen/ATen.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <iostream>\n#include <vector>\n\n#in"
},
{
"path": "apex/contrib/csrc/cudnn_gbn/norm_sample.cpp",
"chars": 18421,
"preview": "/*\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Permission is hereby granted, free of charge, t"
},
{
"path": "apex/contrib/csrc/cudnn_gbn/norm_sample.h",
"chars": 6543,
"preview": "#pragma once\n\n/*\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Permission is hereby granted, fre"
},
{
"path": "apex/contrib/csrc/fmha/fmha_api.cpp",
"chars": 12126,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/gemm.h",
"chars": 10907,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/gmem_tile.h",
"chars": 16668,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/kernel_traits.h",
"chars": 4925,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/mask.h",
"chars": 3217,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/smem_tile.h",
"chars": 46277,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/softmax.h",
"chars": 13961,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha/utils.h",
"chars": 34298,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha.h",
"chars": 5851,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu",
"chars": 3415,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu",
"chars": 3415,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu",
"chars": 3415,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
"chars": 5354,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h",
"chars": 20570,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h",
"chars": 21267,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fill.cu",
"chars": 3373,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2023, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
"chars": 3710,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",
"chars": 3710,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu",
"chars": 3710,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu",
"chars": 6161,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h",
"chars": 18890,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_kernel.h",
"chars": 6743,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
"chars": 5958,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/fmha/src/fmha_utils.h",
"chars": 4069,
"preview": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPO"
},
{
"path": "apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp",
"chars": 2101,
"preview": "#include <torch/torch.h>\n\n#include <cstdint>\n#include <vector>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> fo"
},
{
"path": "apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu",
"chars": 10845,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n\n// Use 128-bit vectorization"
},
{
"path": "apex/contrib/csrc/gpu_direct_storage/gds.cpp",
"chars": 5383,
"preview": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#include <gds.h>\n\n// torch\n#include <c10/cuda/CUDAGuard"
},
{
"path": "apex/contrib/csrc/gpu_direct_storage/gds.h",
"chars": 809,
"preview": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#pragma once\n\n#include <cufile.h>\n#include <torch/torch"
},
{
"path": "apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp",
"chars": 788,
"preview": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#include <gds.h>\n#include <torch/extension.h>\n#include "
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc.cpp",
"chars": 36001,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc.h",
"chars": 7644,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h",
"chars": 8500,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh",
"chars": 13071,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu",
"chars": 18735,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h",
"chars": 8124,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh",
"chars": 7456,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu",
"chars": 13511,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu",
"chars": 362,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu",
"chars": 362,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu",
"chars": 362,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu",
"chars": 362,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu",
"chars": 360,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu",
"chars": 360,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu",
"chars": 361,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp",
"chars": 10493,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/macros.h",
"chars": 20094,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm/traits.h",
"chars": 3879,
"preview": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-I"
},
{
"path": "apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py",
"chars": 1515,
"preview": "import pathlib\n\n\nhw_c_list = [\n (8 * 8, 1280),\n (8 * 8, 2560),\n (16 * 16, 640),\n (16 * 16, 1280),\n (16 * "
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn.cpp",
"chars": 7532,
"preview": "#include \"gn.hpp\"\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\nnamespace group_norm_v2 {\n\ntorch::Te"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn.hpp",
"chars": 1047,
"preview": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include <cstdint>\n\nnamespace group_norm_v2 {\n\nstruct Meta {\n int64_t red_buff"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda.cu",
"chars": 2797,
"preview": "#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <cstdio>\n#include <mutex>\n#include <st"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh",
"chars": 26917,
"preview": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <cstdio>\n#include <stdex"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu",
"chars": 128,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 1280)\n\n} // namespace group_"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu",
"chars": 128,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 1920)\n\n} // namespace group_"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 320)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 640)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 960)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 1280)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 1920)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 2560)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu",
"chars": 126,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 640)\n\n} // namespace group_no"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 320)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 640)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu",
"chars": 127,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 960)\n\n} // namespace group_n"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu",
"chars": 126,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(64, 1280)\n\n} // namespace group_no"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu",
"chars": 126,
"preview": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(64, 2560)\n\n} // namespace group_no"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh",
"chars": 61769,
"preview": "#pragma once\n\n#include <cooperative_groups.h>\n\n#include \"gn_utils.hpp\"\n\nnamespace group_norm_v2 {\n\nnamespace cg = cooper"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp",
"chars": 5920,
"preview": "#pragma once\n\n#define DISPATCH_HW_C(hw, c, HW, C, ...) \\\n [&] "
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_utils.cpp",
"chars": 528,
"preview": "#include \"gn_utils.hpp\"\n\n#include <mutex>\n#include <vector>\n\nnamespace group_norm_v2 {\n\ncudaDeviceProp const& get_device"
},
{
"path": "apex/contrib/csrc/group_norm_v2/gn_utils.hpp",
"chars": 2962,
"preview": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include <cassert>\n#include <cstdio>\n#include <cstdlib>\n\n#include \"gn.hpp\"\n\n// "
},
{
"path": "apex/contrib/csrc/groupbn/batch_norm.cu",
"chars": 10952,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda.h>\n\n"
},
{
"path": "apex/contrib/csrc/groupbn/batch_norm.h",
"chars": 30609,
"preview": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements. See the NOT"
},
{
"path": "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu",
"chars": 11629,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda.h>\n\n"
},
{
"path": "apex/contrib/csrc/groupbn/batch_norm_add_relu.h",
"chars": 27937,
"preview": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements. See the NOT"
},
{
"path": "apex/contrib/csrc/groupbn/cuda_utils.h",
"chars": 340,
"preview": "#include <ATen/cuda/CUDAContext.h>\n#ifndef CUDA_UTILS_H\n#define CUDA_UTILS_H\n\nnamespace at {\nnamespace cuda {\n\nnamespace"
},
{
"path": "apex/contrib/csrc/groupbn/interface.cpp",
"chars": 6347,
"preview": "#include <ATen/ATen.h>\n#include <ATen/ArrayRef.h>\n#include <ATen/ScalarType.h>\n#include <pybind11/numpy.h>\n#include <pyb"
},
{
"path": "apex/contrib/csrc/groupbn/ipc.cu",
"chars": 4063,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n\n#define cudaCheckErrors(msg) "
},
{
"path": "apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h",
"chars": 99829,
"preview": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements. See the NOT"
},
{
"path": "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp",
"chars": 5079,
"preview": "#include <torch/torch.h>\n\n#include <cstdint>\n#include <vector>\n\nvoid index_mul_2d_float_foward_cuda(at::Tensor& out, con"
},
{
"path": "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu",
"chars": 17797,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <ATen/cuda/Atomic.c"
},
{
"path": "apex/contrib/csrc/layer_norm/ln.h",
"chars": 4745,
"preview": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n#include <stdio.h>\n\n#include <unordered_"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_api.cpp",
"chars": 8743,
"preview": "#include <torch/extension.h>\n\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ln.h\"\n\n/*\n\nSupported Type combinations:\n\ninpu"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh",
"chars": 11486,
"preview": "#pragma once\n\n#include \"ln_utils.cuh\"\n\nnamespace layer_norm {\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__("
},
{
"path": "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
"chars": 11818,
"preview": "#include \"ln.h\"\n#include \"ln_bwd_kernels.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ln_utils.cuh\"\n\nusing namespace lay"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
"chars": 10978,
"preview": "#include \"ln.h\"\n#include \"ln_fwd_kernels.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ln_utils.cuh\"\n\nusing namespace lay"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh",
"chars": 3379,
"preview": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n\nnamespace layer_norm {\n\ntemplate <typename Ktraits>\n__global__ __"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_kernel_traits.h",
"chars": 5680,
"preview": "#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nname"
},
{
"path": "apex/contrib/csrc/layer_norm/ln_utils.cuh",
"chars": 22062,
"preview": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#include <cassert>\n\n#include \"ln.h\"\n\n//////////////////////"
},
{
"path": "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu",
"chars": 4330,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/dropout.cuh",
"chars": 9658,
"preview": "#pragma once\n#include <ATen/ATen.h>\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/"
},
{
"path": "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
"chars": 16953,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
"chars": 20877,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/layer_norm.cuh",
"chars": 23984,
"preview": "#pragma once\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/DeviceUtils.cuh>\n\nn"
},
{
"path": "apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
"chars": 4770,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp",
"chars": 41869,
"preview": "#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #"
},
{
"path": "apex/contrib/csrc/multihead_attn/philox.cuh",
"chars": 3532,
"preview": "#pragma once\n// Philox CUDA.\n\nnamespace {\n\nclass Philox {\n public:\n __device__ inline Philox(unsigned long long seed, u"
},
{
"path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
"chars": 14228,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
"chars": 14488,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
"chars": 13827,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
"chars": 18142,
"preview": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profil"
},
{
"path": "apex/contrib/csrc/multihead_attn/softmax.cuh",
"chars": 119494,
"preview": "#pragma once\n#include <curand_kernel.h>\n\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n\n#include \"philox.cuh\"\n\n#ifdef OLD_GEN"
},
{
"path": "apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh",
"chars": 27253,
"preview": "#pragma once\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n\n#include"
},
{
"path": "apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp",
"chars": 1834,
"preview": "#include <c10/cuda/CUDACachingAllocator.h>\n#include <c10/util/Exception.h>\n#include <nccl.h>\n#include <torch/csrc/cuda/C"
},
{
"path": "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
"chars": 1469,
"preview": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Versio"
},
{
"path": "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"chars": 7194,
"preview": "#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <torch/extension.h>\n\n#include <ca"
},
{
"path": "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh",
"chars": 1471,
"preview": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Versio"
},
{
"path": "apex/contrib/csrc/nccl_p2p/nccl_version.cpp",
"chars": 305,
"preview": "// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n// This file is used to check the version of NCCL detect"
},
{
"path": "apex/contrib/csrc/nccl_p2p/nccl_version_check.cu",
"chars": 249,
"preview": "// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\n// This file is used to check the version of NCCL detec"
},
{
"path": "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp",
"chars": 6062,
"preview": "#include <torch/extension.h>\n\n// CUDA forward declaration\nvoid fused_strided_check_finite(at::Tensor& overflow_flag, at:"
},
{
"path": "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu",
"chars": 30746,
"preview": "#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n\n#include <cmath>\n\n#include \"ATen/ATen.h\"\n#include \"ATen/"
},
{
"path": "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp",
"chars": 661,
"preview": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<"
},
{
"path": "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu",
"chars": 9032,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exception"
},
{
"path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp",
"chars": 2042,
"preview": "#include <torch/extension.h>\n\nvoid multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag,\n "
},
{
"path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu",
"chars": 19874,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exception"
},
{
"path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp",
"chars": 1674,
"preview": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag,\n "
},
{
"path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu",
"chars": 15595,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exception"
},
{
"path": "apex/contrib/csrc/peer_memory/peer_memory.cpp",
"chars": 1928,
"preview": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Versio"
},
{
"path": "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"chars": 31991,
"preview": "#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda_runtime_api.h>\n#include <to"
},
{
"path": "apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh",
"chars": 2495,
"preview": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Versio"
},
{
"path": "apex/contrib/csrc/transducer/transducer_joint.cpp",
"chars": 2695,
"preview": "#include <ATen/Functions.h>\r\n#include <torch/extension.h>\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be"
},
{
"path": "apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
"chars": 33996,
"preview": "#include <ATen/AccumulateType.h>\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n#include <curand_kernel.h>\r\n#include <to"
},
{
"path": "apex/contrib/csrc/transducer/transducer_loss.cpp",
"chars": 2713,
"preview": "#include <torch/extension.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tens"
},
{
"path": "apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
"chars": 28552,
"preview": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <c"
},
{
"path": "apex/contrib/csrc/xentropy/interface.cpp",
"chars": 2170,
"preview": "#include <torch/extension.h>\n\n#include <string>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> softmax_xentropy_"
},
{
"path": "apex/contrib/csrc/xentropy/xentropy_kernel.cu",
"chars": 25037,
"preview": "/**\n * From PyTorch:\n *\n * Copyright (c) 2016- Facebook, Inc (Adam Paszke)\n * Copyright (c) 2014- Fac"
},
{
"path": "apex/contrib/cudnn_gbn/__init__.py",
"chars": 41,
"preview": "from .batch_norm import GroupBatchNorm2d\n"
},
{
"path": "apex/contrib/cudnn_gbn/batch_norm.py",
"chars": 7191,
"preview": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\nfrom torch import Te"
},
{
"path": "apex/contrib/examples/gpu_direct_storage/benchmark_load.py",
"chars": 1978,
"preview": "import timeit\nimport torch\nimport apex.contrib.gpu_direct_storage as gds\n\ndef run_benchmark_torch_load():\n sizes = [2"
},
{
"path": "apex/contrib/examples/gpu_direct_storage/benchmark_save.py",
"chars": 1141,
"preview": "import os\nimport timeit\nimport torch\nimport apex.contrib.gpu_direct_storage as gds\n\ndef run_benchmark(func):\n sizes ="
},
{
"path": "apex/contrib/examples/gpu_direct_storage/example_load.py",
"chars": 292,
"preview": "import torch\nimport apex.contrib.gpu_direct_storage as gds\n\nfor size in [128, 1024, 8192]:\n x = torch.empty(size, dev"
},
{
"path": "apex/contrib/examples/gpu_direct_storage/example_save.py",
"chars": 214,
"preview": "import torch\nimport apex.contrib.gpu_direct_storage as gds\n\nfor size in [128, 1024, 8192]:\n x = torch.linspace(0, 1, "
},
{
"path": "apex/contrib/examples/multihead_attn/func_test_multihead_attn.py",
"chars": 5708,
"preview": "import torch\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn"
},
{
"path": "apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py",
"chars": 6131,
"preview": "import torch\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn"
},
{
"path": "apex/contrib/examples/nccl_allocator/allreduce.py",
"chars": 600,
"preview": "import os\nimport torch\nimport torch.distributed as dist\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nassert os."
},
{
"path": "apex/contrib/examples/nccl_allocator/cache.py",
"chars": 1439,
"preview": "import torch\nimport apex.contrib.nccl_allocator as nccl_allocator\nfrom pynvml.smi import nvidia_smi\n\ndef set_device(dev)"
},
{
"path": "apex/contrib/examples/nccl_allocator/change_cuda_allocator.py",
"chars": 349,
"preview": "import torch\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nnccl_allocator.init()\nnrep = 6\npool = nccl_allocator."
},
{
"path": "apex/contrib/examples/nccl_allocator/toy_ddp.py",
"chars": 1583,
"preview": "import os\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.distributed as dist\nfrom torch.nn."
},
{
"path": "apex/contrib/fmha/__init__.py",
"chars": 26,
"preview": "from .fmha import FMHAFun\n"
},
{
"path": "apex/contrib/fmha/fmha.py",
"chars": 4420,
"preview": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPOR"
},
{
"path": "apex/contrib/focal_loss/__init__.py",
"chars": 266,
"preview": "try:\n import torch\n import focal_loss_cuda\n from .focal_loss import focal_loss\n\n del torch\n del focal_los"
},
{
"path": "apex/contrib/focal_loss/focal_loss.py",
"chars": 1499,
"preview": "import torch\n\nimport focal_loss_cuda\n\n\nclass FocalLoss(torch.autograd.Function):\n @staticmethod\n def forward(\n "
},
{
"path": "apex/contrib/gpu_direct_storage/README.md",
"chars": 602,
"preview": "# APEX GPUDirect Storage\n\nThis module aims to add a PyTorch extension for [GPUDirect Storage](https://developer.nvidia.c"
},
{
"path": "apex/contrib/gpu_direct_storage/__init__.py",
"chars": 735,
"preview": "from _apex_gpu_direct_storage import _GDSFile\nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef GDSFile(filena"
},
{
"path": "apex/contrib/group_norm/__init__.py",
"chars": 26,
"preview": "from .group_norm import *\n"
},
{
"path": "apex/contrib/group_norm/group_norm.py",
"chars": 14376,
"preview": "#!/usr/bin/env python\n# coding: utf-8\n\n#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. A"
},
{
"path": "apex/contrib/groupbn/__init__.py",
"chars": 233,
"preview": "try:\n import torch\n import bnp\n from .batch_norm import BatchNorm2d_NHWC\n\n del torch\n del bnp\n del bat"
},
{
"path": "apex/contrib/groupbn/batch_norm.py",
"chars": 13985,
"preview": "import torch\nimport numpy as np\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nimport bnp\n\n\nclass bn_NHWC_impl(torch"
},
{
"path": "apex/contrib/index_mul_2d/__init__.py",
"chars": 39,
"preview": "from .index_mul_2d import index_mul_2d\n"
},
{
"path": "apex/contrib/index_mul_2d/index_mul_2d.py",
"chars": 4283,
"preview": "import torch\n\nimport fused_index_mul_2d\n\n\nclass IndexMul2d_(torch.autograd.Function):\n \"\"\"\n Currently only support"
},
{
"path": "apex/contrib/layer_norm/__init__.py",
"chars": 38,
"preview": "from .layer_norm import FastLayerNorm\n"
},
{
"path": "apex/contrib/layer_norm/layer_norm.py",
"chars": 2109,
"preview": "import torch\nfrom torch.nn import init\n\nfrom apex._autocast_utils import _cast_if_autocast_enabled\nimport fast_layer_nor"
},
{
"path": "apex/contrib/multihead_attn/README.md",
"chars": 2267,
"preview": "# Fast Multihead Attention \n\nThis implementation has two main features :\n* A C++ implementation to avoid the CPU overhea"
},
{
"path": "apex/contrib/multihead_attn/__init__.py",
"chars": 176,
"preview": "from .self_multihead_attn import SelfMultiheadAttn\nfrom .encdec_multihead_attn import EncdecMultiheadAttn\nfrom .mask_sof"
},
{
"path": "apex/contrib/multihead_attn/encdec_multihead_attn.py",
"chars": 7727,
"preview": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .enc"
},
{
"path": "apex/contrib/multihead_attn/encdec_multihead_attn_func.py",
"chars": 17787,
"preview": "import torch\nimport torch.nn.functional as F\n\n\nclass EncdecAttnFunc(torch.autograd.Function):\n @staticmethod\n def "
}
]
// ... and 219 more files (download for full content)
About this extraction
This page contains the full source code of the NVIDIA/apex GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 419 files (3.6 MB), approximately 965.5k tokens, and a symbol index with 1684 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.