Repository: NVIDIA/apex Branch: master Commit: ba32a259b7aa Files: 419 Total size: 3.6 MB Directory structure: gitextract_8yaiblk9/ ├── .clang-format ├── .git-blame-ignore-revs ├── .github/ │ └── ISSUE_TEMPLATE/ │ └── bug_report.md ├── .gitignore ├── .gitmodules ├── .nojekyll ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── apex/ │ ├── __init__.py │ ├── _autocast_utils.py │ ├── contrib/ │ │ ├── __init__.py │ │ ├── bottleneck/ │ │ │ ├── __init__.py │ │ │ ├── bottleneck.py │ │ │ ├── halo_exchangers.py │ │ │ └── test.py │ │ ├── clip_grad/ │ │ │ ├── __init__.py │ │ │ └── clip_grad.py │ │ ├── conv_bias_relu/ │ │ │ ├── __init__.py │ │ │ └── conv_bias_relu.py │ │ ├── csrc/ │ │ │ ├── bottleneck/ │ │ │ │ └── bottleneck.cpp │ │ │ ├── conv_bias_relu/ │ │ │ │ └── conv_bias_relu.cpp │ │ │ ├── cudnn_gbn/ │ │ │ │ ├── cudnn_gbn.cpp │ │ │ │ ├── norm_sample.cpp │ │ │ │ └── norm_sample.h │ │ │ ├── fmha/ │ │ │ │ ├── fmha_api.cpp │ │ │ │ └── src/ │ │ │ │ ├── fmha/ │ │ │ │ │ ├── gemm.h │ │ │ │ │ ├── gmem_tile.h │ │ │ │ │ ├── kernel_traits.h │ │ │ │ │ ├── mask.h │ │ │ │ │ ├── smem_tile.h │ │ │ │ │ ├── softmax.h │ │ │ │ │ └── utils.h │ │ │ │ ├── fmha.h │ │ │ │ ├── fmha_dgrad_fp16_128_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_256_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_384_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_fp16_512_64_kernel.sm80.cu │ │ │ │ ├── fmha_dgrad_kernel_1xN_reload.h │ │ │ │ ├── fmha_dgrad_kernel_1xN_reload_nl.h │ │ │ │ ├── fmha_fill.cu │ │ │ │ ├── fmha_fprop_fp16_128_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_256_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_384_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_fp16_512_64_kernel.sm80.cu │ │ │ │ ├── fmha_fprop_kernel_1xN.h │ │ │ │ ├── fmha_kernel.h │ │ │ │ ├── fmha_noloop_reduce.cu │ │ │ │ └── fmha_utils.h │ │ │ ├── focal_loss/ │ │ │ │ ├── focal_loss_cuda.cpp │ │ │ │ └── focal_loss_cuda_kernel.cu │ │ │ ├── gpu_direct_storage/ │ │ │ │ ├── gds.cpp │ │ │ │ ├── gds.h │ │ │ │ └── gds_pybind.cpp │ │ │ ├── group_norm/ │ │ │ │ ├── group_norm_nhwc.cpp │ │ │ │ ├── group_norm_nhwc.h │ │ │ │ ├── group_norm_nhwc_bwd_one_pass.h │ │ │ │ ├── group_norm_nhwc_bwd_one_pass_kernel.cuh │ │ │ │ ├── group_norm_nhwc_bwd_two_pass.cu │ │ │ │ ├── group_norm_nhwc_fwd_one_pass.h │ │ │ │ ├── group_norm_nhwc_fwd_one_pass_kernel.cuh │ │ │ │ ├── group_norm_nhwc_fwd_two_pass.cu │ │ │ │ ├── group_norm_nhwc_one_pass_10.cu │ │ │ │ ├── group_norm_nhwc_one_pass_112.cu │ │ │ │ ├── group_norm_nhwc_one_pass_12.cu │ │ │ │ ├── group_norm_nhwc_one_pass_120.cu │ │ │ │ ├── group_norm_nhwc_one_pass_128.cu │ │ │ │ ├── group_norm_nhwc_one_pass_14.cu │ │ │ │ ├── group_norm_nhwc_one_pass_16.cu │ │ │ │ ├── group_norm_nhwc_one_pass_160.cu │ │ │ │ ├── group_norm_nhwc_one_pass_20.cu │ │ │ │ ├── group_norm_nhwc_one_pass_24.cu │ │ │ │ ├── group_norm_nhwc_one_pass_26.cu │ │ │ │ ├── group_norm_nhwc_one_pass_28.cu │ │ │ │ ├── group_norm_nhwc_one_pass_30.cu │ │ │ │ ├── group_norm_nhwc_one_pass_32.cu │ │ │ │ ├── group_norm_nhwc_one_pass_4.cu │ │ │ │ ├── group_norm_nhwc_one_pass_40.cu │ │ │ │ ├── group_norm_nhwc_one_pass_42.cu │ │ │ │ ├── group_norm_nhwc_one_pass_48.cu │ │ │ │ ├── group_norm_nhwc_one_pass_56.cu │ │ │ │ ├── group_norm_nhwc_one_pass_60.cu │ │ │ │ ├── group_norm_nhwc_one_pass_64.cu │ │ │ │ ├── group_norm_nhwc_one_pass_70.cu │ │ │ │ ├── group_norm_nhwc_one_pass_8.cu │ │ │ │ ├── group_norm_nhwc_one_pass_80.cu │ │ │ │ ├── group_norm_nhwc_one_pass_84.cu │ │ │ │ ├── group_norm_nhwc_one_pass_96.cu │ │ │ │ ├── group_norm_nhwc_one_pass_98.cu │ │ │ │ ├── group_norm_nhwc_op.cpp │ │ │ │ ├── macros.h │ │ │ │ └── traits.h │ │ │ ├── group_norm_v2/ │ │ │ │ ├── generate_gn_cuda_inst.py │ │ │ │ ├── gn.cpp │ │ │ │ ├── gn.hpp │ │ │ │ ├── gn_cuda.cu │ │ │ │ ├── gn_cuda_host_template.cuh │ │ │ │ ├── gn_cuda_inst_1024_1280.cu │ │ │ │ ├── gn_cuda_inst_1024_1920.cu │ │ │ │ ├── gn_cuda_inst_1024_320.cu │ │ │ │ ├── gn_cuda_inst_1024_640.cu │ │ │ │ ├── gn_cuda_inst_1024_960.cu │ │ │ │ ├── gn_cuda_inst_256_1280.cu │ │ │ │ ├── gn_cuda_inst_256_1920.cu │ │ │ │ ├── gn_cuda_inst_256_2560.cu │ │ │ │ ├── gn_cuda_inst_256_640.cu │ │ │ │ ├── gn_cuda_inst_4096_320.cu │ │ │ │ ├── gn_cuda_inst_4096_640.cu │ │ │ │ ├── gn_cuda_inst_4096_960.cu │ │ │ │ ├── gn_cuda_inst_64_1280.cu │ │ │ │ ├── gn_cuda_inst_64_2560.cu │ │ │ │ ├── gn_cuda_kernel.cuh │ │ │ │ ├── gn_dispatch_hw_c.hpp │ │ │ │ ├── gn_utils.cpp │ │ │ │ └── gn_utils.hpp │ │ │ ├── groupbn/ │ │ │ │ ├── batch_norm.cu │ │ │ │ ├── batch_norm.h │ │ │ │ ├── batch_norm_add_relu.cu │ │ │ │ ├── batch_norm_add_relu.h │ │ │ │ ├── cuda_utils.h │ │ │ │ ├── interface.cpp │ │ │ │ ├── ipc.cu │ │ │ │ └── nhwc_batch_norm_kernel.h │ │ │ ├── index_mul_2d/ │ │ │ │ ├── index_mul_2d_cuda.cpp │ │ │ │ └── index_mul_2d_cuda_kernel.cu │ │ │ ├── layer_norm/ │ │ │ │ ├── ln.h │ │ │ │ ├── ln_api.cpp │ │ │ │ ├── ln_bwd_kernels.cuh │ │ │ │ ├── ln_bwd_semi_cuda_kernel.cu │ │ │ │ ├── ln_fwd_cuda_kernel.cu │ │ │ │ ├── ln_fwd_kernels.cuh │ │ │ │ ├── ln_kernel_traits.h │ │ │ │ └── ln_utils.cuh │ │ │ ├── multihead_attn/ │ │ │ │ ├── additive_masked_softmax_dropout_cuda.cu │ │ │ │ ├── dropout.cuh │ │ │ │ ├── encdec_multihead_attn_cuda.cu │ │ │ │ ├── encdec_multihead_attn_norm_add_cuda.cu │ │ │ │ ├── layer_norm.cuh │ │ │ │ ├── masked_softmax_dropout_cuda.cu │ │ │ │ ├── multihead_attn_frontend.cpp │ │ │ │ ├── philox.cuh │ │ │ │ ├── self_multihead_attn_bias_additive_mask_cuda.cu │ │ │ │ ├── self_multihead_attn_bias_cuda.cu │ │ │ │ ├── self_multihead_attn_cuda.cu │ │ │ │ ├── self_multihead_attn_norm_add_cuda.cu │ │ │ │ ├── softmax.cuh │ │ │ │ └── strided_batched_gemm.cuh │ │ │ ├── nccl_allocator/ │ │ │ │ └── NCCLAllocator.cpp │ │ │ ├── nccl_p2p/ │ │ │ │ ├── nccl_p2p.cpp │ │ │ │ ├── nccl_p2p_cuda.cu │ │ │ │ ├── nccl_p2p_cuda.cuh │ │ │ │ ├── nccl_version.cpp │ │ │ │ └── nccl_version_check.cu │ │ │ ├── optimizers/ │ │ │ │ ├── fused_adam_cuda.cpp │ │ │ │ ├── fused_adam_cuda_kernel.cu │ │ │ │ ├── fused_lamb_cuda.cpp │ │ │ │ ├── fused_lamb_cuda_kernel.cu │ │ │ │ ├── multi_tensor_distopt_adam.cpp │ │ │ │ ├── multi_tensor_distopt_adam_kernel.cu │ │ │ │ ├── multi_tensor_distopt_lamb.cpp │ │ │ │ └── multi_tensor_distopt_lamb_kernel.cu │ │ │ ├── peer_memory/ │ │ │ │ ├── peer_memory.cpp │ │ │ │ ├── peer_memory_cuda.cu │ │ │ │ └── peer_memory_cuda.cuh │ │ │ ├── transducer/ │ │ │ │ ├── transducer_joint.cpp │ │ │ │ ├── transducer_joint_kernel.cu │ │ │ │ ├── transducer_loss.cpp │ │ │ │ └── transducer_loss_kernel.cu │ │ │ └── xentropy/ │ │ │ ├── interface.cpp │ │ │ └── xentropy_kernel.cu │ │ ├── cudnn_gbn/ │ │ │ ├── __init__.py │ │ │ └── batch_norm.py │ │ ├── examples/ │ │ │ ├── gpu_direct_storage/ │ │ │ │ ├── benchmark_load.py │ │ │ │ ├── benchmark_save.py │ │ │ │ ├── example_load.py │ │ │ │ └── example_save.py │ │ │ ├── multihead_attn/ │ │ │ │ ├── func_test_multihead_attn.py │ │ │ │ └── perf_test_multihead_attn.py │ │ │ └── nccl_allocator/ │ │ │ ├── allreduce.py │ │ │ ├── cache.py │ │ │ ├── change_cuda_allocator.py │ │ │ └── toy_ddp.py │ │ ├── fmha/ │ │ │ ├── __init__.py │ │ │ └── fmha.py │ │ ├── focal_loss/ │ │ │ ├── __init__.py │ │ │ └── focal_loss.py │ │ ├── gpu_direct_storage/ │ │ │ ├── README.md │ │ │ └── __init__.py │ │ ├── group_norm/ │ │ │ ├── __init__.py │ │ │ └── group_norm.py │ │ ├── groupbn/ │ │ │ ├── __init__.py │ │ │ └── batch_norm.py │ │ ├── index_mul_2d/ │ │ │ ├── __init__.py │ │ │ └── index_mul_2d.py │ │ ├── layer_norm/ │ │ │ ├── __init__.py │ │ │ └── layer_norm.py │ │ ├── multihead_attn/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── encdec_multihead_attn.py │ │ │ ├── encdec_multihead_attn_func.py │ │ │ ├── fast_encdec_multihead_attn_func.py │ │ │ ├── fast_encdec_multihead_attn_norm_add_func.py │ │ │ ├── fast_self_multihead_attn_func.py │ │ │ ├── fast_self_multihead_attn_norm_add_func.py │ │ │ ├── mask_softmax_dropout_func.py │ │ │ ├── self_multihead_attn.py │ │ │ └── self_multihead_attn_func.py │ │ ├── nccl_allocator/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── nccl_allocator.py │ │ ├── openfold_triton/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── _layer_norm_backward_kernels.py │ │ │ ├── _layer_norm_config_ampere.py │ │ │ ├── _layer_norm_config_hopper.py │ │ │ ├── _layer_norm_forward_kernels.py │ │ │ ├── _mha_kernel.py │ │ │ ├── fused_adam_swa.py │ │ │ ├── layer_norm.py │ │ │ └── mha.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── distributed_fused_adam.py │ │ │ ├── distributed_fused_lamb.py │ │ │ ├── fp16_optimizer.py │ │ │ ├── fused_adam.py │ │ │ ├── fused_lamb.py │ │ │ └── fused_sgd.py │ │ ├── peer_memory/ │ │ │ ├── __init__.py │ │ │ ├── peer_halo_exchanger_1d.py │ │ │ └── peer_memory.py │ │ ├── sparsity/ │ │ │ ├── COPYRIGHT │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── asp.py │ │ │ ├── permutation_lib.py │ │ │ ├── permutation_search_kernels/ │ │ │ │ ├── CUDA_kernels/ │ │ │ │ │ └── permutation_search_kernels.cu │ │ │ │ ├── __init__.py │ │ │ │ ├── call_permutation_search_kernels.py │ │ │ │ ├── channel_swap.py │ │ │ │ ├── exhaustive_search.py │ │ │ │ └── permutation_utilities.py │ │ │ ├── permutation_tests/ │ │ │ │ ├── README.md │ │ │ │ ├── ablation_studies.sh │ │ │ │ ├── permutation_test.py │ │ │ │ ├── runtime_table.sh │ │ │ │ └── unstructured_study.sh │ │ │ ├── sparse_masklib.py │ │ │ └── test/ │ │ │ ├── checkpointing_test_part1.py │ │ │ ├── checkpointing_test_part2.py │ │ │ ├── checkpointing_test_reference.py │ │ │ ├── test_permutation_application.py │ │ │ └── toy_problem.py │ │ ├── test/ │ │ │ ├── __init__.py │ │ │ ├── bottleneck/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_bottleneck_module.py │ │ │ ├── clip_grad/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_clip_grad.py │ │ │ ├── conv_bias_relu/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_conv_bias_relu.py │ │ │ ├── cudnn_gbn/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_cudnn_gbn_with_two_gpus.py │ │ │ ├── fmha/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_fmha.py │ │ │ ├── focal_loss/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_focal_loss.py │ │ │ ├── fused_dense/ │ │ │ │ └── test_fused_dense.py │ │ │ ├── group_norm/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_group_norm.py │ │ │ ├── index_mul_2d/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_index_mul_2d.py │ │ │ ├── layer_norm/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_fast_layer_norm.py │ │ │ ├── multihead_attn/ │ │ │ │ ├── __init__.py │ │ │ │ ├── test_encdec_multihead_attn.py │ │ │ │ ├── test_encdec_multihead_attn_norm_add.py │ │ │ │ ├── test_fast_self_multihead_attn_bias.py │ │ │ │ ├── test_mha_fused_softmax.py │ │ │ │ ├── test_self_multihead_attn.py │ │ │ │ └── test_self_multihead_attn_norm_add.py │ │ │ ├── openfold_triton/ │ │ │ │ ├── test_fused_adam_swa.py │ │ │ │ ├── test_openfold_mha.py │ │ │ │ └── test_sync_triton_auto_tune_cache_across_gpus.py │ │ │ ├── optimizers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── test_dist_adam.py │ │ │ │ └── test_distributed_fused_lamb.py │ │ │ ├── peer_memory/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_peer_halo_exchange_module.py │ │ │ ├── transducer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── test_transducer_joint.py │ │ │ │ └── test_transducer_loss.py │ │ │ └── xentropy/ │ │ │ ├── __init__.py │ │ │ └── test_label_smoothing.py │ │ ├── torchsched/ │ │ │ ├── __init__.py │ │ │ ├── backend.py │ │ │ ├── config.py │ │ │ ├── inductor/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _utils.py │ │ │ │ ├── event.py │ │ │ │ ├── graph.py │ │ │ │ ├── scheduler.py │ │ │ │ └── wrapper.py │ │ │ ├── ops/ │ │ │ │ ├── __init__.py │ │ │ │ └── layer_norm.py │ │ │ └── passes/ │ │ │ ├── __init__.py │ │ │ └── pre_grad_passes.py │ │ ├── transducer/ │ │ │ ├── __init__.py │ │ │ ├── _transducer_ref.py │ │ │ └── transducer.py │ │ └── xentropy/ │ │ ├── __init__.py │ │ └── softmax_xentropy.py │ ├── distributed_testing/ │ │ ├── __init__.py │ │ ├── _ucc_util.py │ │ └── distributed_test_base.py │ ├── fused_dense/ │ │ ├── __init__.py │ │ └── fused_dense.py │ ├── mlp/ │ │ ├── __init__.py │ │ └── mlp.py │ ├── multi_tensor_apply/ │ │ ├── __init__.py │ │ └── multi_tensor_apply.py │ ├── normalization/ │ │ ├── __init__.py │ │ └── fused_layer_norm.py │ └── optimizers/ │ ├── __init__.py │ ├── fused_adagrad.py │ ├── fused_adam.py │ ├── fused_lamb.py │ ├── fused_mixed_precision_lamb.py │ ├── fused_novograd.py │ └── fused_sgd.py ├── csrc/ │ ├── amp_C_frontend.cpp │ ├── flatten_unflatten.cpp │ ├── fused_dense.cpp │ ├── fused_dense_cuda.cu │ ├── layer_norm_cuda.cpp │ ├── layer_norm_cuda_kernel.cu │ ├── megatron/ │ │ ├── fused_rotary_positional_embedding.cpp │ │ ├── fused_rotary_positional_embedding.h │ │ ├── fused_rotary_positional_embedding_cuda.cu │ │ ├── fused_weight_gradient_dense.cpp │ │ ├── fused_weight_gradient_dense_16bit_prec_cuda.cu │ │ ├── fused_weight_gradient_dense_cuda.cu │ │ ├── generic_scaled_masked_softmax.cpp │ │ ├── generic_scaled_masked_softmax.h │ │ ├── generic_scaled_masked_softmax_cuda.cu │ │ ├── scaled_masked_softmax.cpp │ │ ├── scaled_masked_softmax.h │ │ ├── scaled_masked_softmax_cuda.cu │ │ ├── scaled_softmax.cpp │ │ ├── scaled_softmax_cuda.cu │ │ ├── scaled_upper_triang_masked_softmax.cpp │ │ ├── scaled_upper_triang_masked_softmax.h │ │ └── scaled_upper_triang_masked_softmax_cuda.cu │ ├── mlp.cpp │ ├── mlp_cuda.cu │ ├── multi_tensor_adagrad.cu │ ├── multi_tensor_adam.cu │ ├── multi_tensor_apply.cuh │ ├── multi_tensor_axpby_kernel.cu │ ├── multi_tensor_l2norm_kernel.cu │ ├── multi_tensor_l2norm_kernel_mp.cu │ ├── multi_tensor_l2norm_scale_kernel.cu │ ├── multi_tensor_lamb.cu │ ├── multi_tensor_lamb_mp.cu │ ├── multi_tensor_lamb_stage_1.cu │ ├── multi_tensor_lamb_stage_2.cu │ ├── multi_tensor_novograd.cu │ ├── multi_tensor_scale_kernel.cu │ ├── multi_tensor_sgd_kernel.cu │ ├── static_switch.h │ ├── syncbn.cpp │ ├── type_shim.h │ ├── update_scale_hysteresis.cu │ └── welford.cu ├── docs/ │ ├── Makefile │ └── source/ │ ├── _static/ │ │ └── css/ │ │ └── pytorch_theme.css │ ├── _templates/ │ │ └── layout.html │ ├── conf.py │ ├── index.rst │ ├── layernorm.rst │ └── optimizers.rst ├── examples/ │ ├── README.md │ ├── dcgan/ │ │ ├── README.md │ │ └── main_amp.py │ ├── docker/ │ │ ├── Dockerfile │ │ └── README.md │ ├── imagenet/ │ │ ├── README.md │ │ └── main_amp.py │ └── simple/ │ └── distributed/ │ ├── README.md │ ├── distributed_data_parallel.py │ └── run.sh ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── setup.py └── tests/ ├── L0/ │ ├── run_fused_layer_norm/ │ │ └── test_fused_layer_norm.py │ ├── run_mlp/ │ │ └── test_mlp.py │ ├── run_optimizers/ │ │ ├── __init__.py │ │ ├── test_adam.py │ │ ├── test_fused_novograd.py │ │ ├── test_fused_optimizer.py │ │ └── test_lamb.py │ └── run_test.py ├── L1/ │ ├── common/ │ │ ├── compare.py │ │ ├── main_amp.py │ │ └── run_test.sh │ ├── cross_product/ │ │ └── run.sh │ └── cross_product_distributed/ │ └── run.sh ├── distributed/ │ ├── DDP/ │ │ ├── ddp_race_condition_test.py │ │ └── run_race_test.sh │ ├── amp_master_params/ │ │ ├── amp_master_params.py │ │ ├── compare.py │ │ └── run.sh │ └── synced_batchnorm/ │ ├── python_single_gpu_unit_test.py │ ├── single_gpu_unit_test.py │ ├── test_batchnorm1d.py │ ├── test_groups.py │ ├── two_gpu_test_different_batch_size.py │ ├── two_gpu_unit_test.py │ └── unit_test.sh └── docker_extension_builds/ └── run.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ # Start with a built-in style and modify it BasedOnStyle: Google # Overrides ColumnLimit: 120 ================================================ FILE: .git-blame-ignore-revs ================================================ # Commits to ignore in git-blame # These commits are bulk formatting or refactoring changes that should be skipped when viewing blame history # Add pre-commit and GitHub Actions workflow for it (#1949) 1f20398756f0eeba37d6887a2d3f65e0687ec94f # Remove github actions config of pre-commit in favor of pre-commit ci (#1958) 27e0e8951352d9d58c88b2895cd8f2c752bda963 # Enable Ruff pre-commit hooks (#1957) 16fadfe71c0d57312351c2d8b056251a0c8ce1ef ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve apex title: '' labels: bug assignees: '' --- **Describe the Bug** **Minimal Steps/Code to Reproduce the Bug** **Expected Behavior** **Environment** ================================================ FILE: .gitignore ================================================ apex.egg-info dist build docs/build *~ __pycache__ .vscode # Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ ================================================ FILE: .gitmodules ================================================ [submodule "apex/contrib/csrc/multihead_attn/cutlass"] path = apex/contrib/csrc/multihead_attn/cutlass url = https://github.com/NVIDIA/cutlass.git branch = v1.2.0 [submodule "apex/contrib/csrc/cudnn-frontend"] path = apex/contrib/csrc/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git ================================================ FILE: .nojekyll ================================================ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/mirrors-clang-format rev: v22.1.1 # Or pin to your preferred clang-format version hooks: - id: clang-format files: \.(c|h|cpp|hpp|proto|cu|cuh)$ exclude: ^(apex/contrib/csrc/multihead_attn/cutlass|apex/contrib/csrc/cudnn-frontend)/ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.6 hooks: - id: ruff-check args: ["--fix"] - id: ruff-format types_or: [python] exclude: "examples" ================================================ FILE: LICENSE ================================================ All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # Introduction This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intent of Apex is to make up-to-date utilities available to users as quickly as possible. # Installation Each [`apex.contrib`](./apex/contrib) module requires one or more install options other than `--cpp_ext` and `--cuda_ext`. Note that contrib modules do not necessarily support stable PyTorch releases, some of them might only be compatible with nightlies. ## Containers NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. The containers come with all the custom extensions available at the moment. See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as: - how to pull a container - how to run a pulled container - release notes ## From Source To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch. The latest stable release obtainable from https://pytorch.org should also work. We recommend installing [`Ninja`](https://ninja-build.org/) to make compilation faster. ### Linux For performance and full functionality, we recommend installing Apex with CUDA and C++ extensions using environment variables: #### Using Environment Variables (Recommended) ```bash git clone https://github.com/NVIDIA/apex cd apex # Build with core extensions (cpp and cuda) APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation . # To build with additional extensions, specify them with environment variables APEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_FAST_MULTIHEAD_ATTN=1 APEX_FUSED_CONV_BIAS_RELU=1 pip install -v --no-build-isolation . # To build all contrib extensions at once APEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_ALL_CONTRIB_EXT=1 pip install -v --no-build-isolation . ``` To reduce the build time, parallel building can be enabled: ```bash NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation . ``` When CPU cores or memory are limited, the `--parallel` option is generally preferred over `--threads`. See [pull#1882](https://github.com/NVIDIA/apex/pull/1882) for more details. #### Using Command-Line Flags (Legacy Method) The traditional command-line flags are still supported: ```bash # Using pip config-settings (pip >= 23.1) pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ # For older pip versions pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ # To build with additional extensions pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./ ``` #### Python-Only Build APEX also supports a Python-only build via: ```bash pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./ ``` A Python-only build omits: - Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`. - Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`. - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. ### [Experimental] Windows `pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .` may work if you were able to build Pytorch from source on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. ## Custom C++/CUDA Extensions and Install Options If a requirement of a module is not met, then it will not be built. | Module Name | Environment Variable | Install Option | Misc | |---------------|------------------------|------------------|--------| | `apex_C` | `APEX_CPP_EXT=1` | `--cpp_ext` | | | `amp_C` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `syncbn` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `fused_layer_norm_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | [`apex.normalization`](./apex/normalization) | | `mlp_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `scaled_upper_triang_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `generic_scaled_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `scaled_masked_softmax_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | | | `fused_weight_gradient_mlp_cuda` | `APEX_CUDA_EXT=1` | `--cuda_ext` | Requires CUDA>=11 | | `permutation_search_cuda` | `APEX_PERMUTATION_SEARCH=1` | `--permutation_search` | [`apex.contrib.sparsity`](./apex/contrib/sparsity) | | `bnp` | `APEX_BNP=1` | `--bnp` | [`apex.contrib.groupbn`](./apex/contrib/groupbn) | | `xentropy` | `APEX_XENTROPY=1` | `--xentropy` | [`apex.contrib.xentropy`](./apex/contrib/xentropy) | | `focal_loss_cuda` | `APEX_FOCAL_LOSS=1` | `--focal_loss` | [`apex.contrib.focal_loss`](./apex/contrib/focal_loss) | | `fused_index_mul_2d` | `APEX_INDEX_MUL_2D=1` | `--index_mul_2d` | [`apex.contrib.index_mul_2d`](./apex/contrib/index_mul_2d) | | `fused_adam_cuda` | `APEX_DEPRECATED_FUSED_ADAM=1` | `--deprecated_fused_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) | | `fused_lamb_cuda` | `APEX_DEPRECATED_FUSED_LAMB=1` | `--deprecated_fused_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) | | `fast_layer_norm` | `APEX_FAST_LAYER_NORM=1` | `--fast_layer_norm` | [`apex.contrib.layer_norm`](./apex/contrib/layer_norm). different from `fused_layer_norm` | | `fmhalib` | `APEX_FMHA=1` | `--fmha` | [`apex.contrib.fmha`](./apex/contrib/fmha) | | `fast_multihead_attn` | `APEX_FAST_MULTIHEAD_ATTN=1` | `--fast_multihead_attn` | [`apex.contrib.multihead_attn`](./apex/contrib/multihead_attn) | | `transducer_joint_cuda` | `APEX_TRANSDUCER=1` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) | | `transducer_loss_cuda` | `APEX_TRANSDUCER=1` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) | | `cudnn_gbn_lib` | `APEX_CUDNN_GBN=1` | `--cudnn_gbn` | Requires cuDNN>=8.5, [`apex.contrib.cudnn_gbn`](./apex/contrib/cudnn_gbn) | | `peer_memory_cuda` | `APEX_PEER_MEMORY=1` | `--peer_memory` | [`apex.contrib.peer_memory`](./apex/contrib/peer_memory) | | `nccl_p2p_cuda` | `APEX_NCCL_P2P=1` | `--nccl_p2p` | Requires NCCL >= 2.10, [`apex.contrib.nccl_p2p`](./apex/contrib/nccl_p2p) | | `fast_bottleneck` | `APEX_FAST_BOTTLENECK=1` | `--fast_bottleneck` | Requires `peer_memory_cuda` and `nccl_p2p_cuda`, [`apex.contrib.bottleneck`](./apex/contrib/bottleneck) | | `fused_conv_bias_relu` | `APEX_FUSED_CONV_BIAS_RELU=1` | `--fused_conv_bias_relu` | Requires cuDNN>=8.4, [`apex.contrib.conv_bias_relu`](./apex/contrib/conv_bias_relu) | | `distributed_adam_cuda` | `APEX_DISTRIBUTED_ADAM=1` | `--distributed_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) | | `distributed_lamb_cuda` | `APEX_DISTRIBUTED_LAMB=1` | `--distributed_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) | | `_apex_nccl_allocator` | `APEX_NCCL_ALLOCATOR=1` | `--nccl_allocator` | Requires NCCL >= 2.19, [`apex.contrib.nccl_allocator`](./apex/contrib/nccl_allocator) | | `_apex_gpu_direct_storage` | `APEX_GPU_DIRECT_STORAGE=1` | `--gpu_direct_storage` | [`apex.contrib.gpu_direct_storage`](./apex/contrib/gpu_direct_storage) | You can also build all contrib extensions at once by setting `APEX_ALL_CONTRIB_EXT=1`. ================================================ FILE: apex/__init__.py ================================================ import logging import warnings # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten import torch # For optimizers and normalization there is no Python fallback. # Absence of cuda backend is a hard error. # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext # so they expect those backends to be available, but for some reason they actually aren't # available (for example because they built improperly in a way that isn't revealed until # load time) the error message is timely and visible. from . import optimizers from . import normalization __all__ = ["optimizers", "normalization"] def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: cudnn_available = torch.backends.cudnn.is_available() cudnn_version = torch.backends.cudnn.version() if cudnn_available else None if not (cudnn_available and (cudnn_version >= required_cudnn_version)): warnings.warn( f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, " f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" ) return False return True class DeprecatedFeatureWarning(FutureWarning): pass def deprecated_warning(msg: str) -> None: if ( not torch.distributed.is_available or not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) ): warnings.warn(msg, DeprecatedFeatureWarning) ================================================ FILE: apex/_autocast_utils.py ================================================ from typing import Optional, Sequence import torch __all__ = ["_cast_if_autocast_enabled"] def _get_autocast_dtypes() -> Sequence[torch.dtype]: if torch.cuda.is_bf16_supported(): return [torch.half, torch.bfloat16] return [torch.half] def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: if not torch.is_autocast_enabled(): return torch.float or dtype else: return torch.get_autocast_gpu_dtype() def _cast_if_autocast_enabled(*args): if not torch.is_autocast_enabled(): return args else: return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) ================================================ FILE: apex/contrib/__init__.py ================================================ ================================================ FILE: apex/contrib/bottleneck/__init__.py ================================================ from .bottleneck import Bottleneck, SpatialBottleneck from .halo_exchangers import ( HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer, ) ================================================ FILE: apex/contrib/bottleneck/bottleneck.py ================================================ import functools as func import torch from torch import nn from apex import check_cudnn_version_and_warn import fast_bottleneck import nccl_p2p_cuda as inc assert check_cudnn_version_and_warn(__name__, 8400) def kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu"): weight_tensor_nchw = tensor nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias): scale = weight * running_var.rsqrt() bias = bias - running_mean * scale w_scale.copy_(scale) w_bias.copy_(bias) def compute_scale_bias_method(nhwc, args): for arg in args: # arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias) compute_scale_bias_one(nhwc, *arg) class FrozenBatchNorm2d(torch.jit.ScriptModule): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__(self, n): super(FrozenBatchNorm2d, self).__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) @torch.jit.script_method def get_scale_bias(self, nhwc): # type: (bool) -> List[torch.Tensor] scale = self.weight * self.running_var.rsqrt() bias = self.bias - self.running_mean * scale if nhwc: scale = scale.reshape(1, 1, 1, -1) bias = bias.reshape(1, 1, 1, -1) else: scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return scale, bias @torch.jit.script_method def forward(self, x): scale, bias = self.get_scale_bias(False) return x * scale + bias @torch.jit.script def drelu_dscale1(grad_o, output, scale1): relu_mask = output > 0 dx_relu = relu_mask * grad_o g1 = dx_relu * scale1 return g1, dx_relu @torch.jit.script def drelu_dscale2(grad_o, output, scale1, scale2): relu_mask = output > 0 dx_relu = relu_mask * grad_o g1 = dx_relu * scale1 g2 = dx_relu * scale2 return g1, g2 class BottleneckFunction(torch.autograd.Function): @staticmethod def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv): # TODO: clean up order of tensors args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] ctx.downsample = len(conv) > 3 if ctx.downsample: args.append(conv[3]) args.append(scale[3]) args.append(bias[3]) # weight buffers are always in nhwc while shape can be nhwc or channels_last # here we pass in flag and let c++ handle it # alternatively, we can put all sizes into a fixed format and pass it in outputs = fast_bottleneck.forward(nhwc, stride_1x1, args) ctx.save_for_backward(*(args + outputs)) # save relu outputs for drelu ctx.nhwc = nhwc ctx.stride_1x1 = stride_1x1 return outputs[2] # backward relu is not exposed, MUL with mask used now # only support dgrad @staticmethod def backward(ctx, grad_o): outputs = ctx.saved_tensors[-3:] if ctx.downsample: grad_conv3, grad_conv4 = drelu_dscale2( grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11] ) else: grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) # create input vector for backward t_list = [*ctx.saved_tensors[0:10]] t_list.append(grad_conv3) t_list.append(grad_conv4) # outputs used for wgrad and generating drelu mask t_list.append(outputs[0]) t_list.append(outputs[1]) # in case there is downsample if ctx.downsample: t_list.append(ctx.saved_tensors[10]) grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list) return (None, None, None, None, *grads) bottleneck_function = BottleneckFunction.apply def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class Bottleneck(torch.nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. # here we put it at 1x1 def __init__( self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, ): super(Bottleneck, self).__init__() if groups != 1: raise RuntimeError("Only support groups == 1") if dilation != 1: raise RuntimeError("Only support dilation == 1") if norm_func == None: norm_func = FrozenBatchNorm2d else: raise RuntimeError("Only support frozen BN now.") if stride != 1 or in_channels != out_channels: self.downsample = nn.Sequential( conv1x1(in_channels, out_channels, stride), norm_func(out_channels), ) else: self.downsample = None # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) self.conv3 = conv1x1(bottleneck_channels, out_channels) self.relu = nn.ReLU(inplace=True) self.stride = stride self.bn1 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels) self.bn3 = norm_func(out_channels) self.w_scale = None self.use_cudnn = use_cudnn # setup conv weights self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] if self.downsample is not None: self.w_conv.append(self.downsample[0].weight) # init weight in nchw format before possible transpose for w in self.w_conv: kaiming_uniform_(w, a=1) # TODO: prevent unsupported case usage # support cases # native cudnn # normal yes no # channel_last yes yes # explicit_nhwc no yes self.explicit_nhwc = explicit_nhwc if self.explicit_nhwc: for p in self.parameters(): with torch.no_grad(): p.data = p.data.permute(0, 2, 3, 1).contiguous() return # Returns single callable that recomputes scale and bias for all frozen batch-norms. # This method must be called before cuda graphing. # The callable it returns can be called anytime. # Calling this method will prevent these from being computed every forward call. def get_scale_bias_callable(self): self.w_scale, self.w_bias, args = [], [], [] batch_norms = [self.bn1, self.bn2, self.bn3] if self.downsample is not None: batch_norms.append(self.downsample[1]) for bn in batch_norms: s = torch.empty_like(bn.weight) b = torch.empty_like(s) args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b)) if self.explicit_nhwc: self.w_scale.append(s.reshape(1, 1, 1, -1)) self.w_bias.append(b.reshape(1, 1, 1, -1)) else: self.w_scale.append(s.reshape(1, -1, 1, 1)) self.w_bias.append(b.reshape(1, -1, 1, 1)) return func.partial(compute_scale_bias_method, self.explicit_nhwc, args) def forward(self, x): if self.use_cudnn: if self.w_scale is None: # calculate scale/bias from registered buffers # TODO: make this better s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) w_scale = [s1, s2, s3] w_bias = [b1, b2, b3] if self.downsample is not None: s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) w_scale.append(s4) w_bias.append(b4) out = bottleneck_function( self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv ) else: out = bottleneck_function( self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv, ) return out if self.explicit_nhwc: raise RuntimeError("explicit nhwc with native ops is not supported.") # fallback to native ops identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class SpatialBottleneckFunction(torch.autograd.Function): @staticmethod def forward( ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv, ): if spatial_group_size > 1: stream1 = spatial_halo_exchanger.stream1 stream2 = spatial_halo_exchanger.stream2 stream3 = spatial_halo_exchanger.stream3 # TODO: clean up order of tensors args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] ctx.downsample = len(conv) > 3 if ctx.downsample: args.append(conv[3]) args.append(scale[3]) args.append(bias[3]) # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last # here we pass in flag and let c++ handle it # alternatively, we can put all sizes into a fixed format and pass it in outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args) fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs) if spatial_group_size > 1: out1 = outputs[0] if explicit_nhwc: N, Hs, W, C = list(out1.shape) memory_format = torch.contiguous_format out1_pad = torch.empty([N, Hs + 2, W, C], dtype=out1.dtype, device="cuda") else: N, C, Hs, W = list(out1.shape) memory_format = ( torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format ) out1_pad = torch.empty( [N, C, Hs + 2, W], dtype=out1.dtype, device="cuda", memory_format=memory_format, ) stream1.wait_stream(torch.cuda.current_stream()) if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream1): if explicit_nhwc: top_out1_halo = out1_pad[:, :1, :, :] btm_out1_halo = out1_pad[:, Hs + 1 : Hs + 2, :, :] spatial_halo_exchanger.left_right_halo_exchange( out1[:, :1, :, :], out1[:, Hs - 1 :, :, :], top_out1_halo, btm_out1_halo, ) else: top_out1_halo = out1_pad[:, :, :1, :] btm_out1_halo = out1_pad[:, :, Hs + 1 : Hs + 2, :] spatial_halo_exchanger.left_right_halo_exchange( out1[:, :, :1, :], out1[:, :, Hs - 1 :, :], top_out1_halo, btm_out1_halo, ) if spatial_method == 1: # overlap mid convolution with halo transfer if spatial_group_rank < spatial_group_size - 1: stream2.wait_stream(stream1) with torch.cuda.stream(stream2): if explicit_nhwc: btm_fat_halo = torch.empty( (N, 3, W, C), dtype=out1.dtype, device=out1.device ) btm_fat_halo[:, 0:2, :, :].copy_(out1[:, Hs - 2 :, :, :]) btm_fat_halo[:, 2:, :, :].copy_(btm_out1_halo) else: btm_fat_halo = torch.empty( (N, C, 3, W), dtype=out1.dtype, device=out1.device ) btm_fat_halo[:, :, 0:2, :].copy_(out1[:, :, Hs - 2 :, :]) btm_fat_halo[:, :, 2:, :].copy_(btm_out1_halo) btm_out2 = fast_bottleneck.forward_out2_halo( explicit_nhwc, btm_fat_halo, args ) if spatial_group_rank > 0: with torch.cuda.stream(stream1): if explicit_nhwc: top_fat_halo = torch.empty( (N, 3, W, C), dtype=out1.dtype, device=out1.device ) top_fat_halo[:, :1, :, :].copy_(top_out1_halo) top_fat_halo[:, 1:3, :, :].copy_(out1[:, :2, :, :]) else: top_fat_halo = torch.empty( (N, C, 3, W), dtype=out1.dtype, device=out1.device ) top_fat_halo[:, :, :1, :].copy_(top_out1_halo) top_fat_halo[:, :, 1:3, :].copy_(out1[:, :, :2, :]) top_out2 = fast_bottleneck.forward_out2_halo( explicit_nhwc, top_fat_halo, args ) if use_delay_kernel: inc.add_delay(10) elif spatial_method != 2 and spatial_method != 3: assert False, "spatial_method must be 1, 2 or 3" if spatial_group_size <= 1: fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) elif spatial_method == 1: fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) with torch.cuda.stream(stream3): if explicit_nhwc: out1_pad[:, 1 : Hs + 1, :, :].copy_(out1) else: out1_pad[:, :, 1 : Hs + 1, :].copy_(out1) elif spatial_method == 2: # wait for halo transfer to finish before doing a full convolution of padded x if explicit_nhwc: out1_pad[:, 1 : Hs + 1, :, :].copy_(out1) else: out1_pad[:, :, 1 : Hs + 1, :].copy_(out1) torch.cuda.current_stream().wait_stream(stream1) fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad) elif spatial_method == 3: fast_bottleneck.forward_out2_mask( explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom ) with torch.cuda.stream(stream3): if explicit_nhwc: out1_pad[:, 1 : Hs + 1, :, :].copy_(out1) else: out1_pad[:, :, 1 : Hs + 1, :].copy_(out1) # compute halo cells for outputs[1] (out2) if spatial_group_size > 1: out2 = outputs[1] if explicit_nhwc: top_out2_halo = out2[:, :1, :, :] btm_out2_halo = out2[:, Hs - 1 :, :, :] else: top_out2_halo = out2[:, :, :1, :] btm_out2_halo = out2[:, :, Hs - 1 :, :] if spatial_method == 1: if spatial_group_rank > 0: torch.cuda.current_stream().wait_stream(stream1) top_out2_halo.copy_(top_out2) if spatial_group_rank < spatial_group_size - 1: torch.cuda.current_stream().wait_stream(stream2) btm_out2_halo.copy_(btm_out2) elif spatial_method == 3: # Note # out2 halo correction cannot overlap with anything since it has # to wait for out2_mask to finish, but itself has to finish before # the first kernel of _forward_rest can launch. # At least we can overlap the two halo correction kernels. if spatial_group_rank < spatial_group_size - 1: stream2.wait_stream(stream1) # wait for halo transfers to finish stream2.wait_stream( torch.cuda.current_stream() ) # wait for *_out2_mask to finish with torch.cuda.stream(stream2): w1by3 = args[2][:, 2:3, :, :].clone() btm_out1_halo = btm_out1_halo.clone() btm_out2 = fast_bottleneck.forward_out2_halo_corr( explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone(), ) btm_out2_halo.copy_(btm_out2) if spatial_group_rank > 0: stream1.wait_stream( torch.cuda.current_stream() ) # wait for *_out2_mask to finish with torch.cuda.stream(stream1): w1by3 = args[2][:, :1, :, :].clone() top_out1_halo = top_out1_halo.clone() top_out2 = fast_bottleneck.forward_out2_halo_corr( explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone(), ) top_out2_halo.copy_(top_out2) if spatial_group_rank < spatial_group_size - 1: torch.cuda.current_stream().wait_stream(stream2) if spatial_group_rank > 0: torch.cuda.current_stream().wait_stream(stream1) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) # save halos for backward pass if spatial_group_size > 1: if spatial_method != 2: # make sure copy of mid-section of out1 into out1_pad is done before exiting torch.cuda.current_stream().wait_stream(stream3) ctx.save_for_backward( *( args + outputs + [ out1_pad, ] ) ) else: ctx.save_for_backward(*(args + outputs)) # save relu outputs for drelu ctx.explicit_nhwc = explicit_nhwc ctx.stride_1x1 = stride_1x1 ctx.spatial_group_size = spatial_group_size if spatial_group_size > 1: ctx.spatial_group_rank = spatial_group_rank ctx.spatial_halo_exchanger = spatial_halo_exchanger ctx.spatial_method = spatial_method ctx.use_delay_kernel = use_delay_kernel ctx.thresholdTop = thresholdTop ctx.thresholdBottom = thresholdBottom ctx.stream1 = stream1 ctx.stream2 = stream2 ctx.stream3 = stream3 return outputs[2] # backward relu is not exposed, MUL with mask used now # only support dgrad @staticmethod def backward(ctx, grad_o): if ctx.spatial_group_size > 1: out1_pad = ctx.saved_tensors[-1] outputs = ctx.saved_tensors[-4:-1] else: outputs = ctx.saved_tensors[-3:] if ctx.downsample: grad_conv3, grad_conv4 = drelu_dscale2( grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11] ) else: grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) # create input vector for backward t_list = [*ctx.saved_tensors[0:10]] t_list.append(grad_conv3) t_list.append(grad_conv4) # outputs used for wgrad and generating drelu mask t_list.append(outputs[0]) t_list.append(outputs[1]) # in case there is downsample if ctx.downsample: t_list.append(ctx.saved_tensors[10]) grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list) wgrad3_stream = torch.cuda.Stream() wgrad3_stream.wait_stream(torch.cuda.current_stream()) grad_out2 = fast_bottleneck.backward_grad_out2( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads ) wgrad2_stream = torch.cuda.Stream() wgrad2_stream.wait_stream(torch.cuda.current_stream()) # do halo exchange of grad_out2 here # compute halo cells for grad_out1 if ctx.spatial_group_size > 1: if ctx.explicit_nhwc: N, Hs, W, C = list(grad_out2.shape) else: N, C, Hs, W = list(grad_out2.shape) relu1 = t_list[12] ctx.stream1.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(ctx.stream1): top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange( grad_out2[:, :1, :, :], grad_out2[:, Hs - 1 :, :, :] ) # copy halos to send buffer if ctx.spatial_method == 1 or ctx.spatial_method == 2: # 1 -> halo recompute approach # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop) if ctx.spatial_group_rank < ctx.spatial_group_size - 1: ctx.stream2.wait_stream(ctx.stream1) with torch.cuda.stream(ctx.stream2): if ctx.explicit_nhwc: btm_fat_halo = torch.empty( (N, 3, W, C), dtype=grad_out2.dtype, device=grad_out2.device, ) btm_fat_halo[:, :2, :, :].copy_(grad_out2[:, Hs - 2 :, :, :]) btm_fat_halo[:, 2:, :, :].copy_(btm_halo) btm_fat_relu_halo = torch.empty( (N, 3, W, C), dtype=grad_out2.dtype, device=grad_out2.device, ) btm_fat_relu_halo[:, :2, :, :].copy_(relu1[:, Hs - 2 :, :, :]) btm_fat_relu_halo[:, 2:, :, :].zero_() else: btm_fat_halo = torch.empty( (N, C, 3, W), dtype=grad_out2.dtype, device=grad_out2.device, ) btm_fat_halo[:, :, :2, :].copy_(grad_out2[:, :, Hs - 2 :, :]) btm_fat_halo[:, :, 2:, :].copy_(btm_halo) btm_fat_relu_halo = torch.empty( (N, C, 3, W), dtype=grad_out2.dtype, device=grad_out2.device, ) btm_fat_relu_halo[:, :, :2, :].copy_(relu1[:, :, Hs - 2 :, :]) btm_fat_relu_halo[:, :, 2:, :].zero_() btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo, ) if ctx.explicit_nhwc: btm_grad_out1_halo = btm_grad_out1_halo[:, 1:2, :, :] else: btm_grad_out1_halo = btm_grad_out1_halo[:, :, 1:2, :] if ctx.spatial_group_rank > 0: with torch.cuda.stream(ctx.stream1): if ctx.explicit_nhwc: top_fat_halo = torch.empty( (N, 3, W, C), dtype=grad_out2.dtype, device=grad_out2.device, ) top_fat_halo[:, :1, :, :].copy_(top_halo) top_fat_halo[:, 1:, :, :].copy_(grad_out2[:, :2, :, :]) top_fat_relu_halo = torch.empty( (N, 3, W, C), dtype=grad_out2.dtype, device=grad_out2.device, ) top_fat_relu_halo[:, :1, :, :].zero_() top_fat_relu_halo[:, 1:, :, :].copy_(relu1[:, :2, :, :]) else: top_fat_halo = torch.empty( (N, C, 3, W), dtype=grad_out2.dtype, device=grad_out2.device, ) top_fat_halo[:, :, :1, :].copy_(top_halo) top_fat_halo[:, :, 1:, :].copy_(grad_out2[:, :, :2, :]) top_fat_relu_halo = torch.empty( (N, C, 3, W), dtype=grad_out2.dtype, device=grad_out2.device, ) top_fat_relu_halo[:, :, :1, :].zero_() top_fat_relu_halo[:, :, 1:, :].copy_(relu1[:, :, :2, :]) top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo, ) if ctx.explicit_nhwc: top_grad_out1_halo = top_grad_out1_halo[:, 1:2, :, :] else: top_grad_out1_halo = top_grad_out1_halo[:, :, 1:2, :] if ctx.use_delay_kernel: inc.add_delay(10) elif ctx.spatial_method != 3: assert False, "spatial_method must be 1, 2 or 3" # compute grad_out1 for internal cells if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2: grad_out1 = fast_bottleneck.backward_grad_out1( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2 ) elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3: grad_out1 = fast_bottleneck.backward_grad_out1_mask( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom, ) # apply halo cells to grad_out1 if ctx.spatial_group_size > 1: w = t_list[2] z = t_list[4] relu1 = t_list[12] # print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) if ctx.spatial_method == 1 or ctx.spatial_method == 2: if ctx.spatial_group_rank < ctx.spatial_group_size - 1: torch.cuda.current_stream().wait_stream(ctx.stream2) if ctx.explicit_nhwc: grad_out1[:, Hs - 1 :, :, :].copy_(btm_grad_out1_halo) else: grad_out1[:, :, Hs - 1 :, :].copy_(btm_grad_out1_halo) # print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) if ctx.spatial_group_rank > 0: torch.cuda.current_stream().wait_stream(ctx.stream1) if ctx.explicit_nhwc: grad_out1[:, :1, :, :].copy_(top_grad_out1_halo) else: grad_out1[:, :, :1, :].copy_(top_grad_out1_halo) # print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) elif ctx.spatial_method == 3: if ctx.spatial_group_rank < ctx.spatial_group_size - 1: if ctx.explicit_nhwc: btm_relu_halo = relu1[:, Hs - 1 :, :, :].clone() btm_grad_out1 = grad_out1[:, Hs - 1 :, :, :] else: btm_relu_halo = relu1[:, :, Hs - 1 :, :].clone() btm_grad_out1 = grad_out1[:, :, Hs - 1 :, :] w1by3 = w[:, :1, :, :].clone() ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish ctx.stream2.wait_stream( torch.cuda.current_stream() ) # wait for backward_grad_out1_mask to finish before launching halo correction kernel with torch.cuda.stream(ctx.stream2): btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr( ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone(), ) btm_grad_out1.copy_(btm_grad_out1_halo) if ctx.spatial_group_rank > 0: if ctx.explicit_nhwc: top_relu_halo = relu1[:, :1, :, :].clone() top_grad_out1 = grad_out1[:, :1, :, :] else: top_relu_halo = relu1[:, :, :1, :].clone() top_grad_out1 = grad_out1[:, :, :1, :] w1by3 = w[:, 2:, :, :].clone() ctx.stream1.wait_stream( torch.cuda.current_stream() ) # wait for backward_grad_out1_mask to finish before launching halo correction kernel with torch.cuda.stream(ctx.stream1): top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr( ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone(), ) top_grad_out1.copy_(top_grad_out1_halo) if ctx.spatial_group_rank < ctx.spatial_group_size - 1: torch.cuda.current_stream().wait_stream( ctx.stream2 ) # wait for halo correction to finish if ctx.spatial_group_rank > 0: torch.cuda.current_stream().wait_stream(ctx.stream1) wgrad1_stream = torch.cuda.Stream() wgrad1_stream.wait_stream(torch.cuda.current_stream()) fast_bottleneck.backward_rest( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1 ) with torch.cuda.stream(wgrad3_stream): fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads) with torch.cuda.stream(wgrad2_stream): if ctx.spatial_group_size > 1: fast_bottleneck.backward_wgrad2_pad( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2, ) else: fast_bottleneck.backward_wgrad2( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2 ) with torch.cuda.stream(wgrad1_stream): fast_bottleneck.backward_wgrad1( ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1 ) torch.cuda.current_stream().wait_stream(wgrad3_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad1_stream) return ( None, None, None, None, None, None, None, None, None, None, None, None, *grads, ) spatial_bottleneck_function = SpatialBottleneckFunction.apply class SpatialBottleneck(torch.nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. # here we put it at 1x1 def __init__( self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, spatial_parallel_args=None, ): super(SpatialBottleneck, self).__init__() if groups != 1: raise RuntimeError("Only support groups == 1") if dilation != 1: raise RuntimeError("Only support dilation == 1") if norm_func == None: norm_func = FrozenBatchNorm2d else: raise RuntimeError("Only support frozen BN now.") if stride != 1 or in_channels != out_channels: self.downsample = nn.Sequential( conv1x1(in_channels, out_channels, stride), norm_func(out_channels), ) else: self.downsample = None # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) self.conv3 = conv1x1(bottleneck_channels, out_channels) self.relu = nn.ReLU(inplace=True) self.stride = stride self.bn1 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels) self.bn3 = norm_func(out_channels) self.w_scale = None self.use_cudnn = use_cudnn # setup conv weights self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] if self.downsample is not None: self.w_conv.append(self.downsample[0].weight) # init weight in nchw format before possible transpose for w in self.w_conv: kaiming_uniform_(w, a=1) self.thresholdTop, self.thresholdBottom = None, None # TODO: prevent unsupported case usage # support cases # native cudnn # normal yes no # channel_last yes yes # explicit_nhwc no yes self.explicit_nhwc = explicit_nhwc if self.explicit_nhwc: for p in self.parameters(): with torch.no_grad(): p.data = p.data.permute(0, 2, 3, 1).contiguous() # spatial communicator if spatial_parallel_args is None: self.spatial_parallel_args = (1, 0, None, None, 0, False) else: self.spatial_parallel_args = spatial_parallel_args return # Returns single callable that recomputes scale and bias for all frozen batch-norms. # This method must be called before cuda graphing. # The callable it returns can be called anytime. # Calling this method will prevent these from being computed every forward call. def get_scale_bias_callable(self): self.w_scale, self.w_bias, args = [], [], [] batch_norms = [self.bn1, self.bn2, self.bn3] if self.downsample is not None: batch_norms.append(self.downsample[1]) for bn in batch_norms: s = torch.empty_like(bn.weight) b = torch.empty_like(s) args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b)) if self.explicit_nhwc: self.w_scale.append(s.reshape(1, 1, 1, -1)) self.w_bias.append(b.reshape(1, 1, 1, -1)) else: self.w_scale.append(s.reshape(1, -1, 1, 1)) self.w_bias.append(b.reshape(1, -1, 1, 1)) return func.partial(compute_scale_bias_method, self.explicit_nhwc, args) def forward(self, x): if self.use_cudnn: if self.thresholdTop is None: spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args if self.explicit_nhwc: N, H, W, C = list(x.shape) else: N, C, H, W = list(x.shape) self.thresholdTop = torch.tensor( [1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device="cuda", ) self.thresholdBottom = torch.tensor( [H - 2 if spatial_group_rank < spatial_group_size - 1 else H - 1], dtype=torch.int32, device="cuda", ) if self.w_scale is None: # calculate scale/bias from registered buffers # TODO: make this better s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) w_scale = [s1, s2, s3] w_bias = [b1, b2, b3] if self.downsample is not None: s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) w_scale.append(s4) w_bias.append(b4) out = spatial_bottleneck_function( *self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv, ) else: out = spatial_bottleneck_function( *self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv, ) return out if self.explicit_nhwc: raise RuntimeError("explicit nhwc with native ops is not supported.") # fallback to native ops identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out ================================================ FILE: apex/contrib/bottleneck/halo_exchangers.py ================================================ import torch import nccl_p2p_cuda as inc import peer_memory_cuda as pm # Communication free halo exchanger. # NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs # NB! This is only useful for performance testing. # NB! Do not use for actual production runs class HaloExchanger(object): def __init__(self, ranks, rank_in_group): self.stream1 = torch.cuda.Stream() self.stream2 = torch.cuda.Stream() self.stream3 = torch.cuda.Stream() self.group_size = len(ranks) self.ranks = ranks self.rank_in_group = rank_in_group self.wrap_around_left_rank_in_group = ( rank_in_group + self.group_size - 1 ) % self.group_size self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size self.left_rank = ranks[rank_in_group - 1] if rank_in_group > 0 else -1 self.left_zero = True if rank_in_group == 0 else False self.right_rank = ranks[rank_in_group + 1] if rank_in_group < self.group_size - 1 else -1 self.right_zero = True if rank_in_group == self.group_size - 1 else False class HaloExchangerNoComm(HaloExchanger): def __init__(self, ranks, rank_in_group): super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group) def left_right_halo_exchange( self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None, ): if left_input_halo is None: return right_output_halo, left_output_halo else: left_input_halo.copy_(right_output_halo) right_input_halo.copy_(left_output_halo) class HaloExchangerAllGather(HaloExchanger): def __init__(self, ranks, rank_in_group, comm): super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group) # self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks) self.comm = comm def left_right_halo_exchange( self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None, ): N, Hh, W, C = list(left_output_halo.shape) send_halos = torch.empty( (N, 2 * Hh, W, C), dtype=left_output_halo.dtype, device=left_output_halo.device, ) send_halos[:, :Hh, :, :].copy_(left_output_halo) send_halos[:, Hh:, :, :].copy_(right_output_halo) all_halos = torch.empty( (N, 2 * Hh * self.group_size, W, C), dtype=left_output_halo.dtype, device=left_output_halo.device, ) all_halos = [ all_halos[:, i * 2 * Hh : (i + 1) * 2 * Hh, :, :] for i in range(self.group_size) ] torch.distributed.all_gather(all_halos, send_halos, group=self.comm, no_copy=True) ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:, Hh:, :, :] ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:, :Hh, :, :] if left_input_halo is None: if self.left_zero: ag_left_input_halo.zero_() if self.right_zero: ag_right_input_halo.zero_() return ag_left_input_halo, ag_right_input_halo else: if self.left_zero: left_input_halo.zero_() else: left_input_halo.copy_(ag_left_input_halo) if self.right_zero: right_input_halo.zero_() else: right_input_halo.copy_(ag_right_input_halo) class HaloExchangerSendRecv(HaloExchanger): def __init__(self, ranks, rank_in_group): super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group) nccl_id = inc.get_unique_nccl_id(1).cuda() torch.distributed.broadcast(nccl_id, 0) nccl_id = nccl_id.cpu() print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id))) # Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl") # This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence # it cannot be accessed from another class. # TODO: Figure out a way to avoid creating a second global communicator assert torch.distributed.get_rank() == self.ranks[self.rank_in_group], ( "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % ( self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank(), ) ) self.handle = inc.init_nccl_comm( nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size() ) def left_right_halo_exchange( self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None, ): if left_input_halo is None: left_input_halo, right_input_halo = inc.left_right_halo_exchange( self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, ) return left_input_halo, right_input_halo else: inc.left_right_halo_exchange_inplace( self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo, ) class HaloExchangerPeer(HaloExchanger): def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0): super(HaloExchangerPeer, self).__init__(ranks, rank_in_group) self.diagnostics = False self.explicit_nhwc = explicit_nhwc self.numSM = numSM self.peer_pool = peer_pool def _allocate_peer_tensor(self, halo): # Compute size in bytes # Note: Pad buffer so each CUDA block gets required buffer size size = 4 * halo.numel() * halo.element_size() size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers size = (size + size_per_block - 1) // size_per_block * size_per_block # Construct dtype peer buffer with desired size shape = [1, 1, 1, size // halo.element_size()] return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True) def left_right_halo_exchange( self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None, ): inplace = False if left_input_halo is None and right_input_halo is None else True if not inplace: left_input_halo = torch.empty_like(right_output_halo) right_input_halo = torch.empty_like(left_output_halo) channels_last = ( left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc ) left_tx = self._allocate_peer_tensor(left_input_halo) right_tx = self._allocate_peer_tensor(right_input_halo) pm.push_pull_halos_1d( self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group, self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo, self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo, ) if not inplace: return left_input_halo, right_input_halo # Class that combines input volume with halos from neighbors (1d). class HaloPadder: def __init__(self, halo_ex): self.halo_ex = halo_ex self.stream1 = torch.cuda.Stream() self.stream2 = torch.cuda.Stream() def __call__(self, y, half_halo, explicit_nhwc, H_split): channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last) if explicit_nhwc: N, H, W, C = list(y.shape) if H_split: padded_shape = [N, H + 2 * half_halo, W, C] ypad = torch.empty( shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format, ) yleft = ypad[:, :half_halo, :, :] ymid = ypad[:, half_halo : H + half_halo, :, :] yright = ypad[:, H + half_halo : H + 2 * half_halo, :, :] oleft = y[:, :half_halo, :, :] oright = y[:, H - half_halo :, :, :] else: padded_shape = [N, H, W + 2 * half_halo, C] ypad = torch.empty( shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format, ) yleft = ypad[:, :, :half_halo, :] ymid = ypad[:, :, half_halo : W + half_halo, :] yright = ypad[:, :, W + half_halo : W + 2 * half_halo, :] oleft = y[:, :, :half_halo, :] oright = y[:, :, W - half_halo :, :] else: N, C, H, W = list(y.shape) if H_split: padded_shape = [N, C, H + 2 * half_halo, W] ypad = torch.empty( shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last, ) yleft = ypad[:, :, :half_halo, :] ymid = ypad[:, :, half_halo : H + half_halo, :] yright = ypad[:, :, H + half_halo : H + 2 * half_halo, :] oleft = y[:, :, :half_halo, :] oright = y[:, :, H - half_halo :, :] else: padded_shape = [N, C, H, W + 2 * half_halo] ypad = torch.empty( shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last, ) yleft = ypad[:, :, :, :half_halo] ymid = ypad[:, :, :, half_halo : W + half_halo] yright = ypad[:, :, :, W + half_halo : W + 2 * half_halo] oleft = y[:, :, :, :half_halo] oright = y[:, :, :, W - half_halo :] with torch.cuda.stream(self.stream1): self.halo_ex(oleft, oright, yleft, yright) with torch.cuda.stream(self.stream2): ymid.copy_(y) return ypad def wait(self): current_stream = torch.cuda.current_stream() current_stream.wait_stream(self.stream1) current_stream.wait_stream(self.stream2) ================================================ FILE: apex/contrib/bottleneck/test.py ================================================ import torch from bottleneck import Bottleneck torch.manual_seed(23337) # use True to print layerwise sum for all outputs in reference code path DEBUG = False # True for stride, o_channel in [(1, 32), (1, 128), (2, 32)]: print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel) a_ = torch.randn(17, 32, 28, 28) a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_() model = ( Bottleneck(32, 8, o_channel, stride=stride) .cuda() .half() .to(memory_format=torch.channels_last) ) # test model b = model(a) b.mean().backward() d_grad = a.grad.float() a.grad = None torch.cuda.synchronize() if DEBUG: print("[DEBUG] ref dx :", d_grad.sum().item()) # print wgrad. we don't need to reset since later cpp print before accumulation for i, w in enumerate(model.w_conv): print("[DEBUG] ref wgrad{} :".format(i + 1), w.grad.sum().item()) wgrads = [] for w in model.w_conv: wgrads.append(w.grad.float()) model.use_cudnn = True model.zero_grad() c = model(a) c.mean().backward() torch.cuda.synchronize() print("comparing native and channels_last:") print( "max error fprop:", (b - c).abs().max().item(), "max elem:", b.abs().max().item(), ) print( "max error dgrad:", (d_grad - a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item(), ) for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)): print( "max error wgrad{}:".format(i + 1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item(), ) nhwc_a = a_.permute(0, 2, 3, 1).contiguous().cuda().half().requires_grad_() nhwc_model = ( Bottleneck(32, 8, o_channel, stride=stride, explicit_nhwc=True, use_cudnn=True) .cuda() .half() ) for p, q in zip(model.parameters(), nhwc_model.parameters()): # model's storage is already in nhwc, we clone and assign to explicit nhwc model q.data.copy_(p.data.permute(0, 2, 3, 1).contiguous()) for p, q in zip(model.buffers(), nhwc_model.buffers()): q.data.copy_(p.data) d = nhwc_model(nhwc_a) d.mean().backward() torch.cuda.synchronize() # reset reference to cudnn channels_last permute # c_s = c.storage().tolist() # d_s = d.storage().tolist() # print(max([x-y for x,y in zip(c_s,d_s)])) c = c.contiguous(memory_format=torch.contiguous_format).permute(0, 2, 3, 1).contiguous() d_grad = a.grad.float().permute(0, 2, 3, 1).contiguous() wgrads = [] for w in model.w_conv: wgrads.append(w.grad.float().permute(0, 2, 3, 1).contiguous()) torch.cuda.synchronize() print("comparing nhwc and channels_last:") print( "max error fprop:", (d - c).abs().max().item(), "max elem:", c.abs().max().item(), ) print( "max error dgrad:", (d_grad - nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item(), ) for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)): print( "max error wgrad{}:".format(i + 1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item(), ) ================================================ FILE: apex/contrib/clip_grad/__init__.py ================================================ from .clip_grad import clip_grad_norm_ ================================================ FILE: apex/contrib/clip_grad/clip_grad.py ================================================ from typing import Union, Iterable import torch _kernel_import_succeeded = False try: import amp_C from apex.multi_tensor_apply import multi_tensor_applier _kernel_import_succeeded = True except ImportError: _kernel_import_succeeded = False _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, ) -> torch.Tensor: r"""Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. This is identical to torch.nn.utils.clip_grad_norm_, except it uses a fused CUDA kernel when computing the 2-norm of GPU tensors in float32 and float16. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) # Trivial case if len(parameters) == 0: return torch.tensor(0.0) # Fallback implementation if not (_kernel_import_succeeded and norm_type == 2.0 and any(p.is_cuda for p in parameters)): return torch.nn.utils.clip_grad_norm_( parameters, max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, ) # Find fp32 and fp16 gradients on GPU device = next(p.device for p in parameters if p.is_cuda) grads_fp32, grads_fp16, grads_misc = [], [], [] for p in parameters: grad = p.grad.detach() if p.dtype == torch.float32 and p.device == device: grads_fp32.append(grad) elif p.dtype == torch.float16 and p.device == device: grads_fp16.append(grad) else: grads_misc.append(grad) # Compute gradient L2 norms norms = [] dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device) if grads_fp32: norms.append( multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_fp32], False, )[0] ) if grads_fp16: norms.append( multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_fp16], False, )[0], ) for g in grads_misc: norms.append(torch.linalg.norm(g).unsqueeze(0).to(device)) total_norm = torch.linalg.norm(torch.cat(norms)) # Check for non-finite values if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError( f"The total norm of order {norm_type} for gradients from " "`parameters` is non-finite, so it cannot be clipped. To disable " "this error and scale the gradients by the non-finite norm anyway, " "set `error_if_nonfinite=False`" ) # Scale gradients clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) if grads_fp32: multi_tensor_applier( amp_C.multi_tensor_scale, dummy_overflow_buf, [grads_fp32, grads_fp32], clip_coef_clamped, ) if grads_fp16: multi_tensor_applier( amp_C.multi_tensor_scale, dummy_overflow_buf, [grads_fp16, grads_fp16], clip_coef_clamped, ) for g in grads_misc: g.mul_(clip_coef_clamped.to(g.device)) return total_norm ================================================ FILE: apex/contrib/conv_bias_relu/__init__.py ================================================ from .conv_bias_relu import ( ConvBiasReLU, ConvBias, ConvBiasMaskReLU, ConvFrozenScaleBiasReLU, ) ================================================ FILE: apex/contrib/conv_bias_relu/conv_bias_relu.py ================================================ import torch from apex import check_cudnn_version_and_warn import fused_conv_bias_relu check_cudnn_version_and_warn(__name__, 8400) class ConvBiasReLU_(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, padding, stride): outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride) ctx.save_for_backward(x, weight, outputs[0]) ctx.padding = padding ctx.stride = stride return outputs[0] @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) return grads[0], grads[1], grads[2], None, None class ConvBiasMaskReLU_(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, mask, padding, stride): outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride) ctx.save_for_backward(x, weight, outputs[0]) ctx.padding = padding ctx.stride = stride return outputs[0] @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) return grads[0], grads[1], grads[2], None, None, None class ConvBias_(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, padding, stride): outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride) ctx.save_for_backward(x, weight) ctx.padding = padding ctx.stride = stride return outputs[0] @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride) return grads[0], grads[1], grads[2], None, None class ConvFrozenScaleBiasReLU_(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, scale, bias, padding, stride): output = fused_conv_bias_relu.forward_cscale_cbias_relu( [x, weight, scale, bias], padding, stride ) ctx.save_for_backward(x, weight, scale, output) ctx.padding = padding ctx.stride = stride return output @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward_cscale_cbias_relu(bwd_args, padding, stride) return grads[0], grads[1], None, None, None, None ConvBiasReLU = ConvBiasReLU_.apply ConvBiasMaskReLU = ConvBiasMaskReLU_.apply ConvBias = ConvBias_.apply ConvFrozenScaleBiasReLU = ConvFrozenScaleBiasReLU_.apply ================================================ FILE: apex/contrib/csrc/bottleneck/bottleneck.cpp ================================================ #include #include // for getcudnnhandle #include #include #include #include #include #ifdef DEBUG #define DEBUG_MSG(str) \ do { \ std::cout << str << std::endl; \ } while (false) #else #define DEBUG_MSG(str) \ do { \ } while (false) #endif #ifdef DEBUG_CUDNN #define DEBUG_CUDNN_MSG(buf, str) \ do { \ buf << str << std::endl; \ } while (false) #else #define DEBUG_CUDNN_MSG(buf, str) \ do { \ } while (false) #endif #define checkCudnnErr(...) \ do { \ int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ if (err) { \ return; \ } \ } while (0) int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { if (code) { printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); return 1; } return 0; } void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true); #define checkCUDAError(val) \ { \ checkError((val), #val, __FILE__, __LINE__); \ } // in-line regular function void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) { if (code != cudaSuccess) { const char* errorMessage = cudaGetErrorString(code); fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); if (abort) { cudaDeviceReset(); exit(code); } } } void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { // For INT8x4 and INT8x32 we still compute standard strides here to input // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. if (filterFormat == CUDNN_TENSOR_NCHW) { strideA[nbDims - 1] = 1; for (int64_t d = nbDims - 2; d >= 0; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } } else { // Here we assume that the format is CUDNN_TENSOR_NHWC strideA[1] = 1; strideA[nbDims - 1] = strideA[1] * dimA[1]; for (int64_t d = nbDims - 2; d >= 2; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } strideA[0] = strideA[2] * dimA[2]; } } int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; } int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); } int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) { int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; return (p); } enum { X_TENSOR, Y_TENSOR, W_TENSOR, Z_TENSOR, B_TENSOR, AFTERADD_TENSOR, AFTERBIAS_TENSOR, AFTERCONV_TENSOR, OPTIONAL, AFTEROPT_TENSOR, }; using common_conv_descriptors = std::tuple; common_conv_descriptors create_common_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, cudnnConvolutionMode_t mode) { const int convDim = 2; int64_t strideA_padded[4]; int64_t outstrideA_padded[4]; int64_t filterstrideA_padded[4]; generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); return common_conv_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, strideA_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, outstrideA_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, filterstrideA_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(mode) .setNDims(convDim) .setStrides(convDim, convstrideA) .setPrePadding(convDim, padA) .setPostPadding(convDim, padA) .setDilation(convDim, dilationA) .build()); } using common_convbias_descriptors = std::tuple; common_convbias_descriptors create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType) { const int convDim = 2; int64_t b_dim_padded[4]; b_dim_padded[0] = 1; b_dim_padded[1] = y_dim_padded[1]; b_dim_padded[2] = 1; b_dim_padded[3] = 1; int64_t x_stride_padded[4]; int64_t y_stride_padded[4]; int64_t w_stride_padded[4]; int64_t b_stride_padded[4]; generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); return common_convbias_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, w_stride_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('z') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('b') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setVirtual() .setId('A') // after add .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setVirtual() .setId('B') // after bias .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('C') // after conv .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('i') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('D') // after optional add .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build()); } // tensor descriptors used for dgrad enum { X_OR_DX_TENSOR, DY_TENSOR, W_OR_DW_TENSOR, SCALE_TENSOR, RELU_TENSOR, AFTER_DCONV_TENSOR, AFTER_DRELU_TENSOR, }; using dconv_descriptors = std::tuple; dconv_descriptors create_dconv_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType) { const int convDim = 2; int64_t b_dim_padded[4]; b_dim_padded[0] = 1; b_dim_padded[1] = x_dim_padded[1]; b_dim_padded[2] = 1; b_dim_padded[3] = 1; int64_t x_stride_padded[4]; int64_t y_stride_padded[4]; int64_t w_stride_padded[4]; int64_t b_stride_padded[4]; generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); return dconv_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, w_stride_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('s') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('r') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('A') // after dconv .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('B') // after drelu .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build()); } // create a cache for plan std::unordered_map plan_cache; // TODO: better name std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) { for (int i = 0; i < 4; i++) { fusion_string += 'X'; fusion_string += std::to_string(x_dim_padded[i]); } for (int i = 0; i < 4; i++) { fusion_string += 'W'; fusion_string += std::to_string(w_dim_padded[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'P'; fusion_string += std::to_string(padA[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'S'; fusion_string += std::to_string(convstrideA[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'D'; fusion_string += std::to_string(dilationA[i]); } fusion_string += 'T'; fusion_string += std::to_string(dataType); return fusion_string; } cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf, cudnn_frontend::OperationGraph& opGraph, std::string cache_string, bool use_heuristic = true) { auto it = plan_cache.find(cache_string); if (it != plan_cache.end()) { DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); return it->second; } else { if (use_heuristic) { // TODO: confirm which mode to use auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() .setOperationGraph(opGraph) .setHeurMode(CUDNN_HEUR_MODE_INSTANT) .build(); // try 3 times for now as WAR for no heuristic training int max_tries = 3, count = 0; auto& engine_configs = heuristics.getEngineConfig(max_tries); while (true) { try { plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle_) .setEngineConfig(engine_configs[count], opGraph.getTag()) .build())); break; } catch (cudnn_frontend::cudnnException e) { if (++count == max_tries) throw e; } } } else { DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); // How many engines support this operation graph ? auto total_engines = opGraph.getEngineCount(); DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); // We have to randomly pick one engine from [0, total_engines) // Selecting "0" by default auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); DEBUG_CUDNN_MSG(log_buf, engine.describe()); auto& knobs = engine.getSupportedKnobs(); for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { DEBUG_CUDNN_MSG(log_buf, it->describe()); } if (knobs.begin() != knobs.end()) { DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); } // Createmplacee the requisite engine config auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); plan_cache.emplace( cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); } return plan_cache.find(cache_string)->second; } } void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the add operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // optional add auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Add Node with scaling parameters. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(scale_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create a optional add Node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // Create an Activation Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) .setyDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder() .setHandle(handle_) .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) .build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(devPtrI ? 6 : 5, data_ptrs) .setUids(devPtrI ? 6 : 5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the add operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the bias operation auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Add Node with scaling parameters. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) // TODO: change enum to aftermul .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Bias Node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(scale_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &scale_op, &add_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(5, data_ptrs) .setUids(5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors dconv_descriptors tensors = create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the scale backward operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) .setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // TODO: do we need getOutputTensor(), and what it returns in backward case? // Create an relu backward Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(std::get(tensors)) .setxDesc(std::get(tensors)) .setdxDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create a Scale Node. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &act_op, &scale_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(5, data_ptrs) .setUids(5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors dconv_descriptors tensors = create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node // mode should be one of following // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { conv_op_builder.setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta); } else { conv_op_builder.setxDesc(std::get(tensors)) .setdwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta); } auto conv_op = conv_op_builder.build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; int64_t uids[] = {'x', 'y', 'w'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(3, data_ptrs) .setUids(3, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrR) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors dconv_descriptors tensors = create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the add backward operation auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) .setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // TODO: do we need getOutputTensor(), and what it returns in backward case? // Create add Node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &add_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; int64_t uids[] = {'x', 'y', 'w', 'r'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(4, data_ptrs) .setUids(4, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } // inputs contains x,w,z,b,(i) std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // setup dimensions int64_t dimA[] = {0, 0, 0, 0}; int64_t filterdimA1[] = {0, 0, 0, 0}; int64_t filterdimA2[] = {0, 0, 0, 0}; int64_t filterdimA3[] = {0, 0, 0, 0}; int64_t filterdimA4[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[]{0, 1, 2, 3}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[10].size(axis[dim]); } } // output dim in n,c,h,w used by backend int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below // use these fixed value for test run int64_t padA[] = {0, 0}; int64_t padA1[] = {1, 1}; int64_t dilationA[] = {1, 1}; int64_t convstrideA[] = {1, 1}; int64_t convstride1X1[] = {stride_1X1, stride_1X1}; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view int64_t outdim1[] = {0, 0, 0, 0}; int64_t outdim2[] = {0, 0, 0, 0}; int64_t outdim3[] = {0, 0, 0, 0}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; } // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* b = inputs[7].data_ptr(); auto out1 = at::empty(outdim1, inputs[0].type(), output_format); at::Half* y1 = out1.data_ptr(); run_conv_scale_bias_add_activation(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, w, y1, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); w = inputs[2].data_ptr(); z = inputs[5].data_ptr(); b = inputs[8].data_ptr(); auto out2 = at::empty(outdim2, inputs[0].type(), output_format); at::Half* y2 = out2.data_ptr(); run_conv_scale_bias_add_activation(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); // create output of conv3 auto out3 = at::empty(outdim3, inputs[0].type(), output_format); at::Half* y3 = out3.data_ptr(); // create output of conv4 that may exist auto identity = at::empty_like(out3); at::Half* yi = identity.data_ptr(); if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { w = inputs[10].data_ptr(); z = inputs[11].data_ptr(); b = inputs[12].data_ptr(); run_conv_scale_bias(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b); DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); } else { yi = x; } w = inputs[3].data_ptr(); z = inputs[6].data_ptr(); b = inputs[9].data_ptr(); run_conv_scale_bias_add_activation(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, y2, w, y3, z, b, yi); DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); outputs.push_back(out1); outputs.push_back(out2); outputs.push_back(out3); return outputs; } std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // setup dimensions int64_t dimA[] = {0, 0, 0, 0}; int64_t filterdimA1[] = {0, 0, 0, 0}; int64_t filterdimA2[] = {0, 0, 0, 0}; int64_t filterdimA3[] = {0, 0, 0, 0}; int64_t filterdimA4[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[]{0, 1, 2, 3}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[14].size(axis[dim]); } } // output dim in n,c,h,w used by backend int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below // use these fixed value for test run int64_t padA[] = {0, 0}; int64_t padA1[] = {1, 1}; int64_t dilationA[] = {1, 1}; int64_t convstrideA[] = {1, 1}; int64_t convstride1X1[] = {stride_1X1, stride_1X1}; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view int64_t outdim1[] = {0, 0, 0, 0}; int64_t outdim2[] = {0, 0, 0, 0}; int64_t outdim3[] = {0, 0, 0, 0}; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; } // dconv3+drelu2+dscale2 at::Half* conv_in = inputs[13].data_ptr(); at::Half* dy3 = inputs[10].data_ptr(); DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); // wgrad auto wgrad3 = at::empty_like(inputs[3]); at::Half* dw3 = wgrad3.data_ptr(); run_dconv(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format); at::Half* dy2 = grad_out2.data_ptr(); at::Half* w = inputs[3].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* relu2 = inputs[13].data_ptr(); run_dconv_drelu_dscale(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, dy2, w, dy3, z, relu2); DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); // dconv2+drelu1+dscale1 conv_in = inputs[12].data_ptr(); // wgrad auto wgrad2 = at::empty_like(inputs[2]); at::Half* dw2 = wgrad2.data_ptr(); run_dconv(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format); at::Half* dy1 = grad_out1.data_ptr(); w = inputs[2].data_ptr(); z = inputs[4].data_ptr(); at::Half* relu1 = inputs[12].data_ptr(); // fused dgrad run_dconv_drelu_dscale(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1); /* // backward strided conv cannot be fused // if stride == 1 but channel changes, we can fuse here if (stride_1X1 != 1){ // dgrad run_dconv(outdimA1, padA1, convstride1X1, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // mul fused mask grad_out1.mul_(inputs[15]); } else { at::Half* relu1 = inputs[12].data_ptr(); // fused dgrad run_dconv_drelu_dscale(outdimA1, padA1, convstride1X1, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1); } */ DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); // create grads of conv4 that may exist auto grad_x_conv4 = at::empty_like(inputs[0]); at::Half* dx_conv4 = grad_x_conv4.data_ptr(); at::Tensor wgrad4; // x used for dconv1 and dconv4 wgrad at::Half* x = inputs[0].data_ptr(); if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { w = inputs[14].data_ptr(); at::Half* dy_conv4 = inputs[11].data_ptr(); if (requires_grad) { run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); } // wgrad wgrad4 = at::empty_like(inputs[14]); at::Half* dw4 = wgrad4.data_ptr(); run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); } else { // if there is no downsample, dx_conv4 is fork of drelu3 dx_conv4 = inputs[11].data_ptr(); } // dconv1+add // wgrad auto wgrad1 = at::empty_like(inputs[1]); at::Half* dw1 = wgrad1.data_ptr(); run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, dw1, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // dgrad w = inputs[1].data_ptr(); auto grad_x = at::empty_like(inputs[0]); at::Half* dx = grad_x.data_ptr(); // backward strided conv cannot be fused // if stride == 1 but channel changes, we can fuse here if (requires_grad) { if (stride_1X1 != 1) { run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // add 2 together grad_x.add_(grad_x_conv4); } else { run_dconv_add(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4); } } DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); outputs.push_back(grad_x); outputs.push_back(wgrad1); outputs.push_back(wgrad2); outputs.push_back(wgrad3); if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); outputs.push_back(wgrad4); } return outputs; } namespace { enum { X_TENSOR, Y_TENSOR, W_TENSOR, Z_TENSOR, B_TENSOR, AFTERADD_TENSOR, AFTERBIAS_TENSOR, AFTERCONV_TENSOR, OPTIONAL, AFTEROPT_TENSOR, AFTERACT_TENSOR, GEN_INDEX_TENSOR, MASK_TOP_TENSOR, MASK_BOTTOM_TENSOR, MASK_TENSOR, THRESHOLD_TOP_TENSOR, THRESHOLD_BOTTOM_TENSOR, }; using masked_convbias_descriptors = std::tuple; masked_convbias_descriptors create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType) { const int convDim = 2; int64_t b_dim_padded[4]; b_dim_padded[0] = 1; b_dim_padded[1] = y_dim_padded[1]; b_dim_padded[2] = 1; b_dim_padded[3] = 1; int64_t x_stride_padded[4]; int64_t y_stride_padded[4]; int64_t w_stride_padded[4]; int64_t b_stride_padded[4]; int64_t threshold_stride[4]; generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); return masked_convbias_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, w_stride_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('z') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('b') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setVirtual() .setId('A') // after add .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setVirtual() .setId('B') // after bias .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('C') // after conv .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('i') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('D') // after optional add .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('E') // after act for masked .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('I') // output of the gen index operation .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_INT32) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('m') // top half of the mask created after the less than .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('n') // bottom half of the mask .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('M') // OR of the top and bottom masks .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, threshold_dim) .setStrides(4, threshold_stride) .setId('t') // threshold for creating the top mask .setAlignment(16) .setDataType(CUDNN_DATA_INT32) .build(), cudnn_frontend::TensorBuilder() .setDim(4, threshold_dim) .setStrides(4, threshold_stride) .setId('u') // threshold for creating the bottom mask .setAlignment(16) .setDataType(CUDNN_DATA_INT32) .build()); } // tensor descriptors used for dgrad enum { X_OR_DX_TENSOR, DY_TENSOR, W_OR_DW_TENSOR, SCALE_TENSOR, RELU_TENSOR, AFTER_DCONV_TENSOR, AFTER_DRELU_TENSOR, DGRAD_INPUT_TENSOR, DGRAD_OPTIONAL_TENSOR, DGRAD_GEN_INDEX_TENSOR, DGRAD_MASK_TOP_TENSOR, DGRAD_MASK_BOTTOM_TENSOR, DGRAD_MASK_TENSOR, DGRAD_THRESHOLD_TOP_TENSOR, DGRAD_THRESHOLD_BOTTOM_TENSOR, }; using dconv_add_descriptors = std::tuple; dconv_add_descriptors create_dconv_add_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType) { const int convDim = 2; int64_t b_dim_padded[4]; b_dim_padded[0] = 1; b_dim_padded[1] = x_dim_padded[1]; b_dim_padded[2] = 1; b_dim_padded[3] = 1; int64_t x_stride_padded[4]; int64_t y_stride_padded[4]; int64_t w_stride_padded[4]; int64_t b_stride_padded[4]; generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); return dconv_add_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, w_stride_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('s') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('r') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('A') // after dconv .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('B') // after drelu .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('i') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('D') // after optional add .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build()); } using dconv_mask_descriptors = std::tuple; dconv_mask_descriptors create_dconv_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType) { const int convDim = 2; int64_t b_dim_padded[4]; b_dim_padded[0] = 1; b_dim_padded[1] = x_dim_padded[1]; b_dim_padded[2] = 1; b_dim_padded[3] = 1; int64_t x_stride_padded[4]; int64_t y_stride_padded[4]; int64_t w_stride_padded[4]; int64_t b_stride_padded[4]; int64_t threshold_stride[4]; generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); return dconv_mask_descriptors(cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('x') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('y') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, w_dim_padded) .setStrides(4, w_stride_padded) .setId('w') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, b_dim_padded) .setStrides(4, b_stride_padded) .setId('s') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setId('r') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('A') // after dconv .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, x_dim_padded) .setStrides(4, x_stride_padded) .setVirtual() .setId('B') // after drelu .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('i') .setAlignment(16) .setDataType(dataType) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('D') // after optional add .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_FLOAT) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('I') // output of the gen index operation .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_INT32) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('m') // top half of the mask created after the less than .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('n') // bottom half of the mask .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, y_dim_padded) .setStrides(4, y_stride_padded) .setId('M') // OR of the top and bottom masks .setAlignment(16) .setVirtual() .setDataType(CUDNN_DATA_BOOLEAN) .build(), cudnn_frontend::TensorBuilder() .setDim(4, threshold_dim) .setStrides(4, threshold_stride) .setId('t') // threshold for creating the top mask .setAlignment(16) .setDataType(CUDNN_DATA_INT32) .build(), cudnn_frontend::TensorBuilder() .setDim(4, threshold_dim) .setStrides(4, threshold_stride) .setId('u') // threshold for creating the bottom mask .setAlignment(16) .setDataType(CUDNN_DATA_INT32) .build()); } void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the add operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // optional add auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // create an add node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // Create a Add Node with scaling parameters. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(add_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(scale_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Activation Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setyDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(6, data_ptrs) .setUids(6, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors( x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the add operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // optional add auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the genIndex descriptor auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_GEN_INDEX) .setMathPrecision(CUDNN_DATA_FLOAT) .setAxis(axis) .build(); DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); // Define the lessThan descriptor auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_CMP_LT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); // Define the greaterThan descriptor auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_CMP_GT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); // Define the logical_or descriptor auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_LOGICAL_OR) .setMathPrecision(CUDNN_DATA_BOOLEAN) .build(); DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); // Define the binary_selection descriptor auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_BINARY_SELECT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Add Node with scaling parameters. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(scale_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create a optional add Node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // Create an Activation Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) .setyDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create a Gen_Index Node. auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(genIndexDesc) .build(); DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); // Create a LessThan Node. auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(lessThanDesc) .build(); DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); // Create a GreaterThan Node. auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(greaterThanDesc) .build(); DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); // Create a LogicalOr Node. auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(logicalOrDesc) .build(); DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); // Create a Binary_Selection Node. auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .settDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(selectionDesc) .build(); DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation if (devPtrI) { std::array ops = { &conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU}; int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(8, data_ptrs) .setUids(8, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } else { std::array ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU}; int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(7, data_ptrs) .setUids(7, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, at::Half* devPtrI) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors dconv_add_descriptors tensors = create_dconv_add_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // optional add auto addDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the scale backward operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) .setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create add Node. auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(addDesc) .build(); DEBUG_CUDNN_MSG(log_buf, add_op.describe()); // TODO: do we need getOutputTensor(), and what it returns in backward case? // Create an relu backward Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(std::get(tensors)) .setxDesc(std::get(tensors)) .setdxDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create a Scale Node. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &add_op, &act_op, &scale_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI}; int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(6, data_ptrs) .setUids(6, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; // Creates the necessary tensor descriptors dconv_mask_descriptors tensors = create_dconv_mask_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the scale backward operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the genIndex descriptor auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_GEN_INDEX) .setMathPrecision(CUDNN_DATA_FLOAT) .setAxis(axis) .build(); DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); // Define the lessThan descriptor auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_CMP_LT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); // Define the greaterThan descriptor auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_CMP_GT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); // Define the logical_or descriptor auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_LOGICAL_OR) .setMathPrecision(CUDNN_DATA_BOOLEAN) .build(); DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); // Define the binary_selection descriptor auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_BINARY_SELECT) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); float alpha = 1.0f; float beta = 0.0f; // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) .setdxDesc(std::get(tensors)) .setwDesc(std::get(tensors)) .setdyDesc(std::get(tensors)) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // TODO: do we need getOutputTensor(), and what it returns in backward case? // Create an relu backward Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(std::get(tensors)) .setxDesc(std::get(tensors)) .setdxDesc(std::get(tensors)) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create a Scale Node. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Gen_Index Node. auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(genIndexDesc) .build(); DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); // Create a LessThan Node. auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(lessThanDesc) .build(); DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); // Create a GreaterThan Node. auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(greaterThanDesc) .build(); DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); // Create a LogicalOr Node. auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(logicalOrDesc) .build(); DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); // Create a Binary_Selection Node. auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(std::get(tensors)) .setbDesc(std::get(tensors)) .settDesc(std::get(tensors)) .setyDesc(std::get(tensors)) .setpwDesc(selectionDesc) .build(); DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU}; int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(7, data_ptrs) .setUids(7, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } struct bottleneck_forward_status { int64_t dimA[4]; int64_t filterdimA1[4]; int64_t filterdimA2[4]; int64_t filterdimA2hh[4]; int64_t filterdimA3[4]; int64_t filterdimA4[4]; int64_t threshdim[4]; int axis[4]; int64_t outdimA0[4]; int64_t outdimA1[4]; int64_t outdimA1b[4]; // out1_pad int64_t outdimA2[4]; int64_t outdimA3[4]; int64_t outdimA4[4]; int64_t padA[2]; int64_t padA1[2]; int64_t padA2[2]; // halo padding int64_t dilationA[2]; int64_t convstrideA[2]; int64_t convstride1X1[2]; int64_t outdim0[4]; // halo input shape int64_t outdim1[4]; int64_t outdim1b[4]; int64_t outdim2[4]; int64_t outdim3[4]; int64_t outdim4[4]; // halo output shape void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; // All dim calculation after this order of n,c,h,w if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } else { axis[0] = 0; axis[1] = 1; axis[2] = 2; axis[3] = 3; } for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[10].size(axis[dim]); } } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { filterdimA2hh[dim] = 1; } else { filterdimA2hh[dim] = filterdimA2[dim]; } } // output dim in n,c,h,w used by backend outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; // use these fixed value for test run padA[0] = 0; padA[1] = 0; padA1[0] = 1; padA1[1] = 1; padA2[0] = 0; padA2[1] = 1; dilationA[0] = 1; dilationA[1] = 1; convstrideA[0] = 1; convstrideA[1] = 1; convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { outdimA1b[dim] = outdimA1[dim] + 2; } else { outdimA1b[dim] = outdimA1[dim]; } } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { outdimA0[dim] = 3; outdimA4[dim] = 1; } else { outdimA0[dim] = outdimA1[dim]; outdimA4[dim] = outdimA2[dim]; } } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } // Create output tensor in the correct shape in pytorch's view outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } for (int dim = 0; dim < 4; dim++) { outdim0[dim] = outdimA0[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]]; outdim1b[dim] = outdimA1b[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; outdim4[dim] = outdimA4[axis[dim]]; } } }; bottleneck_forward_status forward_state; } // end of anonymous namespace std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { // NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method. // NB! We use a global object to store state. forward_state.init(explicit_nhwc, stride_1X1, inputs); // create output vector std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // printf("outdim1 = // (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]); auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format); auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format); auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format); outputs.push_back(out1); outputs.push_back(out2); outputs.push_back(out3); return outputs; } // inputs contains x,w,z,b,(i) void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { std::cout << std::fixed; // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* b = inputs[7].data_ptr(); auto out1 = outputs[0]; at::Half* y1 = out1.data_ptr(); run_conv_scale_bias_add_activation(forward_state.dimA, forward_state.padA, forward_state.convstride1X1, forward_state.dilationA, forward_state.filterdimA1, forward_state.outdimA1, CUDNN_DATA_HALF, x, w, y1, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); } // computes halo (top or bottom) from fat halo input. // fat halo input is 3 pixels wide in H. at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector inputs) { auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // run at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); at::Half* y1 = fat_halo_y1.data_ptr(); auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); at::Half* y2 = halo_y2.data_ptr(); run_conv_scale_bias_add_activation(forward_state.outdimA0, forward_state.padA2, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA4, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); return halo_y2; } // compute halo correction term (top or bottom) from slim halo input (N,C,1,W). // slim halo input is 1 pixel wide in H. at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector inputs, at::Tensor w1by3, at::Tensor out2_part_halo) { auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // run at::Half* w = w1by3.data_ptr(); // C,C,1,3 at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); at::Half* y1 = slim_halo_y1.data_ptr(); at::Half* prev_out2 = out2_part_halo.data_ptr(); auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); at::Half* y2 = halo_y2.data_ptr(); run_conv_add_scale_bias_activation(forward_state.outdimA4, forward_state.padA2, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA2hh, forward_state.outdimA4, CUDNN_DATA_HALF, y1, w, y2, z, b, prev_out2); return halo_y2; } void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { std::cout << std::fixed; // from _out1 method at::Half* x = inputs[0].data_ptr(); auto out1 = outputs[0]; at::Half* y1 = out1.data_ptr(); // run at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); // printf("forward_state.outdimA1 = // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); // printf("forward_state.filterdimA2 = // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); // printf("forward_state.outdimA2 = // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); run_conv_scale_bias_add_activation(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) { std::cout << std::fixed; // from _out1 method at::Half* x = inputs[0].data_ptr(); auto out1 = outputs[0]; at::Half* y1 = out1.data_ptr(); // run at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); // printf("forward_state.outdimA1 = // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); // printf("forward_state.filterdimA2 = // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); // printf("forward_state.outdimA2 = // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, forward_state.threshdim, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr, thresholdTop.data_ptr(), thresholdBottom.data_ptr(), 2); // axis == 1 -> Does this assume explicit NHWC? DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor out1_pad) { std::cout << std::fixed; // from _out1 method at::Half* x = inputs[0].data_ptr(); auto out1 = outputs[0]; at::Half* y1 = out1_pad.data_ptr(); // run at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); // printf("forward_state.outdimA1 = // {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); // printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); // printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); // printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); // printf("forward_state.filterdimA2 = // {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); // printf("forward_state.outdimA2 = // {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); run_conv_scale_bias_add_activation(forward_state.outdimA1b, forward_state.padA2, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); } void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { std::cout << std::fixed; // from _out1 method at::Half* x = inputs[0].data_ptr(); // create output of conv3 auto out3 = outputs[2]; at::Half* y3 = out3.data_ptr(); // create output of conv4 that may exist auto identity = at::empty_like(out3); at::Half* yi = identity.data_ptr(); at::Half *w, *z, *b; if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]) { w = inputs[10].data_ptr(); z = inputs[11].data_ptr(); b = inputs[12].data_ptr(); run_conv_scale_bias(forward_state.dimA, forward_state.padA, forward_state.convstride1X1, forward_state.dilationA, forward_state.filterdimA4, forward_state.outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b); DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); } else { yi = x; } auto out2 = outputs[1]; at::Half* y2 = out2.data_ptr(); w = inputs[3].data_ptr(); z = inputs[6].data_ptr(); b = inputs[9].data_ptr(); run_conv_scale_bias_add_activation(forward_state.outdimA2, forward_state.padA, forward_state.convstrideA, forward_state.dilationA, forward_state.filterdimA3, forward_state.outdimA3, CUDNN_DATA_HALF, y2, w, y3, z, b, yi); DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); } namespace { struct bottleneck_backward_state { int64_t dimA[4]; int64_t filterdimA1[4]; int64_t filterdimA2[4]; int64_t filterdimA3[4]; int64_t filterdimA4[4]; int64_t filterdimA2hh[4]; // Cin,Cout,1,3 int64_t threshdim[4]; int axis[4]; int64_t outdimA1[4]; // grad_out1 int64_t outdimA1b[4]; // out1_pad int64_t outdimA2[4]; // grad_out2 int64_t outdimA3[4]; int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3) int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1) int64_t outdimA2hh[4]; // input: out1 halo (H=1) int64_t padA[2]; int64_t padA1[2]; int64_t padA2[2]; int64_t dilationA[2]; int64_t convstrideA[2]; int64_t convstride1X1[2]; int64_t filterdim2hh[4]; // Cin,1,3,Cout int64_t outdim1[4]; int64_t outdim1b[4]; int64_t outdim2[4]; int64_t outdim3[4]; int64_t outdim1h[4]; int64_t outdim1hh[4]; void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { // setup dimensions dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; // All dim calculation after this order of n,c,h,w if (explicit_nhwc) { axis[0] = 0; axis[1] = 3; axis[2] = 1; axis[3] = 2; } else { axis[0] = 0; axis[1] = 1; axis[2] = 2; axis[3] = 3; } for (int dim = 0; dim < 4; dim++) { dimA[dim] = inputs[0].size(axis[dim]); filterdimA1[dim] = inputs[1].size(axis[dim]); filterdimA2[dim] = inputs[2].size(axis[dim]); filterdimA3[dim] = inputs[3].size(axis[dim]); } if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { for (int dim = 0; dim < 4; dim++) { filterdimA4[dim] = inputs[14].size(axis[dim]); } } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { filterdimA2hh[dim] = 1; } else { filterdimA2hh[dim] = filterdimA2[dim]; } } // output dim in n,c,h,w used by backend outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0; outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0; outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0; // use these fixed value for test run padA[0] = 0; padA[1] = 0; padA1[0] = 1; padA1[1] = 1; padA2[0] = 0; padA2[1] = 1; dilationA[0] = 1; dilationA[1] = 1; convstrideA[0] = 1; convstrideA[1] = 1; convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; // compute output from pad/stride/dilation outdimA1[0] = dimA[0]; outdimA1[1] = filterdimA1[0]; for (int dim = 0; dim < 2; dim++) { outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { outdimA1b[dim] = outdimA1[dim] + 2; } else { outdimA1b[dim] = outdimA1[dim]; } } outdimA2[0] = outdimA1[0]; outdimA2[1] = filterdimA2[0]; for (int dim = 0; dim < 2; dim++) { outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); } outdimA3[0] = outdimA2[0]; outdimA3[1] = filterdimA3[0]; for (int dim = 0; dim < 2; dim++) { outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); } for (int dim = 0; dim < 4; dim++) { if (dim == 2) { outdimA1h[dim] = 3; outdimA2h[dim] = 3; outdimA1hh[dim] = 1; outdimA2hh[dim] = 1; } else { outdimA1h[dim] = outdimA1[dim]; outdimA2h[dim] = outdimA2[dim]; outdimA1hh[dim] = outdimA1[dim]; outdimA2hh[dim] = outdimA2[dim]; } } // Create output tensor in the correct shape in pytorch's view outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0; filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; if (explicit_nhwc) { axis[0] = 0; axis[1] = 2; axis[2] = 3; axis[3] = 1; } for (int dim = 0; dim < 4; dim++) { outdim1[dim] = outdimA1[axis[dim]]; outdim1b[dim] = outdimA1b[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]]; outdim1h[dim] = outdimA1h[axis[dim]]; outdim1hh[dim] = outdimA1hh[axis[dim]]; filterdim2hh[dim] = filterdimA2hh[axis[dim]]; } } }; bottleneck_backward_state backward_state; } // namespace std::vector bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { std::cout << std::fixed; backward_state.init(explicit_nhwc, stride_1X1, inputs); // create output vector std::vector outputs; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; auto grad_x = at::empty_like(inputs[0]); auto wgrad1 = at::empty_like(inputs[1]); auto wgrad2 = at::empty_like(inputs[2]); auto wgrad3 = at::empty_like(inputs[3]); outputs.push_back(grad_x); outputs.push_back(wgrad1); outputs.push_back(wgrad2); outputs.push_back(wgrad3); if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { auto wgrad4 = at::empty_like(inputs[14]); outputs.push_back(wgrad4); } return outputs; } void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { // dconv3+drelu2+dscale2 at::Half* conv_in = inputs[13].data_ptr(); at::Half* dy3 = inputs[10].data_ptr(); // wgrad auto wgrad3 = outputs[3]; at::Half* dw3 = wgrad3.data_ptr(); run_dconv(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); } at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dconv3+drelu2+dscale2 at::Half* conv_in = inputs[13].data_ptr(); at::Half* dy3 = inputs[10].data_ptr(); DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); // dgrad auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format); at::Half* dy2 = grad_out2.data_ptr(); at::Half* w = inputs[3].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* relu2 = inputs[13].data_ptr(); run_dconv_drelu_dscale(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF, dy2, w, dy3, z, relu2); // do halo exchange of dy2 here DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); return grad_out2; } at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2.data_ptr(); // dgrad auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); at::Half* dy1 = grad_out1.data_ptr(); at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* relu1 = inputs[12].data_ptr(); // printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); // fused dgrad // printf("backward_state.outdim1 = // {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]); run_dconv_drelu_dscale(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1); return grad_out1; } at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2.data_ptr(); // dgrad auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); at::Half* dy1 = grad_out1.data_ptr(); at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* relu1 = inputs[12].data_ptr(); // printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); // fused dgrad run_dconv_drelu_dscale_mask(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, backward_state.threshdim, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1, thresholdTop.data_ptr(), thresholdBottom.data_ptr(), 2); return grad_out1; } // perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) // to produce output of shape [N,1,W,C] at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector inputs, at::Tensor w1by3, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2h = grad_out2_halo.data_ptr(); // dgrad auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format); at::Half* dy1h = grad_out1_halo.data_ptr(); // at::Half* w = inputs[2].data_ptr(); // use w1by3 instead, which is a sliced version of inputs[2] at::Half* w = w1by3.data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* relu1h = relu1_halo.data_ptr(); at::Half* pdy1h = part_grad_out1.data_ptr(); // printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); // fused dgrad // printf("backward_state.outdimA1h = // {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); // printf("backward_state.outdimA2h = // {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); // printf("backward_state.filterdimA2 = // {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); run_dconv_add_drelu_dscale(backward_state.outdimA1hh, backward_state.padA2, // 0,1 backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2hh, // C,1,3,C backward_state.outdimA2hh, CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h, pdy1h); return grad_out1_halo; } // perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) // to produce output of shape [N,3,W,C] at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2h = grad_out2_halo.data_ptr(); // dgrad auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format); at::Half* dy1h = grad_out1_halo.data_ptr(); at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[4].data_ptr(); at::Half* relu1h = relu1_halo.data_ptr(); // printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); // fused dgrad // printf("backward_state.outdimA1h = // {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); // printf("backward_state.outdimA2h = // {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); // printf("backward_state.filterdimA2 = // {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); run_dconv_drelu_dscale(backward_state.outdimA1h, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2h, CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h); return grad_out1_halo; } void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2) { std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2.data_ptr(); // dconv2+drelu1+dscale1 at::Half* conv_in = input.data_ptr(); // wgrad auto wgrad2 = outputs[2]; at::Half* dw2 = wgrad2.data_ptr(); // printf("outdimA1b = // (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]); // printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos) backward_state.padA2, // 0, 1 backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2, // dw2.shape backward_state.outdimA2, // dy2.shape CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); } void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2.data_ptr(); // dconv2+drelu1+dscale1 at::Half* conv_in = inputs[12].data_ptr(); // wgrad auto wgrad2 = outputs[2]; at::Half* dw2 = wgrad2.data_ptr(); // printf("outdimA1 = // (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); run_dconv(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); } // compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension // [N,1,W,C] input and grad_out2_halo tensors are all of same shape output tensor is of shape [Cin,1,3,Cout] (regular // filter dims are [Cin,3,3,Cout] at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2_halo) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2_halo.data_ptr(); // dconv2+drelu1+dscale1 at::Half* conv_in = input.data_ptr(); // wgrad auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format); at::Half* dw2 = wgrad2_halo.data_ptr(); // printf("backward_state.outdimA1hh = // {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]); // printf("backward_state.outdimA2hh = // {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]); // printf("backward_state.filterdim2hh = // {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]); // printf("backward_state.filterdimA2hh = // {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]); // printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); run_dconv(backward_state.outdimA1hh, // N,C,1,W backward_state.padA2, // 0, 1 backward_state.convstrideA, backward_state.dilationA, backward_state.filterdimA2hh, // Cin,Cout,1,3 backward_state.outdimA2hh, // N,C,1,W CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); return wgrad2_halo; } void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out1) { at::Half* x = inputs[0].data_ptr(); at::Half* dy1 = grad_out1.data_ptr(); // dconv1+add // wgrad auto wgrad1 = outputs[1]; at::Half* dw1 = wgrad1.data_ptr(); run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, x, dw1, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); } void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { bool requires_grad = inputs[0].requires_grad(); std::cout << std::fixed; auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; // dgrad at::Half* dy2 = grad_out2.data_ptr(); at::Half* dy1 = grad_out1.data_ptr(); /* // backward strided conv cannot be fused // if stride == 1 but channel changes, we can fuse here if (stride_1X1 != 1){ // dgrad run_dconv(outdimA1, padA1, convstride1X1, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // mul fused mask grad_out1.mul_(inputs[15]); } else { at::Half* relu1 = inputs[12].data_ptr(); // fused dgrad run_dconv_drelu_dscale(outdimA1, padA1, convstride1X1, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1); } */ DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); // create grads of conv4 that may exist auto grad_x_conv4 = at::empty_like(inputs[0]); at::Half* dx_conv4 = grad_x_conv4.data_ptr(); at::Tensor wgrad4; // x used for dconv1 and dconv4 wgrad at::Half* x = inputs[0].data_ptr(); at::Half* w = NULL; if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { w = inputs[14].data_ptr(); at::Half* dy_conv4 = inputs[11].data_ptr(); if (requires_grad) { run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); } // wgrad wgrad4 = outputs[4]; at::Half* dw4 = wgrad4.data_ptr(); run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); } else { // if there is no downsample, dx_conv4 is fork of drelu3 dx_conv4 = inputs[11].data_ptr(); } // dgrad w = inputs[1].data_ptr(); auto grad_x = outputs[0]; at::Half* dx = grad_x.data_ptr(); // backward strided conv cannot be fused // if stride == 1 but channel changes, we can fuse here if (requires_grad) { if (stride_1X1 != 1) { run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, dx, w, dy1, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); // add 2 together grad_x.add_(grad_x_conv4); } else { run_dconv_add(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA, backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4); } } DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); } } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &bottleneck_forward, "Bottleneck block forward", py::call_guard()); m.def("backward", &bottleneck_backward, "Bottleneck block backward", py::call_guard()); m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init", py::call_guard()); m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward", py::call_guard()); m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward", py::call_guard()); m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward", py::call_guard()); m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward", py::call_guard()); m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward", py::call_guard()); m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward", py::call_guard()); m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward", py::call_guard()); m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init", py::call_guard()); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward", py::call_guard()); m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward", py::call_guard()); m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward", py::call_guard()); m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward", py::call_guard()); m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward", py::call_guard()); m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward", py::call_guard()); m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward", py::call_guard()); m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward", py::call_guard()); m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward", py::call_guard()); m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward", py::call_guard()); m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp ================================================ #include #include // for getcudnnhandle #include #include #include #include #include #ifdef DEBUG #define DEBUG_MSG(str) \ do { \ std::cout << str << std::endl; \ } while (false) #else #define DEBUG_MSG(str) \ do { \ } while (false) #endif #ifdef DEBUG_CUDNN #define DEBUG_CUDNN_MSG(buf, str) \ do { \ buf << str << std::endl; \ } while (false) #else #define DEBUG_CUDNN_MSG(buf, str) \ do { \ } while (false) #endif #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) #define checkCudnnErr(...) \ do { \ int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ if (err) { \ return; \ } \ } while (0) int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { if (code) { printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); return 1; } return 0; } void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true); #define checkCUDAError(val) \ { \ checkError((val), #val, __FILE__, __LINE__); \ } // in-line regular function void checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) { if (code != cudaSuccess) { const char* errorMessage = cudaGetErrorString(code); fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); if (abort) { cudaDeviceReset(); exit(code); } } } void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { // For INT8x4 and INT8x32 we still compute standard strides here to input // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. if (filterFormat == CUDNN_TENSOR_NCHW) { strideA[nbDims - 1] = 1; for (int64_t d = nbDims - 2; d >= 0; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } } else { // Here we assume that the format is CUDNN_TENSOR_NHWC strideA[1] = 1; strideA[nbDims - 1] = strideA[1] * dimA[1]; for (int64_t d = nbDims - 2; d >= 2; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } strideA[0] = strideA[2] * dimA[2]; } } int getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; } int getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); } int getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) { int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; return (p); } // create a cache for plan std::unordered_map plan_cache; std::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA, int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) { for (int i = 0; i < 4; i++) { fusion_string += 'X'; fusion_string += std::to_string(x_dim_padded[i]); } for (int i = 0; i < 4; i++) { fusion_string += 'W'; fusion_string += std::to_string(w_dim_padded[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'P'; fusion_string += std::to_string(padA[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'S'; fusion_string += std::to_string(convstrideA[i]); } for (int i = 0; i < 2; i++) { fusion_string += 'D'; fusion_string += std::to_string(dilationA[i]); } fusion_string += 'T'; fusion_string += std::to_string(dataType); return fusion_string; } cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf, cudnn_frontend::OperationGraph& opGraph, std::string cache_string, bool use_heuristic = true) { auto it = plan_cache.find(cache_string); if (it != plan_cache.end()) { DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); return it->second; } else { DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); if (use_heuristic) { // TODO: confirm which mode to use auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() .setOperationGraph(opGraph) .setHeurMode(CUDNN_HEUR_MODE_INSTANT) .build(); auto engine_config_count = heuristics.getEngineConfigCount(); auto& engine_configs = heuristics.getEngineConfig(engine_config_count); for (int64_t count = 0; count < engine_config_count; count++) { try { plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle_) .setEngineConfig(engine_configs[count], opGraph.getTag()) .build())); break; } catch (cudnn_frontend::cudnnException e) { // Throw exception if all engines failed if (count == (engine_config_count - 1)) { throw e; } else { continue; } } } } else { // How many engines support this operation graph ? auto total_engines = opGraph.getEngineCount(); DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); // We have to randomly pick one engine from [0, total_engines) // Selecting "0" by default auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); DEBUG_CUDNN_MSG(log_buf, engine.describe()); auto& knobs = engine.getSupportedKnobs(); for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { DEBUG_CUDNN_MSG(log_buf, it->describe()); } if (knobs.begin() != knobs.end()) { DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); } // Createmplacee the requisite engine config auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); plan_cache.emplace( cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); } return plan_cache.find(cache_string)->second; } } void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t b_dim[] = {1, y_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterConvTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('c') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto bTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('b') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterBiasTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, conv_pad) .setPostPadding(convDim, conv_pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(xTensor) .setwDesc(wTensor) .setyDesc(afterConvTensor) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(bTensor) .setyDesc(afterBiasTensor) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Operation Graph. In this case it is convolution bias activation std::array ops = {&conv_op, &bias_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(2, ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; int64_t uids[] = {'x', 'w', 'b', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(4, data_ptrs) .setUids(4, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, int8_t* devPtrM, at::Half* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int conv_dim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t b_dim[] = {1, y_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto mTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('m') .setAlignment(16) .setDataType(CUDNN_DATA_INT8) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterConvTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('c') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto bTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('b') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterBiasTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('B') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterMaskTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('M') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterReLUTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(conv_dim) .setStrides(conv_dim, conv_stride) .setPrePadding(conv_dim, conv_pad) .setPostPadding(conv_dim, conv_pad) .setDilation(conv_dim, conv_dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Define the mask operation auto maskDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(xTensor) .setwDesc(wTensor) .setyDesc(afterConvTensor) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Bias Node auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(bTensor) .setyDesc(afterBiasTensor) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // create a Mask Node auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setbDesc(mTensor) .setyDesc(afterMaskTensor) .setpwDesc(maskDesc) .build(); DEBUG_CUDNN_MSG(log_buf, mask_op.describe()); // Create an Activation Node auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(mask_op.getOutputTensor()) .setyDesc(afterReLUTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create an Operation Graph. In this case it is convolution bias activation std::array ops = {&conv_op, &bias_op, &mask_op, &act_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(4, ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY}; int64_t uids[] = {'x', 'w', 'b', 'm', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(5, data_ptrs) .setUids(5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrS, at::Half* devPtrB, at::Half* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int conv_dim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t s_dim[] = {1, y_dim[1], 1, 1}; int64_t b_dim[] = {1, y_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterConvTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('c') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); auto sTensor = cudnn_frontend::TensorBuilder() .setDim(4, s_dim) .setStrides(4, stride) .setId('s') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, sTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterScaleTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('S') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterScaleTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto bTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('b') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterBiasTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('B') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterReLUTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(conv_dim) .setStrides(conv_dim, conv_stride) .setPrePadding(conv_dim, conv_pad) .setPostPadding(conv_dim, conv_pad) .setDilation(conv_dim, conv_dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the scale operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(xTensor) .setwDesc(wTensor) .setyDesc(afterConvTensor) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a scale Node. auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(sTensor) .setyDesc(afterScaleTensor) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(scale_op.getOutputTensor()) .setbDesc(bTensor) .setyDesc(afterBiasTensor) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Activation Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setyDesc(afterReLUTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create an Operation Graph. In this case it is convolution bias activation std::array ops = {&conv_op, &scale_op, &bias_op, &act_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrS, devPtrB, devPtrY}; int64_t uids[] = {'x', 'w', 's', 'b', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(5, data_ptrs) .setUids(5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int conv_dim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t b_dim[] = {1, y_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterConvTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('c') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto bTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('b') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterBiasTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('B') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto afterReLUTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(conv_dim) .setStrides(conv_dim, conv_stride) .setPrePadding(conv_dim, conv_pad) .setPostPadding(conv_dim, conv_pad) .setDilation(conv_dim, conv_dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the bias operation auto biasDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Define the activation operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_FWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) .setxDesc(xTensor) .setwDesc(wTensor) .setyDesc(afterConvTensor) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create a Bias Node. auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) .setbDesc(bTensor) .setyDesc(afterBiasTensor) .setpwDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Activation Node. auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(bias_op.getOutputTensor()) .setyDesc(afterReLUTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create an Operation Graph. In this case it is convolution bias activation std::array ops = {&conv_op, &bias_op, &act_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(3, ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; int64_t uids[] = {'x', 'w', 'b', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(4, data_ptrs) .setUids(4, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrS, at::Half* devPtrDX) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t s_dim[] = {1, dy_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto dyTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto rTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('r') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inActGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('R') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC); auto scaleTensor = cudnn_frontend::TensorBuilder() .setDim(4, s_dim) .setStrides(4, stride) .setId('s') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, scaleTensor.describe()); generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto dxTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, dxTensor.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the bias backward operation auto scaleDesc = cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build(); DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); // Create an relu backward Node auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(dyTensor) .setxDesc(rTensor) .setdxDesc(inActGradTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create bias node auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(inActGradTensor) .setbDesc(scaleTensor) .setyDesc(dxTensor) .setpwDesc(scaleDesc) .build(); DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); // Create an Operation Graph. In this case it is bias only std::array ops = {&act_op, &scale_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching // creating unique dummy values int64_t pad_dummy[] = {40, 40}; int64_t stride_dummy[] = {40, 40}; int64_t dilation_dummy[] = {40, 40}; auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, s_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrDY, devPtrR, devPtrS, devPtrDX}; int64_t uids[] = {'y', 'r', 's', 'x'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(4, data_ptrs) .setUids(4, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrDR, float* devPtrDB) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t b_dim[] = {1, dy_dim[1], 1, 1}; // Creates the necessary tensor descriptors int64_t stride[4]; generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto dyTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto rTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('r') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inActGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) .setStrides(4, stride) .setId('R') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto biasGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the bias backward operation auto biasDesc = cudnn_frontend::ReductionDescBuilder() .setMathPrecision(CUDNN_DATA_FLOAT) .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) .build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Create an relu backward Node auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(dyTensor) .setxDesc(rTensor) .setdxDesc(inActGradTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create bias node auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) .setxDesc(inActGradTensor) .setyDesc(biasGradTensor) .setreductionDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Operation Graph. In this case it is bias only std::array ops = {&act_op, &bias_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching // creating unique dummy values int64_t pad_dummy[] = {20, 20}; int64_t stride_dummy[] = {20, 20}; int64_t dilation_dummy[] = {20, 20}; auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB}; int64_t uids[] = {'x', 'r', 'R', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(4, data_ptrs) .setUids(4, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrR, at::Half* devPtrRg, float* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; float alpha = 1.0f; float beta = 0.0f; int64_t b_dim[] = {1, x_dim[1], 1, 1}; int64_t stride[4]; generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto outConvGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inConvGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('A') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .setVirtual() .build(); DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe()); generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto rTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('r') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inReLUGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('R') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inBiasGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(convDim) .setStrides(convDim, convstride) .setPrePadding(convDim, pad) .setPostPadding(convDim, pad) .setDilation(convDim, dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Define the activation backward operation auto actDesc = cudnn_frontend::PointWiseDescBuilder() .setMode(CUDNN_POINTWISE_RELU_BWD) .setMathPrecision(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); // Define the bias backward operation auto biasDesc = cudnn_frontend::ReductionDescBuilder() .setMathPrecision(CUDNN_DATA_FLOAT) .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) .build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Create a convolution Node auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) .setdyDesc(outConvGradTensor) .setwDesc(wTensor) .setdxDesc(inConvGradTensor) .setcDesc(convDesc) .setAlpha(alpha) .setBeta(beta) .build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create an relu backward Node auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setdyDesc(inConvGradTensor) .setxDesc(rTensor) .setdxDesc(inReLUGradTensor) .setpwDesc(actDesc) .build(); DEBUG_CUDNN_MSG(log_buf, act_op.describe()); // Create bias node auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) .setxDesc(inReLUGradTensor) .setyDesc(inBiasGradTensor) .setreductionDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Operation Graph. In this case it is bias only std::array ops = {&conv_op, &act_op, &bias_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY}; int64_t uids[] = {'x', 'w', 'r', 'R', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(5, data_ptrs) .setUids(5, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int conv_dim = 2; float alpha = 1.0f; float beta = 0.0f; // Define the convolution problem int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); auto wTensor = cudnn_frontend::TensorBuilder() .setDim(4, w_dim) .setStrides(4, stride) .setId('w') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); auto yTensor = cudnn_frontend::TensorBuilder() .setDim(4, y_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); // Define the convolution problem auto convDesc = cudnn_frontend::ConvDescBuilder() .setDataType(CUDNN_DATA_FLOAT) .setMathMode(CUDNN_CROSS_CORRELATION) .setNDims(conv_dim) .setStrides(conv_dim, conv_stride) .setPrePadding(conv_dim, conv_pad) .setPostPadding(conv_dim, conv_pad) .setDilation(conv_dim, conv_dilation) .build(); DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); // Create a convolution node // mode should be one of following // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { conv_op_builder.setdxDesc(xTensor).setwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc); } else { conv_op_builder.setxDesc(xTensor).setdwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc); } auto conv_op = conv_op_builder.setAlpha(alpha).setBeta(beta).build(); DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); // Create an Operation Graph. In this case it is convolution add bias activation std::array ops = {&conv_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrW, devPtrY}; int64_t uids[] = {'x', 'w', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(3, data_ptrs) .setUids(3, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } void run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPtrX, float* devPtrY) { cudnnHandle_t handle_ = torch::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; int64_t b_dim[] = {1, x_dim[1], 1, 1}; int64_t stride[4]; generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); auto xTensor = cudnn_frontend::TensorBuilder() .setDim(4, x_dim) .setStrides(4, stride) .setId('x') .setAlignment(16) .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); auto yTensor = cudnn_frontend::TensorBuilder() .setDim(4, b_dim) .setStrides(4, stride) .setId('y') .setAlignment(16) .setDataType(CUDNN_DATA_FLOAT) .build(); DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); // Define the bias backward operation auto biasDesc = cudnn_frontend::ReductionDescBuilder() .setMathPrecision(CUDNN_DATA_FLOAT) .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) .build(); DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); // Create bias node auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) .setxDesc(xTensor) .setyDesc(yTensor) .setreductionDesc(biasDesc) .build(); DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); // Create an Operation Graph. In this case it is bias only std::array ops = {&bias_op}; auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); // Create string encoding for plan caching int64_t pad_dummy[] = {10, 10}; int64_t stride_dummy[] = {10, 10}; int64_t dilation_dummy[] = {10, 10}; auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); auto workspace_size = plan.getWorkspaceSize(); DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); void* workspace_ptr = nullptr; auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); if (workspace_size > 0) { workspace_ptr = workspace_tensor.data_ptr(); } void* data_ptrs[] = {devPtrX, devPtrY}; int64_t uids[] = {'x', 'y'}; auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workspace_ptr) .setDataPointers(2, data_ptrs) .setUids(2, uids) .build(); DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); checkCudnnErr(status); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); } catch (cudnn_frontend::cudnnException e) { std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; } } std::vector conv_bias_mask_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); } // output dim in n,c,h,w used by backend int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* b = inputs[2].data_ptr(); int8_t* m = inputs[3].data_ptr(); auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); run_conv_bias_mask_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, m, y); DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item()); outputs.push_back(out); return outputs; } at::Tensor conv_cscale_cbias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); } // output dim in n,c,h,w used by backend int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* s = inputs[2].data_ptr(); at::Half* b = inputs[3].data_ptr(); auto out = at::empty(y_dim, inputs[0].type(), at::MemoryFormat::ChannelsLast); at::Half* y = out.data_ptr(); run_conv_cscale_cbias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, s, b, y); DEBUG_MSG("[DEBUG] conv-cscale-cbias-relu : " << y.to(at::kFloat).sum().item()); return out; } std::vector conv_cscale_cbias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { bool requires_grad = inputs[0].requires_grad(); for (int i = 0; i <= 4; i++) { CHECK_INPUT(inputs[i]); } std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[3].size(axis[dim]); } int64_t b_dim[] = {1, y_dim[1], 1, 1}; int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // run // drelu-dbias at::Half* dy = inputs[4].data_ptr(); at::Half* r = inputs[3].data_ptr(); auto s = inputs[2].data_ptr(); auto dscale = at::empty_like(inputs[4]); at::Half* ds = dscale.data_ptr(); auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); run_drelu_dscale(y_dim, CUDNN_DATA_HALF, dy, r, s, ds); // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); at::Half* dw = wgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, ds, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, ds, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); return outputs; } std::vector conv_bias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); } // output dim in n,c,h,w used by backend int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* b = inputs[2].data_ptr(); auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); run_conv_bias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y); DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item()); outputs.push_back(out); return outputs; } std::vector conv_bias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { bool requires_grad = inputs[0].requires_grad(); for (int i = 0; i <= 3; i++) { CHECK_INPUT(inputs[i]); } std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[3].size(axis[dim]); } int64_t b_dim[] = {1, y_dim[1], 1, 1}; int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // run // drelu-dbias at::Half* dy = inputs[3].data_ptr(); at::Half* r = inputs[2].data_ptr(); auto drelu = at::empty_like(inputs[2]); at::Half* dr = drelu.data_ptr(); auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); auto bgrad = at::empty(b_dim, options, output_format); float* db = bgrad.data_ptr(); run_drelu_dbias(y_dim, CUDNN_DATA_HALF, dy, r, dr, db); // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); at::Half* dw = wgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dr, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dr, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); outputs.push_back(bgrad); return outputs; } std::vector conv_bias_forward(std::vector inputs, int64_t padding, int64_t stride) { std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); } // output dim in n,c,h,w used by backend int64_t y_dim[] = {0, 0, 0, 0}; // use these fixed values int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // compute output from pad/stride/dilation y_dim[0] = x_dim[0]; y_dim[1] = w_dim[0]; for (int dim = 0; dim < 2; dim++) { y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); } // run at::Half* x = inputs[0].data_ptr(); at::Half* w = inputs[1].data_ptr(); at::Half* b = inputs[2].data_ptr(); auto out = at::empty(y_dim, inputs[0].type(), output_format); at::Half* y = out.data_ptr(); run_conv_bias(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y); DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item()); outputs.push_back(out); return outputs; } std::vector conv_bias_backward(std::vector inputs, int64_t padding, int64_t stride) { bool requires_grad = inputs[0].requires_grad(); for (int i = 0; i <= 2; i++) { CHECK_INPUT(inputs[i]); } std::cout << std::fixed; // create output vector std::vector outputs; auto output_format = at::MemoryFormat::ChannelsLast; // setup dimensions int64_t x_dim[] = {0, 0, 0, 0}; int64_t w_dim[] = {0, 0, 0, 0}; int64_t y_dim[] = {0, 0, 0, 0}; // All dim calculation after this order of n,c,h,w int axis[] = {0, 1, 2, 3}; for (int dim = 0; dim < 4; dim++) { x_dim[dim] = inputs[0].size(axis[dim]); w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[2].size(axis[dim]); } int64_t b_dim[] = {1, y_dim[1], 1, 1}; int64_t conv_pad[] = {padding, padding}; int64_t conv_stride[] = {stride, stride}; int64_t conv_dilation[] = {1, 1}; // run // dbias at::Half* dy = inputs[2].data_ptr(); auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); auto bgrad = at::empty(b_dim, options, output_format); float* db = bgrad.data_ptr(); run_dbias(y_dim, CUDNN_DATA_HALF, dy, db); // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); at::Half* dw = wgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dy, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); // conv dgrad at::Half* w = inputs[1].data_ptr(); auto dgrad = at::empty_like(inputs[0]); at::Half* dx = dgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dy, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); outputs.push_back(dgrad); outputs.push_back(wgrad); outputs.push_back(bgrad); return outputs; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward", py::call_guard()); m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward", py::call_guard()); m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward", py::call_guard()); m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward", py::call_guard()); m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward", py::call_guard()); m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU", py::call_guard()); m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, "Fused Conv-(const)Scale-(const)Bias-ReLU backward", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp ================================================ #include #include #include #include #include #include "norm_sample.h" // define this enum: enum bn_type { BN_FWD, BN_BWD }; // this is a global variable static std::map, cudnn_frontend::ExecutionPlan> gbn_plan_cache; at::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const float momentum, const float epsilon, const int64_t bn_group, const int rank_id, const std::vector& peer_buffers) { int64_t N = x.size(0); int64_t C = x.size(1); int64_t H = x.size(2); int64_t W = x.size(3); int64_t tensorDims[] = {N, C, H, W}; int64_t peerDims[] = {bn_group, 4 * C, 1, 1}; int64_t perChannelDims[] = {1, C, 1, 1}; int64_t epsilonDims[] = {1, 1, 1, 1}; // Allocate output tensor at::Tensor y = at::empty_like(x); std::vector void_peer_buffers; for (int64_t addr : peer_buffers) { void_peer_buffers.push_back((void*)addr); } // we need the peer size for the buffer reset size_t peer_size = 1; for (size_t i = 0; i < 4; ++i) { peer_size *= peerDims[i]; } // sanity check assert(bn_group == void_peer_buffers.size()); // check if plan already exists std::vector fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) { auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); gbn_plan_cache.emplace(fv, std::move(plan)); } // get plan and handle auto plan = gbn_plan_cache.find(fv)->second; // execute execute_batch_norm_forward(plan, x.data_ptr(), y.data_ptr(), scale.data_ptr(), bias.data_ptr(), running_mean.data_ptr(), running_var.data_ptr(), running_mean.data_ptr(), running_var.data_ptr(), minibatch_mean.data_ptr(), minibatch_inv_var.data_ptr(), void_peer_buffers, static_cast(epsilon), static_cast(momentum), peer_size, rank_id); return y; } std::vector gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const float epsilon, const int64_t bn_group, const int rank_id, const std::vector& peer_buffers) { int64_t N = x.size(0); int64_t C = x.size(1); int64_t H = x.size(2); int64_t W = x.size(3); int64_t tensorDims[] = {N, C, H, W}; int64_t peerDims[] = {bn_group, 4 * C, 1, 1}; int64_t perChannelDims[] = {1, C, 1, 1}; int64_t epsilonDims[] = {1, 1, 1, 1}; // Allocate output tensor // outputs at::Tensor x_grad, scale_grad, bias_grad; // Allocate outputs x_grad = at::empty_like(x); scale_grad = at::empty_like(scale); bias_grad = at::empty_like(scale); std::vector void_peer_buffers; for (int64_t addr : peer_buffers) { void_peer_buffers.push_back((void*)addr); } // we need the peer size for the buffer reset size_t peer_size = 1; for (size_t i = 0; i < 4; ++i) { peer_size *= peerDims[i]; } assert(bn_group == void_peer_buffers.size()); std::vector fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF}; if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) { auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF); gbn_plan_cache.emplace(fv, std::move(plan)); } // get plan and handle auto plan = gbn_plan_cache.find(fv)->second; // execute execute_batch_norm_backward(plan, x.data_ptr(), dy.data_ptr(), scale.data_ptr(), minibatch_mean.data_ptr(), minibatch_inv_var.data_ptr(), void_peer_buffers, x_grad.data_ptr(), scale_grad.data_ptr(), bias_grad.data_ptr(), static_cast(epsilon), peer_size, rank_id); return std::vector{x_grad, scale_grad, bias_grad}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &gbn_forward, "Group batch norm forward", py::call_guard()); m.def("backward", &gbn_backward, "Group batch backward", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/cudnn_gbn/norm_sample.cpp ================================================ /* * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. */ #include "norm_sample.h" #include // for getcudnnhandle #include #include #include #include "cudnn_backend.h" // some helpers int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line) { if (code) { printf("CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int)code, cudaGetErrorString(code), expr); return 1; } return 0; } int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { if (code) { printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); return 1; } return 0; } bool AllowAll(cudnnBackendDescriptor_t engine_config) { (void)engine_config; return false; } void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) { // For INT8x4 and INT8x32 we still compute standard strides here to input // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. if (filterFormat == CUDNN_TENSOR_NCHW) { strideA[nbDims - 1] = 1; for (int64_t d = nbDims - 2; d >= 0; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } } else { // Here we assume that the format is CUDNN_TENSOR_NHWC strideA[1] = 1; strideA[nbDims - 1] = strideA[1] * dimA[1]; for (int64_t d = nbDims - 2; d >= 2; d--) { strideA[d] = strideA[d + 1] * dimA[d + 1]; } strideA[0] = strideA[2] * dimA[2]; } } // runtime cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get the cudnn handle cudnnHandle_t handle = torch::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; int64_t stride[4]; int64_t peer_stride[4]; // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC); generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, tensorDims) .setStrides(4, tensor_stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, tensorDims) .setStrides(4, peer_stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, perChannelSum) .setStrides(4, stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; auto xTensor = tensor_create(data_type, 100); auto yTensor = tensor_create(data_type, 101); auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105); auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108); auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109); int64_t epsilon_stride[4]; generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, epsilon) .setStrides(4, epsilon_stride) .setId(id) .setAlignment(16) .setDataType(type) .setByValue(true) .build(); }; auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110); auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111); // Create the two peer stat tensors. Jump IDs in case we need to add more tensors with UIDs std::vector peerStatTensors; for (size_t i = 112; i < 112 + peerDims[0]; ++i) { peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); } #if (CUDNN_VERSION >= 8500) // Batch normalization cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; // Forward training cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING; // Create a Finalize node auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR) .setNormalizationMode(normalizationMode) .setNormFwdPhase(phase) .setxDesc(xTensor) .setScaleAndBias(scaleTensor, biasTensor) .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor) .setNextRunningMeanAndVar(outMeanTensor, outVarTensor) .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) .setEpsilonTensor(epsilonTensor) .setExpDecayFactorTensor(expDecayTensor) .setPeerStatTensor(peerStatTensors) .setyDesc(yTensor) .build(); std::array ops = {&batch_norm_op}; #else std::array ops = {}; #endif auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); // std::cout << opGraph.describe() << std::endl; cudnn_frontend::EngineConfigList filtered_configs; auto statuses = cudnn_frontend::get_heuristics_list<2>({"heuristics_instant", "heuristics_fallback"}, opGraph, ::AllowAll, filtered_configs, true); // std::cout << "get_heuristics_list Statuses: "; // for (auto i = 0u ; i < statuses.size(); i++) { // std::cout << cudnn_frontend::to_string(statuses[i]) << " "; // } // std::cout << std::endl; // std::cout << "Filter config list has " << filtered_configs.size() << " configurations " << std::endl; // some verbose printing: // std::cout << "Tensor shape: (" << tensorDims[0] << ", " << tensorDims[1] << ", " << tensorDims[2] << ", " << // tensorDims[3] << ")" << std::endl; auto plan_builder = [&filtered_configs, &opGraph, &handle]() { for (auto i = 0u; i < filtered_configs.size(); i++) { try { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) .setEngineConfig(filtered_configs[i], opGraph.getTag()) .build(); return plan; } catch (cudnn_frontend::cudnnException& e) { continue; } } return cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) .setEngineConfig(filtered_configs[0], opGraph.getTag()) .build(); }; assert(filtered_configs.size() > 0); auto plan = plan_builder(); return plan; } void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* yDevPtr, void* scaledevPtr, void* biasdevPtr, void* in_meandevPtr, void* in_vardevPtr, void* out_meandevPtr, void* out_vardevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr, const std::vector& peer_devPtrs, double epsilon_val, double exponential_decay_factor, size_t peer_size, int rank_id) { // get handle cudnnHandle_t handle_ = torch::native::getCudnnHandle(); // get stream cudaStream_t stream; cudnnGetStream(handle_, &stream); try { // allocate workspace auto workspace_size = plan.getWorkspaceSize(); auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); void* workPtr = nullptr; if (workspace_size > 0) { workPtr = workspace_tensor.data_ptr(); } // first the data pointers std::vector data_ptrs{ xDevPtr, yDevPtr, scaledevPtr, biasdevPtr, in_meandevPtr, in_vardevPtr, out_meandevPtr, out_vardevPtr, saved_meandevPtr, saved_inv_vardevPtr, &epsilon_val, &exponential_decay_factor}; data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); // then the uids std::vector uids; for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { uids.push_back(i); } auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workPtr) .setDataPointers(data_ptrs.size(), data_ptrs.data()) .setUids(uids.size(), uids.data()) .build(); // std::cout << "variantPack " << variantPack.describe() << std::endl; cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); // Reset local communication buffer cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream); } catch (cudnn_frontend::cudnnException& e) { struct cudaDeviceProp prop; checkCudaErr(cudaGetDeviceProperties(&prop, 0)); if (prop.major == 8) { std::cout << "[ERROR] Exception " << e.what() << std::endl; assert(false); } } } cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get cudnn handle cudnnHandle_t handle = torch::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; int64_t stride[4]; int64_t peer_stride[4]; // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC); generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, tensorDims) .setStrides(4, tensor_stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; auto peer_tensor_create = [&peer_stride, &peerDims](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, peerDims) .setStrides(4, peer_stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, perChannelSum) .setStrides(4, stride) .setId(id) .setAlignment(16) .setDataType(type) .build(); }; auto xTensor = tensor_create(data_type, 100); auto dyTensor = tensor_create(data_type, 101); auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102); auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103); auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104); auto dxTensor = tensor_create(data_type, 105); auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106); auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107); int64_t epsilon_stride[4]; generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC); auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) { return cudnn_frontend::TensorBuilder() .setDim(4, epsilon) .setStrides(4, epsilon_stride) .setId(id) .setAlignment(16) .setDataType(type) .setByValue(true) .build(); }; auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108); std::vector peerStatTensors; for (size_t i = 109; i < 109 + peerDims[0]; ++i) { peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i)); } #if (CUDNN_VERSION >= 8500) // Batch normalization cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM; // Create a Finalize node auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR) .setNormalizationMode(normalizationMode) .setxDesc(xTensor) .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor) .setdyDesc(dyTensor) .setScale(scaleTensor) .setEpsilonTensor(epsilonTensor) .setDScaleAndDBias(dScaleTensor, dBiasTensor) .setdxDesc(dxTensor) .setPeerStatTensor(peerStatTensors) .build(); std::array ops = {&batch_norm_op}; #else std::array ops = {}; #endif auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build(); // std::cout << opGraph.describe() << std::endl; cudnn_frontend::EngineConfigList filtered_configs; auto statuses = cudnn_frontend::get_heuristics_list<2>({"heuristics_instant", "heuristics_fallback"}, opGraph, ::AllowAll, filtered_configs, true); auto plan_builder = [&filtered_configs, &opGraph, &handle]() { for (auto i = 0u; i < filtered_configs.size(); i++) { try { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) .setEngineConfig(filtered_configs[i], opGraph.getTag()) .build(); return plan; } catch (cudnn_frontend::cudnnException& e) { continue; } } return cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) .setEngineConfig(filtered_configs[0], opGraph.getTag()) .build(); }; assert(filtered_configs.size() > 0); auto plan = plan_builder(); return plan; } void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* dyDevPtr, void* scaledevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr, const std::vector& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr, void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id) { // get handle cudnnHandle_t handle_ = torch::native::getCudnnHandle(); // get stream cudaStream_t stream; cudnnGetStream(handle_, &stream); try { // allocate workspace auto workspace_size = plan.getWorkspaceSize(); auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); void* workPtr = nullptr; if (workspace_size > 0) { workPtr = workspace_tensor.data_ptr(); } // create helper arrays std::vector data_ptrs{xDevPtr, dyDevPtr, scaledevPtr, saved_meandevPtr, saved_inv_vardevPtr, dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val}; data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end()); std::vector uids; for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) { uids.push_back(i); } auto variantPack = cudnn_frontend::VariantPackBuilder() .setWorkspacePointer(workPtr) .setDataPointers(data_ptrs.size(), data_ptrs.data()) .setUids(uids.size(), uids.data()) .build(); cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); // Reset local communication buffer cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream); } catch (cudnn_frontend::cudnnException& e) { struct cudaDeviceProp prop; checkCudaErr(cudaGetDeviceProperties(&prop, 0)); if (prop.major == 8) { std::cout << "[ERROR] Exception " << e.what() << std::endl; assert(false); } } } ================================================ FILE: apex/contrib/csrc/cudnn_gbn/norm_sample.h ================================================ #pragma once /* * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. */ #pragma once #include #include #include #include #include #include #include #include #include #include /* some helpers */ void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat); int64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line); int64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line); #define checkCudaErr(...) \ do { \ int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ assert(err == 0); \ } while (0) #define checkCudnnErr(...) \ do { \ int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ assert(err == 0); \ } while (0) /** * @brief Run a Group BN forward sample with 2 peer stat tensors. * * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of memory format * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in GBN * */ cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t in_out_data_type); /** * @param xDevPtr input tensor device pointer * @param yDevPtr output tensor device pointer * @param scaledevPtr input scale device pointer for BN scaling * @param biasdevPtr input scale device pointer for BN bias * @param in_meandevPtr Input mean device pointer * @param in_vardevPtr Input variance device pointer * @param out_meandevPtr output mean device pointer * @param out_vardevPtr output variance device pointer * @param saved_meandevPtr saved mean device pointer for BN backward * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward * @param peer_devPtr1 peer stat tensor 1 device pointer * @param peer_devPtr2 peer stat tensor 2 device pointer * @param epsilon_val episilon value as a double * @param exponential_decay_factor exponential_decay_factor as a value * **/ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* yDevPtr, void* scaledevPtr, void* biasdevPtr, void* in_meandevPtr, void* in_vardevPtr, void* out_meandevPtr, void* out_vardevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr, const std::vector& peer_devPtrs, double epsilon_val, double exponential_decay_factor, size_t peer_size, int rank_id); /** * @brief Run a Group BN backward sample with 2 peer stat tensors. * * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of * memory format * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in * GBN * */ cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type); /** * @brief Run a Group BN backward sample with 2 peer stat tensors. * * @param xDevPtr input tensor device pointer * @param yDevPtr output tensor device pointer * @param scaledevPtr input scale device pointer for BN scaling * @param biasdevPtr input scale device pointer for BN bias * @param in_meandevPtr Input mean device pointer * @param in_vardevPtr Input variance device pointer * @param out_meandevPtr output mean device pointer * @param out_vardevPtr output variance device pointer * @param saved_meandevPtr saved mean device pointer for BN backward * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward * @param peer_devPtr1 peer stat tensor 1 device pointer * @param peer_devPtr2 peer stat tensor 2 device pointer * @param epsilon_val episilon value as a double * */ void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* dyDevPtr, void* scaledevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr, const std::vector& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr, void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id); ================================================ FILE: apex/contrib/csrc/fmha/fmha_api.cpp ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include #include #include "fmha.h" extern at::Tensor& mha_fill(at::Tensor& self, const at::Tensor& start_index); void set_params(Fused_multihead_attention_fprop_params& params, // sizes const size_t b, const size_t s, const size_t h, const size_t d, // device pointers void* qkv_packed_d, void* cu_seqlens_d, void* o_packed_d, void* s_d, float p_dropout) { Data_type acc_type = DATA_TYPE_FP32; Data_type data_type = DATA_TYPE_FP16; // Reset the parameters memset(¶ms, 0, sizeof(params)); // Set the pointers and strides. params.qkv_ptr = qkv_packed_d; params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); params.o_ptr = o_packed_d; params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); params.cu_seqlens = static_cast(cu_seqlens_d); // S = softmax(P) params.s_ptr = s_d; params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); // Set the dimensions. params.b = b; params.h = h; params.s = s; params.d = d; // Set the different scale values. const float scale_bmm1 = 1.f / sqrtf(d); constexpr float scale_softmax = 1.f; constexpr float scale_bmm2 = 1.f; set_alpha(params.scale_bmm1, scale_bmm1, data_type); set_alpha(params.scale_softmax, scale_softmax, acc_type); set_alpha(params.scale_bmm2, scale_bmm2, data_type); // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; params.rp_dropout = 1.f / params.p_dropout; TORCH_CHECK(p_dropout < 1.f); set_alpha(params.scale_dropout, params.rp_dropout, data_type); } std::vector mha_fwd( const at::Tensor& qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens, // b+1 const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors, c10::optional gen_) { using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); Launch_params launch_params(dprops, stream, is_training, is_nl); int seq_len = 512; auto launch = &run_fmha_fp16_512_64_sm80; if (max_seq_len <= 128) { seq_len = 128; launch = &run_fmha_fp16_128_64_sm80; } else if (max_seq_len <= 256) { seq_len = 256; launch = &run_fmha_fp16_256_64_sm80; } else if (max_seq_len <= 384) { seq_len = 384; launch = &run_fmha_fp16_384_64_sm80; } else if (max_seq_len <= 512) { seq_len = 512; launch = &run_fmha_fp16_512_64_sm80; } else { TORCH_CHECK(false); } TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda()) TORCH_CHECK(qkv.is_contiguous()) TORCH_CHECK(cu_seqlens.is_contiguous()) TORCH_CHECK(cu_seqlens.dim() == 1); TORCH_CHECK(qkv.dim() == 4); const auto sizes = qkv.sizes(); TORCH_CHECK(sizes[THREE_DIM] == 3); const int batch_size = cu_seqlens.numel() - 1; const int total = sizes[TOTAL_DIM]; const int num_heads = sizes[H_DIM]; const int head_size = sizes[D_DIM]; TORCH_CHECK(batch_size > 0); TORCH_CHECK(head_size == 64); auto opts = qkv.options(); auto ctx = torch::empty({total, num_heads, head_size}, opts); auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts); if (zero_tensors) { mha_fill(ctx, cu_seqlens.index({Slice(-1, None)})); } auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); set_params(launch_params.params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), ctx.data_ptr(), s.data_ptr(), p_dropout); launch(launch_params, /*configure=*/true); // number of times random will be generated per thread, to offset philox counter in thc random // state int64_t counter_offset = launch_params.elts_per_thread; at::PhiloxCudaState rng_engine_inputs; if (is_training) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); } launch(launch_params, /*configure=*/false); return {ctx, s}; } std::vector mha_bwd( const at::Tensor& dout, // total x num_heads, x head_size const at::Tensor& qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i at::Tensor& softmax, // b x h x s x s softmax and dmask - will be overwritten with dP const at::Tensor& cu_seqlens, // b+1 const float p_dropout, // probability to drop const int max_seq_len, // max sequence length to choose the kernel const bool zero_tensors) { using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); int seq_len = 512; auto launch = &run_fmha_dgrad_fp16_512_64_sm80; if (max_seq_len <= 128) { seq_len = 128; launch = &run_fmha_dgrad_fp16_128_64_sm80; } else if (max_seq_len <= 256) { seq_len = 256; launch = &run_fmha_dgrad_fp16_256_64_sm80; } else if (max_seq_len <= 384) { seq_len = 384; launch = &run_fmha_dgrad_fp16_384_64_sm80; } else if (max_seq_len <= 512) { seq_len = 512; launch = &run_fmha_dgrad_fp16_512_64_sm80; } else { TORCH_CHECK(false); } auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(qkv.dtype() == torch::kFloat16); TORCH_CHECK(dout.dtype() == torch::kFloat16); TORCH_CHECK(softmax.dtype() == torch::kFloat16); TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); TORCH_CHECK(qkv.is_cuda()); TORCH_CHECK(cu_seqlens.is_cuda()); TORCH_CHECK(qkv.is_contiguous()); TORCH_CHECK(cu_seqlens.is_contiguous()); TORCH_CHECK(cu_seqlens.dim() == 1); TORCH_CHECK(qkv.dim() == 4); const auto sizes = qkv.sizes(); TORCH_CHECK(sizes[THREE_DIM] == 3); const int batch_size = cu_seqlens.numel() - 1; const int num_heads = sizes[H_DIM]; const int head_size = sizes[D_DIM]; TORCH_CHECK(batch_size > 0); TORCH_CHECK(head_size == 64); auto dqkv = torch::empty_like(qkv); if (zero_tensors) { mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)})); } Fused_multihead_attention_fprop_params params; set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), dout.data_ptr(), // we set o_ptr to dout softmax.data_ptr(), // softmax gets overwritten by dP! p_dropout); // we're re-using these scales Data_type acc_type = DATA_TYPE_FP32; set_alpha(params.scale_bmm1, 1.f, acc_type); set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); params.dqkv_ptr = dqkv.data_ptr(); launch(params, stream); return {dqkv, softmax}; } std::vector mha_bwd_nl( const at::Tensor& dout, // total x num_heads, x head_size const at::Tensor& qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i at::Tensor& softmax, // b x h x s x s softmax and dmask - will be overwritten with dP const at::Tensor& cu_seqlens, // b+1 const float p_dropout, // probability to drop const int max_seq_len, // max sequence length to choose the kernel const bool zero_tensors) { auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda()) TORCH_CHECK(qkv.is_contiguous()) TORCH_CHECK(cu_seqlens.is_contiguous()) TORCH_CHECK(cu_seqlens.dim() == 1); TORCH_CHECK(qkv.dim() == 4); const auto sizes = qkv.sizes(); TORCH_CHECK(sizes[THREE_DIM] == 3); const int batch_size = cu_seqlens.numel() - 1; const int total = sizes[TOTAL_DIM]; const int num_heads = sizes[H_DIM]; const int head_size = sizes[D_DIM]; TORCH_CHECK(batch_size > 0); TORCH_CHECK(head_size == 64); int seq_len = 512; auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl; auto opts = qkv.options(); auto dqkv = torch::empty_like(qkv); if (zero_tensors) { dqkv.zero_(); } int num_chunks = 2; if (batch_size == 1) { num_chunks = 4; } else if (batch_size == 2) { num_chunks = 3; } auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); Fused_multihead_attention_fprop_params params; set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(), dout.data_ptr(), // o_ptr = dout softmax.data_ptr(), // softmax gets overwritten by dP! p_dropout); params.dkv_ptr = dkv.data_ptr(); Data_type acc_type = DATA_TYPE_FP32; set_alpha(params.scale_bmm1, 1.f, acc_type); set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); params.dqkv_ptr = dqkv.data_ptr(); launch(params, num_chunks, stream); // SPLIT-K reduction of num_chunks dK, dV parts // The equivalent of the following Pytorch code: // using namespace torch::indexing; // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); // torch::sum_out(view_out, dkv, 1); const int hidden_size = num_heads * head_size; fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, num_chunks, stream); return {dqkv, softmax, dkv}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "Fused Multi-head Self-attention for BERT"; m.def("fwd", &mha_fwd, "Forward pass", py::call_guard()); m.def("bwd", &mha_bwd, "Backward pass", py::call_guard()); m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/gemm.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #define FMHA_DIV_UP(m, n) (((m) + (n) - 1) / (n)) namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Fragment_base_ { // The data type. using Data_type = Data_type_; // default input type using Input_type_ = Data_type_; // Does it store the array of elements. enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; // The number of elements. enum { NUM_ELTS = NUM_ELTS_ }; // The size of element in bits. enum { BITS_PER_ELT = BITS_PER_ELT_ }; // The size of byte of a single register. enum { BYTES_PER_REG = 4 }; // The size in bits. enum { BITS_PER_REG = BYTES_PER_REG * 8 }; // The number of registers needed to store the fragment. enum { NUM_REGS = Div_up::VALUE }; // The size in bytes (as returned by sizeof(Fragment_base<>). enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; // The alignment. enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The type of the elements. typename Data_type_, // The number of elements. int NUM_ELTS_, // The alignment if you want to force a value -- use 0 otherwise. int ALIGNMENT_ = 0, // The base class. typename Base_ = Fragment_base_ > struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { // The size of a load/store. enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; // Clear the fragment. Using PTX in that code seems to produce better SASS... inline __device__ void clear() { #pragma unroll for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) :); } } // Immutable access to a register. inline __device__ const uint32_t& reg(int ii) const { return this->regs_[ii]; } // Mutable access to a register. inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } uint32_t regs_[Base_::NUM_REGS]; // Immutable access to the elements. inline __device__ const Data_type_& elt(int ii) const { return reinterpret_cast(&this->regs_[0])[ii]; } // Mutable access to the elements. inline __device__ Data_type_& elt(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } // Immutable access to the elements with a cast. template inline __device__ const Cast_type& elt_as(int ii) const { return reinterpret_cast(&this->regs_[0])[ii]; } // Mutable access to the elements. template inline __device__ Cast_type& elt_as(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } // Add another fragment. inline __device__ void add(const Fragment& other) { #pragma unroll for (int ii = 0; ii < NUM_ELTS_; ++ii) { this->elt(ii) += other.elt(ii); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Fragment_a : public Fragment {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Fragment_b : public Fragment {}; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fragment_accumulator : public Fragment { // The base class. using Base = Fragment; // Add two fragments. template inline __device__ void add(const Other_fragment_& other) { for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { this->elt(ii) = this->elt(ii) + other.elt(ii); } } // Do the HMMA. template inline __device__ void mma(const Fragment_a& a, const Fragment_b& b) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" " {%0, %1, %2, %3}, \n" " {%4, %5, %6, %7}, \n" " {%8, %9}, \n" " {%0, %1, %2, %3}; \n" : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void clear(Fragment (&frag)[M][N]) { #pragma unroll for (int mi = 0; mi < M; ++mi) { #pragma unroll for (int ni = 0; ni < N; ++ni) { frag[mi][ni].clear(); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Clear_accumulator {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Clear_accumulator { template static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { fmha::clear(acc); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { #pragma unroll for (int mi = 0; mi < M; ++mi) { #pragma unroll for (int ni = 0; ni < N; ++ni) { acc[mi][ni].mma(a[mi], b[ni]); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The number of rows in the CTA tile. int M_, // The number of cols in the CTA tile. int N_, // The number of elements in the the K dimension of the GEMM loop. int K_, // The number of rows of warps. int WARPS_M_, // The number of cols of warps. int WARPS_N_, // The number of warps in the K dimension of the GEMM loop. int WARPS_K_> struct Cta_tile_ { enum { M = M_, N = N_, K = K_ }; // The number of warps. enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; // The number of warps per CTA. enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; // The number of threads per warp. enum { THREADS_PER_WARP = 32 }; // The number of threads per CTA. enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Hmma_tile { // The number of elements computed with a single warp-MMA. enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 }; // The number of elements computed with a single CTA-MMA. enum { M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K }; // The number of MMAs needed to compute the GEMM. enum { MMAS_M = Div_up::VALUE, MMAS_N = Div_up::VALUE, MMAS_K = Div_up::VALUE, }; // The number of elements computed per warp. enum { M_PER_WARP = MMAS_M * M_PER_MMA, N_PER_WARP = MMAS_N * N_PER_MMA, K_PER_WARP = MMAS_K * K_PER_MMA, }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using A_type = uint16_t; using B_type = uint16_t; using C_type = uint16_t; using Accumulator_type = float; using Epilogue_type = float; constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; //////////////////////////////////////////////////////////////////////////////////////////////////// template using Cta_tile_extd = Cta_tile_; //////////////////////////////////////////////////////////////////////////////////////////////////// template using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/gmem_tile.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The number of bits per element. int BITS_PER_ELEMENT, // The number of rows of Q, K or V loaded by this tile. int ROWS, // The number of columns. int COLS, // The number of matrics. int NUM_MATS = 3> struct Gmem_tile_qkv { // The size of each LDG. enum { BYTES_PER_LDG = 16 }; // The size of a row in bytes. enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; // The number of threads to load a "row" of the matrix. enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; // The number of "rows" loaded per LDG. enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; // The number of LDGs needed to load a chunk of the Q matrix. enum { LDGS = fmha::Div_up::VALUE }; // Ctor. template inline __device__ Gmem_tile_qkv(const Params& params, const int qkv_offset, const BInfo& binfo, const int tidx) : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes), actual_seqlen(binfo.actual_seqlen), qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { // Compute the position in the sequence (within the CTA for the moment). int row = tidx / THREADS_PER_ROW; // Compute the position of the thread in the row. int col = tidx % THREADS_PER_ROW; // Store the row as we need it to disable the loads. row_ = row; // The row offset in the batched GEMM. For each seq element, we store QKV in that order. int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; // Add the block index. row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; // Assemble the final pointer. qkv_ptr_ += row_offset + col * BYTES_PER_LDG; } // Store data to shared memory. template inline __device__ void commit(Smem_tile& smem_tile) { smem_tile.store(fetch_); } // Load data from memory. template inline __device__ void load(Smem_tile& smem_tile) { const void* ptrs[LDGS]; uint32_t preds[LDGS]; #pragma unroll for (int ii = 0; ii < LDGS; ++ii) { ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); fetch_[ii] = make_uint4(0, 0, 0, 0); } // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) Ldg_functor fct(fetch_, ptrs); #pragma unroll for (int ii = 0; ii < LDGS; ++ii) { fct.load(ii, preds[ii]); } } // Store data to memory. inline __device__ void store(const uint4 (&data)[LDGS]) { #pragma unroll for (int ii = 0; ii < LDGS; ++ii) { char* ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) { fmha::stg(ptr, data[ii]); } } } // Move the pointer to the next location. inline __device__ void move() { qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; actual_seqlen -= ROWS; } inline __device__ void move(int steps) { qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; actual_seqlen -= ROWS * steps; } // The stride between rows for the QKV matrice. int64_t params_qkv_stride_in_bytes_; // The pointer. char* qkv_ptr_; // The fetch registers. uint4 fetch_[LDGS]; // Keep track of the row the thread is processing as we move the tile. int row_; // The length of the sequence loaded by that memory tile. int actual_seqlen; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Gmem_tile_o { // The mma tile. using Mma_tile = fmha::Hmma_tile; // The size of each element. enum { BYTES_PER_ELEMENT = 2 }; // The size of a row in bytes. enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; // The number of threads to store a "row" of the matrix. enum { THREADS_PER_ROW = 16 }; // The size of each STG. enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW }; // The number of "rows" stored per iteration of the loop. The output of 1 MMA. enum { ROWS = Cta_tile::M }; // The number of "rows" stored per iteration of the loop. The output of 1 MMA. enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; // The number of outter loop for the stores. enum { LOOPS = ROWS / ROWS_PER_LOOP }; // The number of "rows" stored per STG. enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; // Do we have to guard against partial writes/reads. enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; // The number of STGs needed to store a chunk of the Q matrix. enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; // The number of STGs needed to store a chunk of the Q matrix in total. enum { STGS = STGS_PER_LOOP * LOOPS }; // Ctor. template inline __device__ Gmem_tile_o(const Params& params, const BInfo& binfo, int tidx) : params_o_stride_in_bytes_(params.o_stride_in_bytes), actual_seqlen_(binfo.actual_seqlen), o_ptr_(reinterpret_cast(params.o_ptr)) { // Compute the position in the sequence (within the CTA for the moment). int row = tidx / THREADS_PER_ROW; // Compute the position of the thread in the row. int col = tidx % THREADS_PER_ROW; // Store the row as we need it to disable loads. row_ = row; // The row offset in the batched GEMM. int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; // Assemble the final pointer. o_ptr_ += row_offset + col * BYTES_PER_STG; // Is that thread active on the last STG? if (HAS_INCOMPLETE_STG) { is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; } } // Store data to global memory. inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { #pragma unroll for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { int jj = mi * STGS_PER_LOOP + ii; if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { break; } float x = reinterpret_cast(src[ii].x); float y = reinterpret_cast(src[ii].y); float z = reinterpret_cast(src[ii].z); float w = reinterpret_cast(src[ii].w); uint2 out = float4_to_half4(x, y, z, w); if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_)) { fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); } } } // Move the pointer to the next location. inline __device__ void move() { row_ += ROWS; o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; } inline __device__ void move(const int steps) { row_ += ROWS * steps; o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; } // The stride between rows for the QKV matrice. int64_t params_o_stride_in_bytes_; // The pointer. char* o_ptr_; // Is the thread active for the last STG? int is_active_for_last_stg_; // Keep track of the row to disable loads. int row_; // The length of the sequence loaded by that memory tile. int actual_seqlen_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Gmem_tile_mma_sd { // The mma tile. using Mma_tile = fmha::Hmma_tile; // Each STG stores 8 elements. enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 }; // The number of MMAs in the M dimension. enum { MMAS_M = Mma_tile::MMAS_M }; // The number of MMAs in the N dimension. enum { MMAS_N = Mma_tile::MMAS_N }; // The number of rows computed per MMA per thread block. enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA }; // The number of cols computed per MMA per thread block. enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA }; // The number of threads per block. enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA }; // The size of each row in bytes. I.e. how many bytes are stored per STG. enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG }; // The fixed sequence length. enum { SEQLEN = Cta_tile::N }; // The distance between two blocks (in bytes). enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT }; // The distance between elements stored per loop (in bytes). enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW }; // The type of elements stored per STG. using Type = typename fmha::Uint_from_size_in_bytes::Type; // Ctor. template inline __device__ Gmem_tile_mma_sd(void* ptr, const Params& params, const int bidb, const int bidh, const int tidx) : ptr_(static_cast(ptr)) { // The block index. size_t bidx = bidb * params.h + bidh; // Set store location for each thread at the beginning of the loop ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG; } // Store to global memory. inline __device__ void store(const Type& data, const int mi, const int ni) { size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; fmha::stg(ptr_ + offset, data); } // Load from global memory. inline __device__ void load(Type& data, const int mi, const int ni) { size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; fmha::ldg(data, ptr_ + offset); } // Move to the next tile. inline __device__ void move() { ptr_ += LOOP_STRIDE_BYTES; } inline __device__ void move(const int steps) { ptr_ += LOOP_STRIDE_BYTES * steps; } // The pointer in global memory. char* ptr_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Gmem_tile_mma_s : public Base { // The number of mmas in the vertical dimension. enum { M = Base::MMAS_M }; // The number of mmas in the horizontal dimension. enum { N = Base::MMAS_N }; // The type of the vectors stored by each STG. using Type = typename Base::Type; // Ctor. template inline __device__ Gmem_tile_mma_s(const Params& params, const Block_info& binfo, const int tidx) : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {} // Store to global memory. template inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask& mask) { #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; uint4 dst; dst.x = fmha::float2_to_half2(tmp00, tmp01); dst.y = fmha::float2_to_half2(tmp02, tmp03); dst.z = fmha::float2_to_half2(tmp10, tmp11); dst.w = fmha::float2_to_half2(tmp12, tmp13); if (mask.is_valid(mi, ni, 0, 0)) { Base::store(dst, mi, ni); } } } } // Store to global memory. template inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask) { #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint4 dst; dst.x = frag[ni][mi].reg(0); dst.y = frag[ni][mi].reg(2); dst.z = frag[ni][mi].reg(1); dst.w = frag[ni][mi].reg(3); if (mask.any_valid(mi, ni)) { Base::store(dst, mi, ni); } } } } // Load from global memory. template inline __device__ void load(uint4 (®s)[M][N], const Mask& mask) { #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { regs[mi][ni] = make_uint4(0, 0, 0, 0); if (mask.any_valid(mi, ni)) { Base::load(regs[mi][ni], mi, ni); } } } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The base class. typename Base = fmha::Gmem_tile_qkv > struct Gmem_tile_dout : public Base { // Ctor. template inline __device__ Gmem_tile_dout(const Params& params, const BInfo& binfo, int tidx) : Base(params, 0, binfo, tidx) { this->qkv_ptr_ = reinterpret_cast(params.o_ptr); this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move // Compute the position of the thread in the row. int col = tidx % Base::THREADS_PER_ROW; // The row offset in the batched GEMM. For each seq element, we store O in that order. int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; // Assemble the final pointer. this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Gmem_tile_dq : public Base { // Ctor. template inline __device__ Gmem_tile_dq(const Params& params, const BInfo& binfo, int tidx) : Base(params, binfo, tidx) { this->o_ptr_ = reinterpret_cast(params.dqkv_ptr); this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move // Compute the position of the thread in the row. int col = tidx % Base::THREADS_PER_ROW; // The row offset in the batched GEMM. For each seq element, we store O in that order. int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes + (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; // Assemble the final pointer. this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/kernel_traits.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include "gmem_tile.h" #include "smem_tile.h" //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FMHA_kernel_traits { // The CTA description for the 1st GEMM. using Cta_tile_p = fmha::Cta_tile_extd; // The CTA description for the 2nd GEMM. using Cta_tile_o = fmha::Cta_tile_extd; // Do we use one buffer for K and V. enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u }; // Do we keep K in registers. enum { K_IN_REGS = (FLAGS & 0x10u) == 0u }; // The global memory tile to load Q. using Gmem_tile_q = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle Q. using Smem_tile_q = fmha::Smem_tile_a; // The global memory tile to load K. using Gmem_tile_k = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle K. using Smem_tile_k = fmha::Smem_tile_b; // The global memory tile to load V. using Gmem_tile_v = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle V. using Smem_tile_v = fmha::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = fmha::Gmem_tile_o; // The shared memory tile for O. using Smem_tile_o = fmha::Smem_tile_o; // The global memory tile to load/store S. using Gmem_tile_s = fmha::Gmem_tile_mma_s; // The shared memory tile to transpose S. using Smem_tile_st = fmha::Smem_tile_mma_transposed; using Gmem_tile_do = fmha::Gmem_tile_dout; // Make sure the number of threads match. static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); // The number of threads. enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; // Make sure the number of threads matches both CTAs. static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); // The amount of shared memory needed to load Q and K. enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; // The extra amount of shared memory needed to load V. enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K and V.. enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; // The amount of shared memory needed to load Q and store O. enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K, V and O. enum { BYTES_PER_SMEM = fmha::Max::VALUE }; // Make sure we have enough shared memory. static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/mask.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once namespace fmha { template struct Mask { using Mma_tile = fmha::Hmma_tile; template __device__ Mask(const Params& params, const BInfo& blockInfo, int tidx) { actual_seqlen = blockInfo.actual_seqlen; const int warp = tidx / Cta_tile::THREADS_PER_WARP; const int lane = tidx % Cta_tile::THREADS_PER_WARP; static_assert(Cta_tile::WARPS_K == 1, ""); // find the warp in the Cta tile const int warp_n = (warp / Cta_tile::WARPS_M); const int warp_m = (warp % Cta_tile::WARPS_M); // decompose warp into 8x4 tile const int quad = lane / 4; const int tid = (lane % 4) * 2; row = warp_m * 16 + quad; col = warp_n * 16 + tid; } inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { // ii and jj iterate over the 2x4 fragment const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; return col_valid; // return row_valid && col_valid; } // BERT Mask: if upper left is invalid, none are valid inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } inline __device__ void load(int it) { row_offset = it * Cta_tile::M + row; } int row_offset; int row; int col; int actual_seqlen; }; } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/smem_tile.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The description of the tile computed by this CTA. typename Cta_tile, // The number of rows in the 2D shared memory buffer. int M_, // The number of cols. int N_, // The size in bits of each element. int BITS_PER_ELEMENT_, // The number of bytes per STS. int BYTES_PER_STS_ = 16, // The number of buffers. (Used in multistage and double buffer cases.) int BUFFERS_PER_TILE_ = 1, // Do we enable the fast path for LDS.128 and friends. int ENABLE_LDS_FAST_PATH_ = 0, // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. int ROWS_PER_XOR_PATTERN_ = 8, // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. int COLS_PER_XOR_PATTERN_ = 1, // Use or not predicates bool USE_PREDICATES_ = true> struct Smem_tile_without_skews { // The size in bits of each element. enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; // The size in bytes of a single STS. enum { BYTES_PER_STS = BYTES_PER_STS_ }; // The number of elements per STS. enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; // To support arbitrary N, we pad some values to a power-of-2. enum { N_WITH_PADDING = Next_power_of_two::VALUE }; // The number of bytes per row without packing of rows. enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; // The number of bytes per row -- we want at least 128B per row. enum { BYTES_PER_ROW = Max::VALUE }; // The number of rows in shared memory (two rows may be packed into a single one). enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; // The number of threads per row. enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; // The number of threads per row. enum { THREADS_PER_ROW = Min::VALUE }; // The number of STS per row. enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; // It must be at least one. static_assert(STS_PER_ROW >= 1, ""); // The number of rows written with a single STS. enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) static_assert(ROWS_PER_STS >= 1, ""); // The number of STS needed to store all rows. enum { STS_PER_COL = Div_up::VALUE }; // The number of STS in total. enum { STS = STS_PER_COL * STS_PER_ROW }; // The size of one buffer in bytes in shared memory. enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; // The number of buffers. enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; // The size in bytes of total buffers. enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; // The boundary for smem_read_offset and smem_write_offset increment. enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; // Do we enable the LDS.128 fast path? enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; static_assert(ENABLE_LDS_FAST_PATH == 0); // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; // Use or not predicates enum { USE_PREDICATES = USE_PREDICATES_ }; // The type of elements that are stored in shared memory by each thread. using Store_type = typename Uint_from_size_in_bytes::Type; // Ctor. inline __device__ Smem_tile_without_skews(void* smem, int tidx) : smem_(__nvvm_get_smem_pointer(smem)) { // The row written by a thread. See doc/mma_smem_layout.xlsx. int smem_write_row = tidx / THREADS_PER_ROW; // The XOR pattern. int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; // Compute the column and apply the XOR pattern. int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; // The offset. this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; // TODO: Why not merge it with the read offset? this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); } // Compute the store pointers. template inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { #pragma unroll for (int ii = 0; ii < N; ++ii) { // Decompose the STS into row/col. int row = ii / STS_PER_ROW; int col = ii % STS_PER_ROW; // Assemble the offset. int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; // Take the column into account. if (STS_PER_ROW > 1) { offset += col * THREADS_PER_ROW * BYTES_PER_STS; } // Apply the XOR pattern if needed. if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) { const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; } // Assemble the final pointer :) ptrs[ii] = smem_ + offset + smem_write_buffer_; } } inline __device__ void debug_reset() { for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { for (int row = 0; row < ROWS; ++row) { for (int col = 0; col < BYTES_PER_ROW; col += 4) { if (threadIdx.x == 0) { uint32_t val = 0x0; sts(val, smem_ + row * BYTES_PER_ROW + col + buffer); } } } } } // Print the content of the tile (only for debug ;)). inline __device__ void debug_print() const { for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { for (int row = 0; row < ROWS; ++row) { for (int col = 0; col < BYTES_PER_ROW; col += 4) { if (threadIdx.x == 0) { uint32_t val; lds(val, smem_ + row * BYTES_PER_ROW + col + buffer); printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", blockIdx.x, blockIdx.y, blockIdx.z, smem_, buffer, row, col, val); } } } } } // Move the read offset to next buffer. inline __device__ void move_to_next_read_buffer() { if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; } else if (BUFFERS_PER_TILE > 1) { this->smem_read_buffer_ += BYTES_PER_BUFFER; } } // Move the read offset to next buffer. TODO: Remove this member function!!! inline __device__ void move_next_read_buffer() { this->move_to_next_read_buffer(); } // Move the read offset to next N buffer (circular-buffer). inline __device__ void move_to_next_read_buffer(int N) { if (BUFFERS_PER_TILE > 1) { this->smem_read_buffer_ += N * BYTES_PER_BUFFER; this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; } } // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! inline __device__ void move_next_read_buffer(int N) { this->move_to_next_read_buffer(N); } // Move the write offset to next buffer. inline __device__ void move_to_next_write_buffer() { if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; } else if (BUFFERS_PER_TILE > 1) { this->smem_write_buffer_ += BYTES_PER_BUFFER; } } // Move the write offset to next buffer. TODO: Remove that member function! inline __device__ void move_next_write_buffer() { this->move_to_next_write_buffer(); } // Move the read offset. inline __device__ void move_read_offset(int delta) { this->smem_read_offset_ += delta; } // Move the write offset. inline __device__ void move_write_offset(int delta) { this->smem_write_offset_ += delta; } // Store to the tile in shared memory. template inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { uint32_t smem_ptrs[N]; this->compute_store_pointers(smem_ptrs); sts(smem_ptrs, data); } // Store to the tile in shared memory. template inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { uint32_t smem_ptrs[N]; this->compute_store_pointers(smem_ptrs); sts(smem_ptrs, data, preds); } // Store to the tile in shared memory. template inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { this->store(data, preds); } // Store to the tile in shared memory. template inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { uint32_t tmp[1] = {preds}; this->store(gmem_ptrs, tmp); } // The shared memory pointer. uint32_t smem_; // The read offset. Reserve 4 offsets if needed. int smem_read_offset_; // The write offset. int smem_write_offset_; // The buffer base offset for read. int smem_read_buffer_; // The buffer base offset for write. int smem_write_buffer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The layout of the tile. typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. int BUFFERS_PER_TILE = 1, // Use or not predicates bool USE_PREDICATES = true> struct Smem_tile_a {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_reset_mask { // The potential mask. enum { HALF = MMAS_K_WITH_PADDING / 2 }; // The remainder. enum { MOD = MMAS_K % HALF }; // The final value. enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { enum { VALUE = 0 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Compute_reset_mask { enum { VALUE = MMAS_K - 1 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rows_per_xor_pattern_a { // The size in bits. enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; // The number of rows. enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE, // How many rows to use for the XOR pattern to avoid bank conflicts? int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE> struct Smem_tile_row_a : public Smem_tile_without_skews { // The MMA tile. using Mma_tile = fmha::Hmma_tile; // The base class. using Base = Smem_tile_without_skews; // The fragment. using Fragment = Fragment_a; // When we use padding to reach a power of two, special care has to be taken. using Cta_tile_with_padding = Cta_tile_with_k_with_padding; // The number of MMAs. using Mma_tile_with_padding = fmha::Hmma_tile; // The size of a single LDS in bytes. enum { BYTES_PER_LDS = 16 }; // Ctor. inline __device__ Smem_tile_row_a(void* smem, int tidx) : Base(smem, tidx) { // For documentation on the layout, see doc/mma_smem_layout.xlsx. // The number of warps. const int WARPS_M = Cta_tile::WARPS_M; const int WARPS_N = Cta_tile::WARPS_N; const int WARPS_K = Cta_tile::WARPS_K; static_assert(WARPS_M == 1); static_assert(WARPS_N == 4 || WARPS_N == 8); static_assert(WARPS_K == 1); static_assert(Base::ROWS_PER_XOR_PATTERN == 8); // The row and column read by the thread. int smem_read_row = (tidx & 0x0f); int smem_read_col = (tidx & 0x07); smem_read_col ^= (tidx & 0x10) / 16; // The shared memory offset. this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; } // Rewind smem_read_offset for last LDS phase in main loop. inline __device__ void reverse_smem_read_offset(int ki = 0) { // Undo the pointer increment for the next ni. // Should match the load function below for ki = 0. if (Mma_tile_with_padding::MMAS_K >= 2) { this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } } // Load from shared memory. inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { #pragma unroll for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; // Load using LDSM.M88.4. uint4 tmp; ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); // Store the value into the fragment. a[mi].reg(0) = tmp.x; a[mi].reg(1) = tmp.y; a[mi].reg(2) = tmp.z; a[mi].reg(3) = tmp.w; } // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 2) { this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; } } // Reset the read offset. inline __device__ void reset_read_offset() { // The number of MMAs in the K dimension. enum { MMAS_K = Mma_tile::MMAS_K }; // The number of MMAs in the K dimension when we include padding. enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; // Assemble the mask. enum { MASK = Compute_reset_mask::VALUE }; // Reset the read offset. this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE> struct Smem_tile_a : public Smem_tile_row_a { // The base class. using Base = Smem_tile_row_a; // Ctor. inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The layout of the tile. typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. int BUFFERS_PER_TILE = 1, // Use or not predicates bool USE_PREDICATES = true> struct Smem_tile_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rows_per_xor_pattern_b { // The size in bits. enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; // The number of rows. enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE, // How many rows to use for the XOR pattern to avoid bank conflicts? int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE> struct Smem_tile_col_b : public Smem_tile_without_skews { // The MMA tile. using Mma_tile = fmha::Hmma_tile; // The base class. using Base = Smem_tile_without_skews; // The fragment. using Fragment = Fragment_b; // When we use padding to reach a power of two, special care has to be taken. using Cta_tile_with_padding = Cta_tile_with_k_with_padding; // The number of MMAs. using Mma_tile_with_padding = fmha::Hmma_tile; // The size of a single LDS in bytes. enum { BYTES_PER_LDS = 16 }; // The number of STS per thread enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; // The number of STS per thread must be at least 1. enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; // Ctor. inline __device__ Smem_tile_col_b(void* smem, int tidx) : Base(smem, tidx) { // For documentation on the layout, see doc/mma_smem_layout.xlsx. // The number of warps. const int WARPS_M = Cta_tile::WARPS_M; const int WARPS_N = Cta_tile::WARPS_N; const int WARPS_K = Cta_tile::WARPS_K; static_assert(Base::ROWS_PER_XOR_PATTERN == 8); static_assert(WARPS_M == 1); static_assert(WARPS_N == 4 || WARPS_N == 8); static_assert(WARPS_K == 1); // The masks to select the warps. const int WARP_MASK_N = Warp_masks::N; // The divisor for the warps. const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; // The row and column read by the thread. int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + (tidx & 0x07) + (tidx & 0x10) / 2; int smem_read_col = (tidx & 0x07); smem_read_col ^= (tidx & 0x08) / 8; // The shared memory offset. this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; } // Rewind smem_read_offset for last LDS phase in main loop. inline __device__ void reverse_smem_read_offset(int ki = 0) { // Undo the pointer increment for the next ni. // Should match the load function below for ki = 0. if (Mma_tile_with_padding::MMAS_K >= 2) { this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } } // Load from shared memory. inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { #pragma unroll for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; // Load using LDSM.M88.4. uint4 tmp; ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); // Store the value into the fragment. b[ni].reg(0) = tmp.x; b[ni].reg(1) = tmp.y; b[ni].reg(2) = tmp.z; b[ni].reg(3) = tmp.w; } // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; } else if (Mma_tile_with_padding::MMAS_K >= 2) { this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; } } // Reset the read offset. inline __device__ void reset_read_offset() { // The number of MMAs in the K dimension. enum { MMAS_K = Mma_tile::MMAS_K }; // The number of MMAs in the K dimension when we include padding. enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; // Assemble the mask. enum { MASK = Compute_reset_mask::VALUE }; // Reset the read offset. this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE> struct Smem_tile_b : public Smem_tile_col_b { // The base class. using Base = Smem_tile_col_b; // Ctor. inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE, // How many rows to use for the XOR pattern to avoid bank conflicts? int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, // How many cols to use for the XOR pattern to avoid bank conflicts? int COLS_PER_XOR_PATTERN_ = 1> struct Smem_tile_row_b : public Smem_tile_without_skews { // The MMA tile. using Mma_tile = fmha::Hmma_tile; // The base class. using Base = Smem_tile_without_skews; // The fragment. using Fragment = Fragment_b; // Can we use LDSM? No if the data type is 32-bit large. enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; // The size of a single LDS in bytes. enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; // The number of elements per LDS. enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; // The number of STS per thread enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; // The number of STS per thread must be at least 1. enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; // Ctor. inline __device__ Smem_tile_row_b(void* smem, int tidx) : Base(smem, tidx) { // The number of warps. const int WARPS_M = Cta_tile::WARPS_M; const int WARPS_N = Cta_tile::WARPS_N; const int WARPS_K = Cta_tile::WARPS_K; static_assert(WARPS_K == 1); static_assert(WARPS_M == 4 || WARPS_M == 8); static_assert(WARPS_N == 1); // The masks to select the warps. const int WARP_MASK_N = Warp_masks::N; const int WARP_MASK_K = Warp_masks::K; // The divisor for the warps. const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; // The row/col read by the thread. int smem_read_row, smem_read_col; static_assert(USE_LDSMT); static_assert(Base::ROWS_PER_XOR_PATTERN == 8); smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08); smem_read_col = (tidx & 0x07); smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; // The shared memory offset. this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; // Fill zeroes for group conv } // Rewind smem_read_offset for last LDS phase in main loop. inline __device__ void reverse_smem_read_offset(int ki = 0) { // The size of each element in bits. const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; // The size in bytes of the data needed to compute an MMA per CTA. const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; #pragma unroll for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { // Undo the pointer increment for the next ni. // Should match the load function below for ki = 0. if (BYTES_PER_MMA_PER_CTA >= 128) { // Nothing to do! } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; } else if (BYTES_PER_MMA_PER_CTA == 64) { // Nothing to do! } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } } // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; } } // Load from shared memory. inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { // The size of each element in bits. const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; // The size in bytes of the data needed to compute an MMA per CTA. const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; #pragma unroll for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { // Prepare the offset. int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; if (BYTES_PER_MMA_PER_CTA == 32) { offset += this->smem_read_offset_; } else if (BYTES_PER_MMA_PER_CTA == 64) { offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2; } else { offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA; } // Load the data using LDSM.MT88.2. uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; uint4 tmp; if (USE_LDSMT) { ldsmt(tmp, ptr); } else { lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW); lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW); lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW); lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW); } // Store those values in the fragment. b[ni].reg(0) = tmp.x; b[ni].reg(1) = tmp.y; b[ni].reg(2) = tmp.z; b[ni].reg(3) = tmp.w; // Move the pointer for the next ni. I expect the compiler to not recompute those. if (BYTES_PER_MMA_PER_CTA >= 128) { // Nothing to do! } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; } else if (BYTES_PER_MMA_PER_CTA == 64) { // Nothing to do! } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { this->smem_read_offset_ ^= BYTES_PER_LDS * 2; } } // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The dimensions of the tile computed by the CTA. typename Cta_tile, // The size of the STS. int BYTES_PER_STS, // The number of buffers per tile. int BUFFERS_PER_TILE> struct Smem_tile_b : public Smem_tile_row_b { // The base class. using Base = Smem_tile_row_b; // Ctor. inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_v : public fmha::Smem_tile_without_skews { // The base class. using Base = Smem_tile_without_skews; // The MMA tile. using Mma_tile = fmha::Hmma_tile; // The fragment. using Fragment = Fragment_b; // The size of a single LDS in bytes. enum { BYTES_PER_LDS = 16 }; // Ctor. inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { // The row/col read by the thread. int read_row, read_col; static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); read_col = (tidx & 0x07); read_col ^= (tidx & 0x10) / 16; // The shared memory offset. this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; } // Load from shared memory. inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { #pragma unroll for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { // Jump by 16 * #warps row. int row = ki * 16 * Cta_tile::WARPS_K; // Load the data using LDSM.MT88.2. uint4 tmp; fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); b[ni].reg(0) = tmp.x; b[ni].reg(1) = tmp.y; b[ni].reg(2) = tmp.z; b[ni].reg(3) = tmp.w; // Move the pointer for the next ni. I expect the compiler to not recompute those. if (Mma_tile::MMAS_N == 4) { this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); } else { assert(false); // Not implemented! } } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_o { // The MMA tile. using Mma_tile = fmha::Hmma_tile; // The accumulators. using Accumulator = fmha::Fragment_accumulator; // The accumulators. using Data_type = typename Accumulator::Data_type; // The size of each element. enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; // The size of each STS. enum { BYTES_PER_STS = 8 }; // The size of each row in shared memory. enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT }; // The size of each LDS. enum { BYTES_PER_LDS = 16 }; enum { THREADS_PER_ROW = 16 }; // The number of rows. enum { ROWS = Cta_tile::M }; // The number of "rows" to process per loop iteration (in the "epilogue"). enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; // The number of outer loops. enum { LOOPS = ROWS / ROWS_PER_LOOP }; // Make sure it matches our expectations. static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); // The number of rows loaded per LDS. enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; // Do we have to guard against partial writes/reads. enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; // The total number of LDS per loop. enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; // The amount of shared memory. enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; // The write pointer. uint32_t smem_write_, smem_read_; // Is the thread active for the last LDS of the series? int is_active_for_last_lds_; static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); // Ctor. inline __device__ Smem_tile_o(void* smem, int tidx) { // Get a 32-bit value for the shared memory address. uint32_t smem_ = __nvvm_get_smem_pointer(smem); static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); int write_row = (tidx & 0x1c) / 4; int write_col = (tidx); // Assemble the write pointer. smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; // The element read by each thread. int read_row = tidx / THREADS_PER_ROW; int read_col = tidx % THREADS_PER_ROW; // Take the XOR pattern into account for the column. read_col ^= 2 * (read_row & 0x7); // Assemble the read pointer. this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; // Is that thread active on the last LDS? if (HAS_INCOMPLETE_LDS) { this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; } } // Load the output fragments. inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { #pragma unroll for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { // Load the elements before the reduction (split-K). uint4 tmp[Cta_tile::WARPS_K]; #pragma unroll for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_)) { fmha::lds(tmp[jj], this->smem_read_ + imm); } } // Perform the reduction. out[ii] = tmp[0]; #pragma unroll for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { out[ii] = fmha::fadd4(out[ii], tmp[jj]); } } } // Store the accumulators. template inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; #pragma unroll for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { // The number of MMAs that are stored per loop iteration. enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; // Store 1st column of the different MMAs. #pragma unroll for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { // Precompute the immediates to jump between rows. int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; uint2 tmp0, tmp1; tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); // Store. fmha::sts(this->smem_write_ + row_0, tmp0); fmha::sts(this->smem_write_ + row_1, tmp1); } // Swizzle the write pointer using a XOR of 16B. this->smem_write_ ^= 32; // Store 2nd column of the different MMAs. #pragma unroll for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { // Precompute the immediates to jump between rows. int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; uint2 tmp0, tmp1; tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); // Store. fmha::sts(this->smem_write_ + row_0, tmp0); fmha::sts(this->smem_write_ + row_1, tmp1); } // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_mma { using Mma_tile = fmha::Hmma_tile; using Fragment = fmha::Fragment_a; enum { COLS = Cta_tile::N }; enum { BYTES_PER_ELT = 2 }; enum { BYTES_PER_STS = 4 }; enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; enum { WARPS_M = Cta_tile::WARPS_M }; enum { WARPS_N = Cta_tile::WARPS_N }; enum { WARPS_K = Cta_tile::WARPS_K }; static_assert(WARPS_K == 1); inline __device__ Smem_tile_mma(char* smem, int tidx) { smem_ = __nvvm_get_smem_pointer(smem); int write_col, write_row; static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); if (WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)) { write_row = (tidx & 0x1c) / 4; write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); } else { write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; write_col = (tidx & 0x03); } write_col ^= (write_row & 0x07) * 4; write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; } template inline __device__ void store(const uint4 (®s)[M][N]) { static_assert(COLS == Cta_tile::N); for (int mi = 0; mi < M; mi++) { for (int ni = 0; ni < N; ni++) { size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); } } } uint32_t smem_; uint32_t write_offset_; uint32_t warp_m; uint32_t warp_n; uint32_t lane; }; template > struct Smem_tile_mma_transposed : public Base { enum { BYTES_PER_LDS = 16 }; enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; enum { WARPS_M = Base::WARPS_M }; enum { WARPS_N = Base::WARPS_N }; static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); using Fragment = typename Base::Fragment; inline __device__ Smem_tile_mma_transposed(char* smem, int tidx) : Base(smem, tidx) { static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); int read_row, read_col; read_row = (tidx & 0x0f); read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; } template inline __device__ void load(Fragment (&frag)[M][N]) { static_assert(Base::COLS == Cta_tile::N); for (int mi = 0; mi < M; mi++) { for (int ni = 0; ni < N; ni++) { size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; uint4 dst; fmha::ldsmt(dst, this->smem_ + offset); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! frag[mi][ni].reg(2) = dst.y; frag[mi][ni].reg(3) = dst.w; } } } uint32_t read_offset_; }; template > struct Smem_tile_mma_epilogue : public Base { enum { BYTES_PER_LDS = 16 }; enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); enum { WARPS_M = Base::WARPS_M }; enum { WARPS_N = Base::WARPS_N }; static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); using Acc = fmha::Fragment_accumulator; inline __device__ Smem_tile_mma_epilogue(char* smem, int tidx) : Base(smem, tidx) { const int read_row = tidx / THREADS_PER_ROW; int read_col = tidx % THREADS_PER_ROW; read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; } inline __device__ void load(uint4 (&data)[NUM_LDS]) { for (int ii = 0; ii < NUM_LDS; ii++) { size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; fmha::lds(data[ii], this->smem_ + offset); } } template inline __device__ void store(const Acc (&acc)[M][N]) { #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { // 1st row - 4 elements per row. float tmp00 = acc[mi][ni].elt(0); float tmp01 = acc[mi][ni].elt(1); float tmp02 = acc[mi][ni].elt(4); float tmp03 = acc[mi][ni].elt(5); // 2nd row - 4 elements per row. float tmp10 = acc[mi][ni].elt(2); float tmp11 = acc[mi][ni].elt(3); float tmp12 = acc[mi][ni].elt(6); float tmp13 = acc[mi][ni].elt(7); uint32_t x = fmha::float2_to_half2(tmp00, tmp01); uint32_t y = fmha::float2_to_half2(tmp02, tmp03); uint32_t z = fmha::float2_to_half2(tmp10, tmp11); uint32_t w = fmha::float2_to_half2(tmp12, tmp13); size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); offset ^= 4 * Base::BYTES_PER_STS; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); } } } template inline __device__ void store(const uint4 (®s)[M][N]) { for (int mi = 0; mi < M; mi++) { for (int ni = 0; ni < N; ni++) { size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); } } } uint32_t read_offset_; }; } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/softmax.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Sum_ { enum { IS_SUM = 1 }; static inline __device__ float apply(float x, float y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Max_ { enum { IS_SUM = 0 }; static inline __device__ float apply(float x, float y) { return x > y ? x : y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float apply_exp_(float x, float max) { return __expf(x - max); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct ReadType {}; template <> struct ReadType<4> { using T = float; }; template <> struct ReadType<8> { using T = float2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Smem_tile_reduce { // Helper class to distribute MMA tiles reduced over rows per warp over quads. // The Mma tile. using Mma_tile = fmha::Hmma_tile; // The number of MMAs in M/N dimensions. enum { MMAS_M = Mma_tile::MMAS_M }; enum { MMAS_N = Mma_tile::MMAS_N }; enum { WARPS_M = Cta_tile::WARPS_M }; enum { WARPS_N = Cta_tile::WARPS_N }; static constexpr int ROWS = WARPS_M * MMAS_M * 16; static constexpr int COLS = WARPS_N; static_assert(COLS == 4 || COLS == 8); static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); static constexpr int ELTS_PER_TILE = ROWS * COLS; static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; static_assert(THREADS_PER_GROUP == 16); // DEBUG static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; static_assert(LOOPS == 1); using read_t = typename ReadType::T; __device__ inline Smem_tile_reduce(float* smem_, const int tidx) { int lane = tidx % 32; int warp = tidx / 32; int warp_m = warp % WARPS_M; int warp_n = warp / WARPS_M; qid_ = lane % 4; int qp = lane / 4; // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. // This won't affect reading as we assume commutative reduction ops. const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; } __device__ inline void store(float (&frag)[2 * MMAS_M]) { if (qid_ == 0) { #pragma unroll for (int mi = 0; mi < MMAS_M; mi++) { int offset = mi * 16 * WARPS_N; smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; } } } __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { #pragma unroll for (int mi = 0; mi < MMAS_M; mi++) { int offset = mi * 16 * 4; frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; } } int qid_; float* smem_write_; read_t* smem_read_; }; template struct Softmax_base { // The Mma tile. using Mma_tile = fmha::Hmma_tile; // The number of MMAs in M/N dimensions. enum { MMAS_M = Mma_tile::MMAS_M }; enum { MMAS_N = Mma_tile::MMAS_N }; // The number of groups of warp such that we have at most 4 warps writing consecutive elements. enum { GROUPS = fmha::Div_up::VALUE }; // The number of elements that we are going to store per row. enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; // The number of rows. enum { ROWS = Cta_tile::M * GROUPS }; // The total number of elements. enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; // Ctor. template inline __device__ Softmax_base(const Params& params, void* smem, int bidb, int tidx) : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), smem_(reinterpret_cast(smem)), tidx_(tidx) { // Move to the 1st mask loaded by the thread+ tidx; // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); // Extract the position in the warp. int warp = tidx / Cta_tile::THREADS_PER_WARP; int lane = tidx % Cta_tile::THREADS_PER_WARP; // Decompose the warp index into M and N. int warp_m = warp % Cta_tile::WARPS_M; int warp_n = warp / Cta_tile::WARPS_M; // Decompose the warp-n index into group/position-inside-the-group. int warp_g = warp_n / ELEMENTS_PER_ROW; int warp_i = warp_n % ELEMENTS_PER_ROW; // The location written by the threads. int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; int write_col = warp_i; // Assemble the write pointer. smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; // Assemble the read pointer. smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; } template inline __device__ void apply_mask(const Mask& mask) { #pragma unroll for (int mi = 0; mi < MMAS_M; ++mi) { #pragma unroll for (int ii = 0; ii < 2; ++ii) { #pragma unroll for (int ni = 0; ni < MMAS_N; ++ni) { #pragma unroll for (int jj = 0; jj < 4; ++jj) { if (!mask.is_valid(mi, ni, ii, jj)) { elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; } } } } } } // Apply the exp to all the elements. inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { #pragma unroll for (int mi = 0; mi < MMAS_M * 2; ++mi) { #pragma unroll for (int ni = 0; ni < MMAS_N * 4; ++ni) { elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); } } } // Scale all the elements. inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. float inv_sum[MMAS_M * 2]; #pragma unroll for (int mi = 0; mi < MMAS_M * 2; ++mi) { inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; } // Update the values. #pragma unroll for (int mi = 0; mi < MMAS_M * 2; ++mi) { #pragma unroll for (int ni = 0; ni < MMAS_N * 4; ++ni) { elt_[mi][ni] *= inv_sum[mi]; } } } // The pointer to the mask. const char* packed_mask_ptr_; // Shared memory for the CTA-wide reduction. float *smem_, *smem_write_, *smem_read_; // The current thread index. int tidx_; // The elements. float elt_[MMAS_M * 2][MMAS_N * 4]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax : public Softmax_base { // The base class. using Base = Softmax_base; // The fragment. using Fragment_a = fmha::Fragment_a; static_assert(Fragment_a::NUM_REGS == 4); enum { WARPS_M = Cta_tile::WARPS_M }; enum { WARPS_N = Cta_tile::WARPS_N }; // The MMAs. enum { MMAS_M = Base::MMAS_M }; enum { MMAS_N = Base::MMAS_N }; // The accumulators. using Accumulator = fmha::Fragment_accumulator; using Accumulator_out = Fragment; static_assert(Accumulator_out::NUM_REGS == 4); static_assert(std::is_same::value); using Smem_tile_red = Smem_tile_reduce; static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); // Ctor. template inline __device__ Softmax(const Params& params, void* smem, int bidb, int tidx) : Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1), smem_sum_(static_cast(smem), tidx), smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {} // Pack the data to a fragment for the next GEMM. template inline __device__ void pack(Fragment_a (&dst)[K][M]) const { #pragma unroll for (int mi = 0; mi < M; ++mi) { #pragma unroll for (int ki = 0; ki < K; ++ki) { // 1st row - 4 elements per row. float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; // 2nd row - 4 elements per row. float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; // Pack to 4 registers. dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); } } } // Scale FP32 fragments inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { const float scalef = reinterpret_cast(this->params_scale_bmm1_); #pragma unroll for (int mi = 0; mi < MMAS_M; ++mi) { #pragma unroll for (int ni = 0; ni < MMAS_N; ++ni) { // 1st row - 4 elements per row. this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; // 2nd row - 4 elements per row. this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; } } } // Scale FP32 fragments inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { #pragma unroll for (int mi = 0; mi < MMAS_M; ++mi) { #pragma unroll for (int ni = 0; ni < MMAS_N; ++ni) { // 1st row - 4 elements per row. this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); // 2nd row - 4 elements per row. this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); } } } template __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator& op, Smem_tile_red& smem_red) { for (int mi = 0; mi < 2 * MMAS_M; mi++) { frag[mi] = this->elt_[mi][0]; for (int ni = 1; ni < 4 * MMAS_N; ni++) { frag[mi] = op(frag[mi], this->elt_[mi][ni]); } } quad_reduce(frag, frag, op); smem_red.store(frag); __syncthreads(); typename Smem_tile_red::read_t tmp[2 * MMAS_M]; smem_red.load(tmp); quad_allreduce(frag, tmp, op); } __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) { MaxOp max; reduce_(frag, max, smem_max_); } __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) { SumOp sum; reduce_(frag, sum, smem_sum_); } const uint32_t params_scale_bmm1_; Smem_tile_red smem_max_; Smem_tile_red smem_sum_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha/utils.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); //////////////////////////////////////////////////////////////////////////////////////////////////// namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Row {}; struct Col {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Next_power_of_two {}; template struct Next_power_of_two { enum { VALUE = M }; }; template <> struct Next_power_of_two<3, false> { enum { VALUE = 4 }; }; template <> struct Next_power_of_two<5, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<6, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<7, false> { enum { VALUE = 8 }; }; template <> struct Next_power_of_two<9, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<10, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<11, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<12, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<13, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<14, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<15, false> { enum { VALUE = 16 }; }; template <> struct Next_power_of_two<24, false> { enum { VALUE = 32 }; }; template <> struct Next_power_of_two<48, false> { enum { VALUE = 64 }; }; template <> struct Next_power_of_two<80, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<96, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; template <> struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Prev_power_of_two {}; template struct Prev_power_of_two { enum { VALUE = N }; }; template <> struct Prev_power_of_two<3, false> { enum { VALUE = 2 }; }; template <> struct Prev_power_of_two<5, false> { enum { VALUE = 4 }; }; template <> struct Prev_power_of_two<6, false> { enum { VALUE = 4 }; }; template <> struct Prev_power_of_two<7, false> { enum { VALUE = 4 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Div_up { enum { VALUE = (M + N - 1) / N }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Max { enum { VALUE = A >= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Max_3 { enum { VALUE = Max::VALUE, C>::VALUE }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Min { enum { VALUE = A <= B ? A : B }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Uint_from_size_in_bytes {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<1> { using Type = uint8_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<2> { using Type = uint16_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<4> { using Type = uint32_t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<8> { using Type = uint2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Uint_from_size_in_bytes<16> { using Type = uint4; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Warp_masks {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; template <> struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; template <> struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; template <> struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; template <> struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; template <> struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; template <> struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; template <> struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; template <> struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; template <> struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; template <> struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; template <> struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; template <> struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; template <> struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; template <> struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ __host__ T div_up(T m, T n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int clz(int x) { for (int i = 31; i >= 0; --i) { if ((1 << i) & x) { return 31 - i; } } return 32; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int find_log_2(int x, bool round_up = false) { int a = 31 - clz(x); if (round_up) { a += (x & (x - 1)) ? 1 : 0; } return a; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { uint2 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; c.x = hmul2(a.x, b.x); c.y = hmul2(a.y, b.y); c.z = hmul2(a.z, b.z); c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; c.x = hmul2(a, b.x); c.y = hmul2(a, b.y); c.z = hmul2(a, b.z); c.w = hmul2(a, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { uint32_t res; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); #else const uint32_t zero = 0u; asm volatile( "{\n" "\t .reg .f16x2 sela;\n" "\t set.gtu.u32.f16x2 sela, %1, %2;\n" "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } static inline __device__ uint32_t habs2(uint32_t x) { uint32_t res; asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// // template static inline __device__ T clamp(T x, T lb, T ub) { return x < lb ? lb : (x > ub ? ub : x); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t clamp_to_zero(uint16_t x) { uint16_t mask; asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); return mask & x; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t float_to_half(float f) { uint16_t h; asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(float a, float b) { uint32_t c; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); #else uint16_t lo = float_to_half(a); uint16_t hi = float_to_half(b); asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); #endif return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t float2_to_half2(const float2& f) { return float2_to_half2(f.x, f.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { uint2 d; d.x = float2_to_half2(x, y); d.y = float2_to_half2(z, w); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else d = hrelu2(hfma2(a, b, c)); #endif return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h0_h0(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float h0_to_float(uint32_t h2) { float f; asm volatile( "{\n" ".reg .f16 lo, hi;\n" "mov.b32 {lo, hi}, %1;\n" "cvt.f32.f16 %0, lo;\n" "}\n" : "=f"(f) : "r"(h2)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t h1_h1(uint32_t x) { uint32_t y; asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" : "=r"(y) : "r"(x)); return y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { uint16_t d; asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd4(uint2 a, uint2 b) { uint2 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd8(uint4 a, uint4 b) { uint4 c; c.x = hadd2(a.x, b.x); c.y = hadd2(a.y, b.y); c.z = hadd2(a.z, b.z); c.w = hadd2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 fadd4(uint4 a, uint4 b) { float4 c; c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); return reinterpret_cast(c); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float half_to_float(uint16_t h) { float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float2 half2_to_float2(uint32_t x) { uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); return make_float2(half_to_float(lo), half_to_float(hi)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) { float2 tmp = half2_to_float2(h); x = tmp.x; y = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { uint16_t d; asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { uint16_t d; asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); return d; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint16_t& dst) { dst = uint16_t(0); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint32_t& dst) { dst = 0u; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint2& dst) { dst = make_uint2(0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void clear(uint4& dst) { dst = make_uint4(0u, 0u, 0u, 0u); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // P R E D I C A T E P A C K I N G // //////////////////////////////////////////////////////////////////////////////////////////////////// enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; //////////////////////////////////////////////////////////////////////////////////////////////////// // // G E N E R I C P R E D I C A T E D L D G S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void load_(Functor& fct, const uint32_t (&preds)[M]) { // The number of complete bytes (where we use all the predicates in a byte). enum { COMPLETE = N / PREDS_PER_BYTE }; // Make sure we did allocate enough predicates. static_assert(Div_up::VALUE <= M, ""); // The remainder. enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; // Make sure we got the math right and the remainder is between 0 and 3. static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); // The mask to extract the predicates. enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; // Clear the fetch registers. #pragma unroll for (int ii = 0; ii < N; ++ii) { fct.clear(ii); } // Run complete steps. bool p[PREDS_PER_BYTE]; #pragma unroll for (int ii = 0; ii < COMPLETE; ++ii) { // The predicate. uint32_t reg = preds[ii / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); } } // Skip the rest of the code if we do not have a remainder. if (REMAINDER > 0) { // The mask to extract the predicates. enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; // The predicate register. uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; // Extract the predicates. #pragma unroll for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); p[jj] = (reg & mask) != 0u; } // Issue the loads. #pragma unroll for (int ii = 0; ii < REMAINDER; ++ii) { fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void load_(Functor& fct, uint32_t preds) { uint32_t tmp[1] = {preds}; load_(fct, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint8_t& dst, const void* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint16_t& dst, const void* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint32_t& dst, const void* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint2& dst, const void* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldg(uint4& dst, const void* ptr) { dst = *reinterpret_cast(ptr); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Ldg_functor { // Ctor. inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) : fetch_(fetch), ptrs_(ptrs) {} // Clear the element. inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } // Trigger the loads. inline __device__ void load(int ii, bool p) { if (p) { ldg(fetch_[ii], ptrs_[ii]); } } // The fetch registers. Data_type (&fetch_)[N]; // The pointers. const void* (&ptrs_)[N]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { Ldg_functor fct(fetch, ptrs); load_(fct, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { ldg_(fetch, ptrs, preds); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint16_t& dst, uint32_t ptr) { asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint32_t& dst, uint32_t ptr) { asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint2& dst, uint32_t ptr) { asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void lds(uint4& dst, uint32_t ptr) { asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// // // L D S M // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint32_t& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint2& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint2& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsm(uint4& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void ldsmt(uint4& dst, uint32_t ptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T G // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast(ptr) = val; } //////////////////////////////////////////////////////////////////////////////////////////////////// // // S T S // //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint16_t val) { asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint32_t val) { asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint2 val) { asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void sts(uint32_t ptr, uint4 val) { asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { #pragma unroll for (int ii = 0; ii < N; ++ii) { sts(ptrs[ii], data[ii]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { sts_(ptrs, data); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ inline T operator()(T const& x, T const& y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ inline T run(T x, Operator& op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Allreduce<2> { template static __device__ inline T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op) { #pragma unroll for (int mi = 0; mi < M; mi++) { dst[mi] = src[mi]; dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op) { float tmp[M]; #pragma unroll for (int mi = 0; mi < M; mi++) { tmp[mi] = op(src[mi].x, src[mi].y); } quad_reduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op) { #pragma unroll for (int mi = 0; mi < M; mi++) { dst[mi] = src[mi]; dst[mi] = Allreduce<4>::run(dst[mi], op); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op) { float tmp[M]; #pragma unroll for (int mi = 0; mi < M; mi++) { tmp[mi] = op(src[mi].x, src[mi].y); } quad_allreduce(dst, tmp, op); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include #include constexpr int TOTAL_DIM = 0; constexpr int THREE_DIM = 1; constexpr int H_DIM = 2; constexpr int D_DIM = 3; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { // The QKV matrices. void* __restrict__ qkv_ptr; // The stride between rows of the Q, K and V matrices. size_t qkv_stride_in_bytes; // The number of heads. int h; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fused_multihead_attention_fprop_params : public Qkv_params { // The dQKV matrices. void* __restrict__ dqkv_ptr; // Temporary for dKV. void* __restrict__ dkv_ptr; // The O matrix (output). void* __restrict__ o_ptr; // The stride between rows of O. int64_t o_stride_in_bytes; // The pointer to the S matrix, overwritten by the dP matrix (bwd). void* __restrict__ s_ptr; // The stride between rows of the S matrix. int64_t s_stride_in_bytes; // The dimensions. int b, s, d; // The scaling factors for the kernel. uint32_t scale_bmm1, scale_softmax, scale_bmm2; // array of length b+1 holding starting offset of each sequence. int* __restrict__ cu_seqlens; // The dropout probability (probability of keeping an activation). float p_dropout; // Scale factor of 1 / (1 - p_dropout). float rp_dropout; // Scale factor of 1 / (1 - p_dropout), in half2. uint32_t scale_dropout; // Random state. at::PhiloxCudaState philox_args; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Launch_params { Launch_params(cudaDeviceProp* props_, cudaStream_t stream_, bool is_training_, bool is_nl_) : elts_per_thread(0), props(props_), stream(stream_), is_training(is_training_), is_nl(is_nl_) {} size_t elts_per_thread; cudaDeviceProp* props; cudaStream_t stream; bool is_training; Kernel_params params; int num_full_heads; int num_main_groups; int heads_last_wave; int main_steps; int rest_steps; bool is_nl; }; //////////////////////////////////////////////////////////////////////////////////////////////////// void run_fmha_fp16_128_64_sm80(Launch_params& launch_params, const bool configure); void run_fmha_fp16_256_64_sm80(Launch_params& launch_params, const bool configure); void run_fmha_fp16_384_64_sm80(Launch_params& launch_params, const bool configure); void run_fmha_fp16_512_64_sm80(Launch_params& launch_params, const bool configure); void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream); void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream); void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream); void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream); void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const bool is_training, const int num_chunks, cudaStream_t stream); void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const int num_chunks, cudaStream_t stream); void fmha_run_noloop_reduce(void* out, const void* in, const int* cu_seqlens, const int hidden_size, const int batch_size, const int total, const int num_chunks, cudaStream_t stream); ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_dgrad_kernel_1xN_reload.h" using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { fmha::compute_dv_1xN(params); fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * 128 * 2); static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(params.h, params.b); fmha_dgrad_fp16_128_64_sm80_kernel<<>>(params); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_dgrad_kernel_1xN_reload.h" using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { fmha::compute_dv_1xN(params); fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * 256 * 2); static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(params.h, params.b); fmha_dgrad_fp16_256_64_sm80_kernel<<>>(params); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_dgrad_kernel_1xN_reload.h" using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { fmha::compute_dv_1xN(params); fmha::compute_dq_dk_1xN(params); } void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * 384 * 2); static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(params.h, params.b); fmha_dgrad_fp16_384_64_sm80_kernel<<>>(params); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_dgrad_kernel_1xN_reload.h" #include "fmha_dgrad_kernel_1xN_reload_nl.h" using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { fmha::compute_dv_1xN(params); fmha::compute_dq_dk_1xN(params); } template __global__ void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params) { fmha::compute_dv_1xN_nl(params); fmha::compute_dq_dk_1xN_nl(params); } void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * 512 * 2); static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(params.h, params.b); fmha_dgrad_fp16_512_64_sm80_kernel<<>>(params); } void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const int num_chunks, cudaStream_t stream) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; using Smem_tile_s = fmha::Smem_tile_mma_transposed; constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; static_assert(smem_size_s == 16 * 512 * 2); static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; if (num_chunks == 2) { kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; } else if (num_chunks == 3) { kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>; } else { assert(false && "Unsupperted number of chunks"); } if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(params.h, params.b, num_chunks); kernel<<>>(params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dv_1xN(const Params& params) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_dv = fmha::Cta_tile_extd; static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); static_assert(Cta_tile_dv::N == 64); static_assert(Cta_tile_dv::K == 16); // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_dv = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The shared memory tile to swizzle Q. // using Smem_tile_q = typename Kernel_traits::Smem_tile_q; using Smem_tile_q = fmha::Smem_tile_a; // The shared memory tile to reload Q as fragment b. using Smem_tile_qt = fmha::Smem_tile_b; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; // The shared memory tile to swizzle K. using Smem_tile_k = typename Kernel_traits::Smem_tile_k; // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; // The shared memory tile to swizzle O. using Smem_tile_o = typename Kernel_traits::Smem_tile_o; // The global memory tile to store dV. using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle dV. using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Smem_tile_st = typename Kernel_traits::Smem_tile_st; using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; // Shared memory. extern __shared__ char smem_[]; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); if (binfo.stop_early()) return; Mask mask(params, binfo, tidx); // Allocate the global memory tile loader for Q. Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q // Allocate the shared memory tile loader for Q. Smem_tile_q smem_q(&smem_[0], tidx); Smem_tile_qt smem_qt(&smem_[0], tidx); Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K // Allocate the shared memory tile loader for K. Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); // Trigger the loads for Q. gmem_q.load(smem_q); // Trigger the loads for K. gmem_k.load(smem_k); // Commit the data for Q and K to shared memory. gmem_q.commit(smem_q); gmem_k.commit(smem_k); // Make sure the data is in shared memory. __syncthreads(); // Load the fragments for Q. typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; smem_q.load(frag_q[0], 0); typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); static_assert(Mma_tile_dv::MMAS_K == 1); smem_qt.load(frag_qt[0], 0); // Load the fragments for K. We keep the data in registers during the entire kernel. typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; smem_k.load(frag_k[0], 0); enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; Gmem_tile_s gmem_s(params, binfo, tidx); // Create the object to do the softmax. using Softmax = fmha::Softmax; Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); enum { THREADS_PER_ROW = 32 }; enum { M = Mma_tile_p::MMAS_M }; enum { N = Mma_tile_p::MMAS_N }; // Declare the accumulators for the 2nd gemm. fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; fmha::Clear_accumulator::apply(acc_dv); enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; // Load over the entire sequence length. for (int l = 0; l < STEPS; l++) { const int loop = l * Cta_tile_p::M; if (loop >= binfo.actual_seqlen) break; // Load S uint4 s_regs[M][N]; gmem_s.load(s_regs, mask); fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Clear_accumulator::apply(acc_p); // Do this part of P^T = (Q * K^T)^T. #pragma unroll for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_q.load(frag_q[ki & 1], ki); smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Store s * dmask to smem for transpose smem_s.store(s_regs); // Declare the accumulators for the 1st gemm. // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe if (l < STEPS - 1) { smem_q.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(smem_q); } // Convert from the accumulator type to FP32 for Softmax. softmax.unpack(acc_p); float s_mat[2 * M][4 * N]; #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint4& dst = s_regs[mi][ni]; fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); } } #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ii = 0; ii < 2; ii++) { #pragma unroll for (int ni = 0; ni < N; ni++) { #pragma unroll for (int jj = 0; jj < 4; jj++) { float& s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; const bool drop = reinterpret_cast(s_dmask) & 0x80000000; const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; s_dmask = fabsf(s_dmask); softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask); } } } } float p_sum[2 * M]; softmax.reduce_sum(p_sum); const float scalef = reinterpret_cast(params.scale_softmax); #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ii = 0; ii < 2; ii++) { #pragma unroll for (int ni = 0; ni < N; ni++) { #pragma unroll for (int jj = 0; jj < 4; jj++) { softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]); softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; } } } } typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; smem_s.load(frag_s); for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) { for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) { for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) { frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); } } } gmem_s.store(softmax.elt_, mask); gmem_s.move(); #pragma unroll for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dv::MMAS_K; fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Commit the values for Q into shared memory. if (l < STEPS - 1) { gmem_q.commit(smem_q); } // Make sure we are reading from the correct buffer. smem_q.move_to_next_read_buffer(); smem_qt.move_to_next_read_buffer(); // Make sure the data is in shared memory. __syncthreads(); // Trigger the loads for the values of Q for the next iteration. smem_q.load(frag_q[0], 0); smem_k.load(frag_k[0], 0); smem_qt.load(frag_qt[0], 0); } // Outer loop over the sequence length. // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); smem_dv.store(acc_dv); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; smem_dv.load(dv_out); Qkv_params dv_params; dv_params.qkv_ptr = params.dqkv_ptr; dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; dv_params.h = params.h; Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); gmem_dv.store(dv_out); } template inline __device__ void compute_dq_dk_1xN(const Params& params) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_o = typename Kernel_traits::Cta_tile_o; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_dk = fmha::Cta_tile_extd; static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); static_assert(Cta_tile_dk::N == 64); static_assert(Cta_tile_dk::K == 16); // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; using Mma_tile_o = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_dk = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The shared memory tile to swizzle Q. using Smem_tile_q = typename Kernel_traits::Smem_tile_q; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle K. using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store O. // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; using Gmem_tile_o = fmha::Gmem_tile_dq; // The shared memory tile to swizzle O. using Smem_tile_o = typename Kernel_traits::Smem_tile_o; // The global memory tile to store dK. using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle dK. using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); // The shared memory tile to reload Q transposed. using Smem_tile_qt = fmha::Smem_tile_b; using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Smem_tile_st = typename Kernel_traits::Smem_tile_st; enum { M = Mma_tile_p::MMAS_M }; enum { N = Mma_tile_p::MMAS_N }; static_assert(M == Mma_tile_o::MMAS_M); static_assert(N == Mma_tile_o::MMAS_K); // Shared memory. extern __shared__ char smem_[]; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); if (binfo.stop_early()) return; Mask mask(params, binfo, tidx); // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params, 0, binfo, tidx); // Allocate the shared memory tile loader for Q. Smem_tile_q smem_q(&smem_[0], tidx); Smem_tile_qt smem_qt(&smem_[0], tidx); Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params, 1, binfo, tidx); // Allocate the shared memory tile loader for K. Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params, binfo, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); // Trigger the loads for Q. gmem_q.load(smem_q); // Trigger the loads for K. gmem_k.load(smem_k); Gmem_tile_s gmem_s(params, binfo, tidx); // Load dP uint4 s_regs[M][N]; gmem_s.load(s_regs, mask); gmem_s.move(); // Commit the data for Q and K to shared memory. gmem_q.commit(smem_q); gmem_k.commit(smem_k); // Make sure the data is in shared memory. __syncthreads(); typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; smem_qt.load(frag_qt[0], 0); typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; smem_k.load(frag_k[0], 0); enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; enum { THREADS_PER_ROW = 32 }; enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; // Declare the accumulators for the 2nd gemm. fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; fmha::Clear_accumulator::apply(acc_dk); // Load over the entire sequence length. for (int l = 0; l < STEPS; l++) { const int loop = l * Cta_tile_p::M; if (loop >= binfo.actual_seqlen) break; // Pack dP as Fragment_a fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint4& dst = s_regs[mi][ni]; frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1 frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1 frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9 frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9 } } // Declare the accumulators for the 1st gemm. fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; fmha::Clear_accumulator::apply(acc_o); // Do this part of O = P^T * V^T. dQ = dP x dK #pragma unroll for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_o::MMAS_K; fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); } // Store dP to smem for transpose smem_s.store(s_regs); if (l < STEPS - 1) { // Load next part of S gmem_s.load(s_regs, mask); gmem_s.move(); smem_q.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(smem_q); } // Loop over MMAS_M. #pragma unroll for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { // Swizzle the elements and do the final reduction. smem_o.store(acc_o, ii); // Make sure the data is in shared memory. __syncthreads(); // Load from shared memory. uint4 out[Gmem_tile_o::STGS_PER_LOOP]; smem_o.load(out); // Make sure the data was read from shared memory. if (ii < Gmem_tile_o::LOOPS - 1) { __syncthreads(); } // Output the values. gmem_o.store(out, ii); } // Move to the next part of the output. gmem_o.move(); typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; smem_s.load(frag_s); #pragma unroll for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dk::MMAS_K; fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Commit the values for Q into shared memory. if (l < STEPS - 1) { gmem_q.commit(smem_q); } // Make sure the data is in shared memory. __syncthreads(); // Trigger the loads for the values of Q for the next iteration. smem_qt.load(frag_qt[0], 0); smem_k.load(frag_k[0], 0); } // Outer loop over the sequence length. // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[0], tidx); smem_dk.store(acc_dk); __syncthreads(); uint4 dk_out[Smem_tile_dk::NUM_LDS]; smem_dk.load(dk_out); Qkv_params dk_params; dk_params.qkv_ptr = params.dqkv_ptr; dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; dk_params.h = params.h; Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); gmem_dk.store(dk_out); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dv_1xN_nl(const Params& params) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_dv = fmha::Cta_tile_extd; static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); static_assert(Cta_tile_dv::N == 64); static_assert(Cta_tile_dv::K == 16); // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_dv = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The shared memory tile to swizzle Q. using Smem_tile_q = fmha::Smem_tile_a; // The shared memory tile to reload Q as fragment b. using Smem_tile_qt = fmha::Smem_tile_b; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; // The shared memory tile to swizzle K. using Smem_tile_k = typename Kernel_traits::Smem_tile_k; // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store dV. using Gmem_tile_dv = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle dV. using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Smem_tile_st = typename Kernel_traits::Smem_tile_st; using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; // Shared memory. extern __shared__ char smem_[]; // The block index for the chunk. const int bidc = blockIdx.z; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); if (binfo.stop_early()) return; fmha::Mask mask(params, binfo, tidx); // Allocate the global memory tile loader for Q. Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q // Allocate the shared memory tile loader for Q. Smem_tile_q smem_q(&smem_[0], tidx); Smem_tile_qt smem_qt(&smem_[0], tidx); Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K // Allocate the shared memory tile loader for K. Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); Gmem_tile_s gmem_s(params, binfo, tidx); using Noloop = Noloop_traits; Noloop nl_traits(bidc, binfo); nl_traits.move_all(gmem_q, gmem_s); // Trigger the loads for Q. gmem_q.load(smem_q); // Trigger the loads for K. gmem_k.load(smem_k); // Commit the data for Q and K to shared memory. gmem_q.commit(smem_q); gmem_k.commit(smem_k); // Make sure the data is in shared memory. __syncthreads(); // Load the fragments for Q. typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; smem_q.load(frag_q[0], 0); typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); static_assert(Mma_tile_dv::MMAS_K == 1); smem_qt.load(frag_qt[0], 0); // Load the fragments for K. We keep the data in registers during the entire kernel. typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; smem_k.load(frag_k[0], 0); enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; // Create the object to do the softmax. using Softmax = fmha::Softmax; Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); enum { THREADS_PER_ROW = 32 }; enum { M = Mma_tile_p::MMAS_M }; enum { N = Mma_tile_p::MMAS_N }; // Declare the accumulators for the 2nd gemm. fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; fmha::Clear_accumulator::apply(acc_dv); // Load over the entire sequence length. for (int l = 0; l < nl_traits.num_steps_; l++) { uint4 s_regs[M][N]; gmem_s.load(s_regs, mask); fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Clear_accumulator::apply(acc_p); // Do this part of P^T = (Q * K^T)^T. #pragma unroll for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_q.load(frag_q[ki & 1], ki); smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } smem_s.store(s_regs); // Declare the accumulators for the 1st gemm. // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe if (l < nl_traits.num_steps_ - 1) { smem_q.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(smem_q); } // Convert from the accumulator type to FP32 for Softmax. softmax.unpack(acc_p); float s_mat[2 * M][4 * N]; #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint4& dst = s_regs[mi][ni]; fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); } } #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ii = 0; ii < 2; ii++) { #pragma unroll for (int ni = 0; ni < N; ni++) { #pragma unroll for (int jj = 0; jj < 4; jj++) { float& s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; const bool drop = reinterpret_cast(s_dmask) & 0x80000000; const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; s_dmask = fabsf(s_dmask); softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask); } } } } float p_sum[2 * M]; softmax.reduce_sum(p_sum); const float scalef = reinterpret_cast(params.scale_softmax); #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ii = 0; ii < 2; ii++) { #pragma unroll for (int ni = 0; ni < N; ni++) { #pragma unroll for (int jj = 0; jj < 4; jj++) { softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]); softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; } } } } typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; smem_s.load(frag_s); for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) { for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) { for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) { frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); } } } gmem_s.store(softmax.elt_, mask); gmem_s.move(); static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG #pragma unroll for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dv::MMAS_K; fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Commit the values for Q into shared memory. if (l < nl_traits.num_steps_ - 1) { gmem_q.commit(smem_q); } // Make sure we are reading from the correct buffer. smem_q.move_to_next_read_buffer(); smem_qt.move_to_next_read_buffer(); // Make sure the data is in shared memory. __syncthreads(); // Trigger the loads for the values of Q for the next iteration. smem_q.load(frag_q[0], 0); smem_k.load(frag_k[0], 0); smem_qt.load(frag_qt[0], 0); } // Outer loop over the sequence length. // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this! // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); smem_dv.store(acc_dv); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; smem_dv.load(dv_out); Qkv_params dv_params; dv_params.qkv_ptr = params.dkv_ptr; dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); dv_params.h = params.h; Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx); gmem_dv.store(dv_out); } template inline __device__ void compute_dq_dk_1xN_nl(const Params& params) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Cta_tile_o = typename Kernel_traits::Cta_tile_o; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_dk = fmha::Cta_tile_extd; static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); static_assert(Cta_tile_dk::N == 64); static_assert(Cta_tile_dk::K == 16); // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; using Mma_tile_o = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_dk = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The shared memory tile to swizzle Q. using Smem_tile_q = typename Kernel_traits::Smem_tile_q; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle K. using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = Gmem_tile_dq; // The shared memory tile to swizzle O. using Smem_tile_o = typename Kernel_traits::Smem_tile_o; // The global memory tile to store dK. using Gmem_tile_dk = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle dK. using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); // The shared memory tile to reload Q transposed. using Smem_tile_qt = fmha::Smem_tile_b; // The global memory tile to load dP, stored in S using Gmem_tile_s = Gmem_tile_mma_s; // The shared memory tile to transpose dP. using Smem_tile_st = Smem_tile_mma_transposed; using Noloop = Noloop_traits; enum { M = Mma_tile_p::MMAS_M }; enum { N = Mma_tile_p::MMAS_N }; static_assert(M == Mma_tile_o::MMAS_M); static_assert(N == Mma_tile_o::MMAS_K); // Shared memory. extern __shared__ char smem_[]; const int bidc = blockIdx.z; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); if (binfo.stop_early()) return; fmha::Mask mask(params, binfo, tidx); // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params, 0, binfo, tidx); // Allocate the shared memory tile loader for Q (as B). Smem_tile_qt smem_qt(&smem_[0], tidx); // Allocate the global memory tile loader for dP. Gmem_tile_s gmem_s(params, binfo, tidx); // Allocate the shared memory tile loader for dP. Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params, 1, binfo, tidx); // Allocate the shared memory tile loader for K. Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params, binfo, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); Noloop nl_traits(bidc, binfo); nl_traits.move_all(gmem_q, gmem_o, gmem_s); // Trigger the loads for Q. gmem_q.load(smem_qt); // Trigger the loads for K. gmem_k.load(smem_k); uint4 s_regs[M][N]; gmem_s.load(s_regs, mask); // Commit the data for Q and K to shared memory. gmem_q.commit(smem_qt); gmem_k.commit(smem_k); // Make sure the data is in shared memory. __syncthreads(); typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; smem_qt.load(frag_qt[0], 0); typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; smem_k.load(frag_k[0], 0); enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; enum { THREADS_PER_ROW = 32 }; // Declare the accumulators for the 2nd gemm. fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; fmha::Clear_accumulator::apply(acc_dk); // Load over the entire sequence length. for (int l = 0; l < nl_traits.num_steps_; l++) { // Pack dP as Fragment_a fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint4& dst = s_regs[mi][ni]; frag_p[ni][mi].reg(0) = dst.x; frag_p[ni][mi].reg(1) = dst.z; frag_p[ni][mi].reg(2) = dst.y; frag_p[ni][mi].reg(3) = dst.w; } } smem_s.store(s_regs); if (l < nl_traits.num_steps_ - 1) { // Load next part of S gmem_s.move(); gmem_s.load(s_regs, mask); // Trigger the load for the next Q values. smem_qt.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(smem_qt); } // Declare the accumulators for the 1st gemm. fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; fmha::Clear_accumulator::apply(acc_o); // Do this part of O = P^T * V^T. dQ = dP x dK #pragma unroll for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_o::MMAS_K; fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); } static_assert(Gmem_tile_o::LOOPS == 1); // DEBUG // Loop over MMAS_M. #pragma unroll for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { // Swizzle the elements and do the final reduction. smem_o.store(acc_o, ii); // Make sure the data is in shared memory. __syncthreads(); // Load from shared memory. uint4 out[Gmem_tile_o::STGS_PER_LOOP]; smem_o.load(out); // Make sure the data was read from shared memory. if (ii < Gmem_tile_o::LOOPS - 1) { __syncthreads(); } // Output the values. gmem_o.store(out, ii); } // Move to the next part of the output. gmem_o.move(); typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; smem_s.load(frag_s); static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG #pragma unroll for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dk::MMAS_K; fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Commit the values for Q into shared memory. if (l < nl_traits.num_steps_ - 1) { gmem_q.commit(smem_qt); __syncthreads(); // Trigger the loads for the values of Q for the next iteration. smem_qt.load(frag_qt[0], 0); smem_k.load(frag_k[0], 0); } } // Outer loop over the sequence length. // Epilogue for dK = dP' * dq. We're fully exposed to this! // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[0], tidx); smem_dk.store(acc_dk); __syncthreads(); uint4 dk_out[Smem_tile_dk::NUM_LDS]; smem_dk.load(dk_out); Qkv_params dk_params; dk_params.qkv_ptr = params.dkv_ptr; dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); dk_params.h = params.h; Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx); gmem_dk.store(dk_out); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fill.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include #include #include constexpr int block_size = 512; constexpr int ctas_per_sm = 4; template __global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, const int32_t* const start_row, const size_t num_rows) { size_t row_stride = gridDim.y * blockDim.x; size_t row_index = blockIdx.x + (size_t)start_row[0]; size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; while (row_index < num_rows) { out_tensor[row_index * row_stride + col_index] = 0; row_index += gridDim.x; } } at::Tensor& mha_fill(at::Tensor& self, const at::Tensor& start_index) { auto max_tokens = self.size(0); auto self_2d = self.view({max_tokens, -1}); auto fcd_size = self_2d.size(1); TORCH_CHECK(self.is_contiguous(), "input not contiguous"); TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); uint64_t num_blk_x = (uint64_t)std::ceil(num_mp * ctas_per_sm / num_blk_y); dim3 dim_grid(num_blk_x, num_blk_y); dim3 dim_block(block_size); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_padding_fill_", [&]() { mha_fill_kernel<<>>( self_2d.data_ptr(), start_index.data_ptr(), max_tokens); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); return self; } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; template __global__ void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int total_heads) { fmha::device_1xN(params, total_heads); } void run_fmha_fp16_128_64_sm80(Launch_params& launch_params, const bool configure) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel : &fmha_fprop_fp16_128_64_sm80_kernel; constexpr int smem_size = fmha::get_dynamic_smem_size(); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } const int sm_count = launch_params.props->multiProcessorCount; int ctas_per_sm; FMHA_CHECK_CUDA( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); int total_ctas = sm_count * ctas_per_sm; const int heads_total = launch_params.params.b * launch_params.params.h; if (configure) { using Mma_tile_p = fmha::Hmma_tile; constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; launch_params.elts_per_thread = heads_per_cta * elts_per_head; return; } dim3 grid(total_ctas); kernel<<>>(launch_params.params, heads_total); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; template __global__ void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int total_heads) { fmha::device_1xN(params, total_heads); } void run_fmha_fp16_256_64_sm80(Launch_params& launch_params, const bool configure) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel : &fmha_fprop_fp16_256_64_sm80_kernel; constexpr int smem_size = fmha::get_dynamic_smem_size(); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } const int sm_count = launch_params.props->multiProcessorCount; int ctas_per_sm; FMHA_CHECK_CUDA( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); int total_ctas = sm_count * ctas_per_sm; const int heads_total = launch_params.params.b * launch_params.params.h; if (configure) { using Mma_tile_p = fmha::Hmma_tile; constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; launch_params.elts_per_thread = heads_per_cta * elts_per_head; return; } dim3 grid(total_ctas); kernel<<>>(launch_params.params, heads_total); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>; template __global__ void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int total_heads) { fmha::device_1xN(params, total_heads); } void run_fmha_fp16_384_64_sm80(Launch_params& launch_params, const bool configure) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel : &fmha_fprop_fp16_384_64_sm80_kernel; constexpr int smem_size = fmha::get_dynamic_smem_size(); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } const int sm_count = launch_params.props->multiProcessorCount; int ctas_per_sm; FMHA_CHECK_CUDA( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); int total_ctas = sm_count * ctas_per_sm; const int heads_total = launch_params.params.b * launch_params.params.h; if (configure) { using Mma_tile_p = fmha::Hmma_tile; constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; launch_params.elts_per_thread = heads_per_cta * elts_per_head; return; } dim3 grid(total_ctas); kernel<<>>(launch_params.params, heads_total); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>; template __global__ void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int total_heads) { fmha::device_1xN(params, total_heads); } template __global__ void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params, const int num_full_heads, const int num_main_groups, const int main_group_size, const int main_steps, const int rest_steps) { fmha::device_1xN(params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); } void run_fmha_fp16_512_64_sm80_(Launch_params& launch_params, const bool configure) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel : &fmha_fprop_fp16_512_64_sm80_kernel; constexpr int smem_size = fmha::get_dynamic_smem_size(); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } const int sm_count = launch_params.props->multiProcessorCount; int ctas_per_sm; FMHA_CHECK_CUDA( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); int total_ctas = sm_count * ctas_per_sm; const int heads_total = launch_params.params.b * launch_params.params.h; if (configure) { using Mma_tile_p = fmha::Hmma_tile; constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; launch_params.elts_per_thread = heads_per_cta * elts_per_head; return; } dim3 grid(total_ctas); kernel<<>>(launch_params.params, heads_total); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } void run_fmha_fp16_512_64_sm80_nl_(Launch_params& launch_params, const bool configure) { auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl : &fmha_fprop_fp16_512_64_sm80_kernel_nl; constexpr int smem_size = fmha::get_dynamic_smem_size(); if (smem_size >= 48 * 1024) { FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } const int sm_count = launch_params.props->multiProcessorCount; int ctas_per_sm; FMHA_CHECK_CUDA( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); int total_ctas = sm_count * ctas_per_sm; if (configure) { const int heads_total = launch_params.params.b * launch_params.params.h; std::tie(launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave, launch_params.main_steps, launch_params.rest_steps, launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); return; } dim3 grid(total_ctas); kernel<<>>( launch_params.params, launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave, launch_params.main_steps, launch_params.rest_steps); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } void run_fmha_fp16_512_64_sm80(Launch_params& launch_params, const bool configure) { if (launch_params.is_nl) { run_fmha_fp16_512_64_sm80_nl_(launch_params, configure); } else { run_fmha_fp16_512_64_sm80_(launch_params, configure); } } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h ================================================ /*************************************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include "fmha_kernel.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Gemm_Q_K_base { using Smem_tile_o = typename Kernel_traits::Smem_tile_o; using Smem_tile_q = typename Kernel_traits::Smem_tile_q; using Smem_tile_k = typename Kernel_traits::Smem_tile_k; using Fragment_q = typename Smem_tile_q::Fragment; using Fragment_k = typename Smem_tile_k::Fragment; // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; __device__ inline Gemm_Q_K_base(char* smem_ptr_q, char* smem_ptr_k, const int tidx) : smem_q(smem_ptr_q, tidx), smem_k(smem_ptr_k, tidx) {} __device__ inline void load_q() { smem_q.load(frag_q[0], 0); } __device__ inline void reload_q() { smem_q.load(frag_q[0], 0); } Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; Smem_tile_q smem_q; Smem_tile_k smem_k; }; template struct Gemm_Q_K : public Gemm_Q_K_base { using Base = Gemm_Q_K_base; using Smem_tile_o = typename Base::Smem_tile_o; using Smem_tile_q = typename Base::Smem_tile_q; using Smem_tile_k = typename Base::Smem_tile_k; using Fragment_k = typename Base::Fragment_k; using Mma_tile_p = typename Base::Mma_tile_p; enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE }; enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; // Q | K / V // | O | SOFTMAX static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); __device__ inline Gemm_Q_K(char* smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} __device__ inline void load_k() { #pragma unroll for (int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) { Base::smem_k.load(frag_k[ki], ki); } } template __device__ inline void operator()(Acc (&acc_p)[M][N]) { // Do this part of P^T = (Q * K^T)^T. #pragma unroll for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. Base::smem_q.load(Base::frag_q[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } } __device__ inline void reload_k() { // Noop. } Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; }; template struct Gemm_Q_K : public Gemm_Q_K_base { using Base = Gemm_Q_K_base; using Smem_tile_o = typename Base::Smem_tile_o; using Smem_tile_q = typename Base::Smem_tile_q; using Smem_tile_k = typename Base::Smem_tile_k; using Smem_tile_v = typename Kernel_traits::Smem_tile_v; using Fragment_k = typename Base::Fragment_k; using Mma_tile_p = typename Base::Mma_tile_p; Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; static_assert(Smem_tile_v::BYTES_PER_TILE == (int)Smem_tile_k::BYTES_PER_TILE); enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE }; // Q | K/V + O + SOFTMAX static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; __device__ inline Gemm_Q_K(char* smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {} __device__ inline void load_k() { Base::smem_k.load(frag_k[0], 0); } template __device__ inline void operator()(Acc (&acc_p)[M][N]) { // Do this part of P^T = (Q * K^T)^T. #pragma unroll for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) { // Trigger the load from shared memory for the next series of Q values. Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } } __device__ inline void reload_k() { Base::smem_k.load(frag_k[0], 0); } }; template constexpr size_t get_dynamic_smem_size() { return Gemm_Q_K::SMEM_BYTES; } template inline __device__ void device_1xN_(const Params& params, const int bidb, const int bidh, const int begin, const int steps, Prng& ph) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_o = typename Kernel_traits::Cta_tile_o; // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_o = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; // The shared memory tile to swizzle O. using Smem_tile_o = typename Kernel_traits::Smem_tile_o; using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; // The number of threads per row. enum { THREADS_PER_ROW = 32 }; enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); if (binfo.stop_early()) return; Gemm1 gemm_q_k(smem_, tidx); // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params, 0, binfo, tidx); // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); // Wind gmem tiles to the correct position. for (int it = 0; it < begin; it++) { gmem_q.move(); gmem_s.move(); gmem_o.move(); } fmha::Mask mask(params, binfo, tidx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params, 1, binfo, tidx); // Allocate the global memory tile loader for V. Gmem_tile_v gmem_v(params, 2, binfo, tidx); // The base pointer of smem_v; char* smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! Smem_tile_v smem_v(smem_v_, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); // Trigger the loads for K. gmem_k.load(gemm_q_k.smem_k); // Trigger the loads for Q. gmem_q.load(gemm_q_k.smem_q); // Trigger the loads for V. gmem_v.load(smem_v); const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); #pragma unroll for (int it = 0; it < Gmem_tile_k::LDGS; it++) { gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); } // Commit the data for Q and V to shared memory. gmem_q.commit(gemm_q_k.smem_q); gmem_v.commit(smem_v); // Commit the data for K to shared memory. if (!Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { gmem_k.commit(gemm_q_k.smem_k); } __syncthreads(); // Load the fragments for Q. gemm_q_k.load_q(); // Load the fragments for V. We keep the data in registers during the entire kernel. typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; #pragma unroll for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { smem_v.load(frag_v[ki], ki); } // Commit the data for V to shared memory if it has not been done already. if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V) { // Make sure we are done loading the fragments for K. __syncthreads(); // Commit the data to shared memory for V. gmem_k.commit(gemm_q_k.smem_k); // Make sure the data is in shared memory. __syncthreads(); } // Load the fragments for K. gemm_q_k.load_k(); uint32_t p_scaled = (uint32_t)256.0 * params.p_dropout; // Create the object to do the softmax. Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); // Load over the entire sequence length. for (int l = 0; l < steps; l++) { if (begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break; // Declare the accumulators for the 1st gemm. fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Clear_accumulator::apply(acc_p); // Do this part of P^T = (Q * K^T)^T. gemm_q_k(acc_p); // Trigger the load for the next Q values. if (l < steps - 1) { gemm_q_k.smem_q.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(gemm_q_k.smem_q); } // Load the mask for that iteration. mask.load(begin + l); // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); // Apply the mask. softmax.apply_mask(mask); if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0) { // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction __syncthreads(); } // Compute the max. float p_max[Mma_tile_p::MMAS_M * 2]; // softmax.template reduce(p_max); softmax.reduce_max(p_max); // Compute the exponential value. softmax.apply_exp(p_max); // Compute the sum. float p_sum[Mma_tile_p::MMAS_M * 2]; softmax.reduce_sum(p_sum); // Finalize softmax on the accumulators of P^T. softmax.scale(p_sum); using Frag_p = fmha::Fragment_a; Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; if (Is_training) { auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; }; #pragma unroll for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) { #pragma unroll for (int ii = 0; ii < 2; ii++) { #pragma unroll for (int ni = 0; ni < Mma_tile_p::MMAS_N / 4; ni++) { uint8_t* rand_arr = (uint8_t*)&ph(); // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from // pre-existing zeros for (int ind = 0; ind < 16; ind++) { softmax.elt_[2 * mi + ii][16 * ni + ind] = encode_dropout(rand_arr[ind] <= p_scaled, softmax.elt_[2 * mi + ii][16 * ni + ind]); } } } } softmax.pack(frag_p); gmem_s.store(frag_p, mask); gmem_s.move(); } else { softmax.pack(frag_p); } // Commit the values for Q into shared memory. if (l < steps - 1) { gmem_q.commit(gemm_q_k.smem_q); } if (Is_training) { #pragma unroll for (int ki = 0; ki < Mma_tile_o::MMAS_K; ki++) { #pragma unroll for (int mi = 0; mi < Mma_tile_o::MMAS_M; mi++) { #pragma unroll for (int ii = 0; ii < Frag_p::NUM_REGS; ii++) { //"Apply" the dropout. frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout); frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii)); } } } } // Declare the accumulators for the 1st gemm. fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; fmha::Clear_accumulator::apply(acc_o); // Do this part of O = P^T * V^T. #pragma unroll for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) { fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); } // Loop over MMAS_M. #pragma unroll for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) { // Swizzle the elements and do the final reduction. smem_o.store(acc_o, ii); // Make sure the data is in shared memory. __syncthreads(); // Load from shared memory. uint4 out[Gmem_tile_o::STGS_PER_LOOP]; smem_o.load(out); // Make sure the data was read from shared memory. if (ii < Gmem_tile_o::LOOPS - 1) { __syncthreads(); } // Output the values. gmem_o.store(out, ii); } // Move to the next part of the output. gmem_o.move(); gemm_q_k.reload_k(); // Commit the values for Q into shared memory. if (l < steps - 1) { gemm_q_k.reload_q(); } } // Outer loop over the sequence length. } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void device_1xN(const Params& params, const int num_full_heads, const int num_main_groups, const int main_group_size, const int main_steps, const int rest_steps) { constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); for (int it = 0; it < num_full_heads; it++) { const int bidx = it * gridDim.x + blockIdx.x; const int bidh = bidx % params.h; const int bidb = bidx / params.h; fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); __syncthreads(); } if (main_group_size == 0) return; const int head_offset = num_full_heads * gridDim.x; if (blockIdx.x < main_group_size * num_main_groups) { // process within heads const int group = blockIdx.x % num_main_groups; const int bidx = blockIdx.x / num_main_groups; const int bidh = (head_offset + bidx) % params.h; const int bidb = (head_offset + bidx) / params.h; const int offset = group * main_steps; fmha::device_1xN_(params, bidb, bidh, offset, main_steps, ph); } else { if (rest_steps == 0) return; // process across heads const int bidx = blockIdx.x - main_group_size * num_main_groups; const int offset = num_main_groups * main_steps; const int total_heads = params.b * params.h; const int rest_ctas = gridDim.x - main_group_size * num_main_groups; for (int it = head_offset + bidx; it < total_heads; it += rest_ctas) { const int bidh = it % params.h; const int bidb = it / params.h; fmha::device_1xN_(params, bidb, bidh, offset, rest_steps, ph); __syncthreads(); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void device_1xN(const Params& params, const int total_heads) { const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; for (int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x) { const int bidh = bidx % params.h; const int bidb = bidx / params.h; fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); __syncthreads(); } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_kernel.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include #include #include #include #include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BlockInfoPadded { template __device__ BlockInfoPadded(const Params& params, const int bidb, const int bidh, const int tidx) : bidb(bidb), bidh(bidh), h(params.h) { // The block index. sum_s = params.cu_seqlens[bidb]; actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; bidx = sum_s * params.h + bidh; tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; } __device__ bool stop_early() const { return actual_seqlen == 0; } int actual_seqlen; int bidx; int sum_s; int bidh; int bidb; int tidx_global; int h; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Noloop_traits { // Interpretation of Cta_tile dims, i.e. Cta_tile_p: enum { STEP = Cta_tile::M }; enum { SEQLEN = Cta_tile::N }; template inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) : bidc_(bidc) { const int seqlen = binfo.actual_seqlen; const int steps = (seqlen + STEP - 1) / STEP; const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; const int step_begin = bidc_ * steps_per_chunk; const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); const int actual_steps = max(0, step_end - step_begin); loop_offset_ = step_begin; num_steps_ = actual_steps; } template inline __device__ void move_all(Tiles&... tiles) const { using expand_type = int[]; for (int s = 0; s < loop_offset_; s++) { expand_type{(tiles.move(), 0)...}; } } inline __device__ int get_idx_dk() const { // return bidc_; return bidc_ * 2 + 0; } inline __device__ int get_idx_dv() const { // return CHUNKS + bidc_; return bidc_ * 2 + 1; } inline __device__ int offset_loop_count(const int l) { // convert loop counter to position in the outer sequence return (loop_offset_ + l) * STEP; } const uint32_t bidc_; int loop_offset_; int num_steps_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template std::tuple work_dist(const int total_ctas, const int heads_total) { constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; const int num_full_heads = heads_total / total_ctas; const int heads_last_wave = heads_total % total_ctas; int num_main_groups = 0; int main_steps = 0; int rest_steps = 0; if (heads_last_wave > 0) { // Number of CTA groups that process within heads. num_main_groups = total_ctas / heads_last_wave; // Remaining CTAs that process between heads. const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); if (rest_ctas == 0) { // We have exactly "num_main_groups" CTAs to process each of the remaining heads. main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 rest_steps = STEPS_PER_HEAD % main_steps; } else { // Ideal number of steps if we could load-balance as evenly as possible. const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; // Iterations that a "rest" CTA has to do at most. const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main // CTAs. main_steps = steps_ideal; rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; for (; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++) { rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; const int max_rest_total_steps = rest_steps * max_rest_iters; if (max_rest_total_steps < main_steps) break; } rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; } } using Cta_tile_p = typename Kernel_traits::Cta_tile_p; using Mma_tile_p = fmha::Hmma_tile; const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; const int elts_per_thread = max_steps * elts_per_thread_per_step; return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread}; } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include "fmha.h" inline __device__ float4 ldg128(const void* ptr) { return *static_cast(ptr); } inline __device__ void stg128(void* ptr, const float4& data) { *static_cast(ptr) = data; } template __global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void* __restrict__ out, const void* __restrict__ in, const int* __restrict__ cu_seqlens, const int batch_size) { enum { BYTES_PER_LDG = 16 }; enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) }; // One CTA hidden vector for K and V enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 }; // The stride in bytes in dQKV enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) }; // The offset in bytes in dQKV to the dKV part for non-interleaved heads enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) }; static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); // Size in bytes of the input tile enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW }; enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG }; enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA }; static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW); union Vec_t { float4 raw; T elt[NUM_ELTS]; }; // ZERO-OUT invalid positions in dQKV const int total = cu_seqlens[batch_size]; if (blockIdx.x >= total) { enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) }; enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG }; const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f); char* base_ptr = static_cast(out) + blockIdx.x * OUT_STRIDE_BYTES; for (int tidx = threadIdx.x; tidx < STGS; tidx += THREADS) { stg128(base_ptr + tidx * BYTES_PER_LDG, zeros); } return; } // SETUP const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG; const char* ptr_in = static_cast(in) + offset_in; const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG; char* ptr_out = static_cast(out) + OUT_OFFSET_KV_BYTES + offset_out; // LOAD Vec_t local_in[CHUNKS][LDGS]; #pragma unroll for (int c = 0; c < CHUNKS; c++) { #pragma unroll for (int l = 0; l < LDGS; l++) { int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA; local_in[c][l].raw = ldg128(ptr_in + offset); } } // UNPACK float acc[LDGS][NUM_ELTS]; #pragma unroll for (int l = 0; l < LDGS; l++) { #pragma unroll for (int e = 0; e < NUM_ELTS; e++) { acc[l][e] = float(local_in[0][l].elt[e]); } } // COMPUTE #pragma unroll for (int c = 1; c < CHUNKS; c++) { #pragma unroll for (int l = 0; l < LDGS; l++) { #pragma unroll for (int e = 0; e < NUM_ELTS; e++) { acc[l][e] += float(local_in[c][l].elt[e]); } } } // PACK Vec_t local_out[LDGS]; #pragma unroll for (int l = 0; l < LDGS; l++) { #pragma unroll for (int e = 0; e < NUM_ELTS; e++) { local_out[l].elt[e] = T(acc[l][e]); } } // STORE #pragma unroll for (int l = 0; l < LDGS; l++) { const int offset = l * BYTES_PER_CTA; stg128(ptr_out + offset, local_out[l].raw); } } void fmha_run_noloop_reduce(void* out, const void* in, const int* cu_seqlens, const int hidden_size, const int batch_size, const int total, const int num_chunks, cudaStream_t stream) { const int blocks = total; if (hidden_size == 1024) { constexpr int HIDDEN_SIZE = 1024; constexpr int THREADS = 256; if (num_chunks == 2) { fmha_noloop_reduce_kernel <<>>(out, in, cu_seqlens, batch_size); } else if (num_chunks == 3) { fmha_noloop_reduce_kernel <<>>(out, in, cu_seqlens, batch_size); } else { assert(false && "Unsupported num_chunks"); } } else { assert(false && "Unsupported hidden_size"); } FMHA_CHECK_CUDA(cudaPeekAtLastError()); } ================================================ FILE: apex/contrib/csrc/fmha/src/fmha_utils.h ================================================ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// #define FMHA_CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while (0) //////////////////////////////////////////////////////////////////////////////////////////////////// enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; //////////////////////////////////////////////////////////////////////////////////////////////////// static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) { if (dtype == DATA_TYPE_FP16) { half x = __float2half_rn(norm); uint16_t h = reinterpret_cast(x); ushort2 h2 = {h, h}; alpha = reinterpret_cast(h2); } else if (dtype == DATA_TYPE_FP32) { alpha = reinterpret_cast(norm); } else if (dtype == DATA_TYPE_INT32) { int32_t inorm = static_cast(norm); alpha = reinterpret_cast(inorm); } else { assert(false); } } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline size_t get_size_in_bytes(size_t n, Data_type dtype) { switch (dtype) { case DATA_TYPE_FP32: return n * 4; case DATA_TYPE_FP16: return n * 2; case DATA_TYPE_INT32: return n * 4; case DATA_TYPE_INT8: return n; default: assert(false); return 0; } } //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp ================================================ #include #include #include // CUDA forward declarations std::vector focal_loss_forward_cuda(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, const at::Tensor& num_positives_sum, const int64_t num_real_classes, const float alpha, const float gamma, const float smoothing_factor); at::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Tensor& partial_grad, const at::Tensor& num_positives_sum); // C++ interface #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector focal_loss_forward(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, const at::Tensor& num_positives_sum, const int64_t num_real_classes, const float alpha, const float gamma, const float smoothing_factor) { CHECK_INPUT(cls_output); CHECK_INPUT(cls_targets_at_level); CHECK_INPUT(num_positives_sum); return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, smoothing_factor); } at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& partial_grad, const at::Tensor& num_positives_sum) { CHECK_INPUT(grad_output); CHECK_INPUT(partial_grad); return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &focal_loss_forward, "Focal loss calculation forward (CUDA)", py::call_guard()); m.def("backward", &focal_loss_backward, "Focal loss calculation backward (CUDA)", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu ================================================ #include #include #include // Use 128-bit vectorization typedef uint4 vector_t; #define ASSERT_ALIGNED(DTYPE, PTR) \ TORCH_INTERNAL_ASSERT(is_aligned(PTR), "Tensor " #PTR " is not " #DTYPE " aligned") template bool is_aligned(const void* ptr) noexcept { auto iptr = reinterpret_cast(ptr); return !(iptr % alignof(T)); } template __global__ void focal_loss_forward_cuda_kernel(outscalar_t* loss, scalar_t* partial_grad, const scalar_t* __restrict__ cls_output, const labelscalar_t* __restrict__ cls_targets_at_level, const float* __restrict__ num_positives_sum, const int64_t num_examples, const int64_t num_classes, const int64_t num_real_classes, const float alpha, const float gamma, const float smoothing_factor) { extern __shared__ unsigned char shm[]; accscalar_t* loss_shm = reinterpret_cast(shm); loss_shm[threadIdx.x] = 0; accscalar_t loss_acc = 0; accscalar_t one = accscalar_t(1.0); accscalar_t K = accscalar_t(2.0); accscalar_t normalizer = one / static_cast(num_positives_sum[0]); accscalar_t nn_norm, np_norm, pn_norm, pp_norm; // *_norm is used for label smoothing only if (SMOOTHING) { nn_norm = one - smoothing_factor / K; np_norm = smoothing_factor / K; pn_norm = smoothing_factor - smoothing_factor / K; pp_norm = one - smoothing_factor + smoothing_factor / K; } vector_t p_vec, grad_vec; // Accumulate loss on each thread for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) { int64_t idy = i / num_classes; labelscalar_t y = cls_targets_at_level[idy]; int64_t base_yid = i % num_classes; int64_t pos_idx = idy * num_classes + y; p_vec = *(vector_t*)&cls_output[i]; // Vectorized load // Skip ignored matches if (y == -2) { #pragma unroll for (int j = 0; j < ILP; j++) { *((scalar_t*)(&grad_vec) + j) = 0; } *(vector_t*)&partial_grad[i] = grad_vec; continue; } #pragma unroll for (int j = 0; j < ILP; j++) { // Skip the pad classes if (base_yid + j >= num_real_classes) { *((scalar_t*)(&grad_vec) + j) = 0; continue; } accscalar_t p = static_cast(*((scalar_t*)(&p_vec) + j)); accscalar_t exp_np = ::exp(-p); accscalar_t exp_pp = ::exp(p); accscalar_t sigma = one / (one + exp_np); accscalar_t logee = (p >= 0) ? exp_np : exp_pp; accscalar_t addee = (p >= 0) ? 0 : -p; accscalar_t off_a = addee + ::log(one + logee); // Negative matches accscalar_t base = SMOOTHING ? nn_norm * p : p; accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma; accscalar_t coeff_f1 = one - alpha; accscalar_t coeff_f2 = sigma; accscalar_t coeff_b1 = gamma; accscalar_t coeff_b2 = one - sigma; // Positive matches if (y >= 0 && (i + j == pos_idx)) { base = SMOOTHING ? pn_norm * p : 0; off_b = (SMOOTHING ? pp_norm : one) - sigma; coeff_f1 = alpha; coeff_f2 = one - sigma; coeff_b1 = -gamma; coeff_b2 = sigma; } accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma); accscalar_t coeff_b = coeff_b1 * coeff_b2; accscalar_t loss_t = coeff_f * (base + off_a); accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b); // Delay the normalize of partial gradient by num_positives_sum to back // propagation because scalar_t reduces precision. Focal loss is very // sensitive to the small gradient. No worry on overflow here since // gradient has relative smaller range than input. loss_acc += loss_t; *((scalar_t*)(&grad_vec) + j) = static_cast(grad); } // This may generate two vectorized stores instead of one *(vector_t*)&partial_grad[i] = grad_vec; } loss_shm[threadIdx.x] = loss_acc; // Intra-CTA reduction __syncthreads(); for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s]; } __syncthreads(); } // Inter-CTA reduction if (threadIdx.x == 0) { loss_acc = loss_shm[0] * normalizer; atomicAdd(loss, loss_acc); } } template __global__ void focal_loss_backward_cuda_kernel(scalar_t* partial_grad, const outscalar_t* __restrict__ grad_output, const float* __restrict__ num_positives_sum, const uint64_t numel) { int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; accscalar_t normalizer = static_cast(grad_output[0]) / static_cast(num_positives_sum[0]); // The input is enforced to pad to use vector load, thus there's no need to // check whether the last element of ILP can out of bound. if (idx >= numel) return; vector_t grad_vec; grad_vec = *(vector_t*)&partial_grad[idx]; #pragma unroll(ILP) for (int i = 0; i < ILP; i++) { auto grad = static_cast(*((scalar_t*)(&grad_vec) + i)); grad *= normalizer; *((scalar_t*)(&grad_vec) + i) = static_cast(grad); } *(vector_t*)&partial_grad[idx] = grad_vec; } std::vector focal_loss_forward_cuda(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, const at::Tensor& num_positives_sum, const int64_t num_real_classes, const float alpha, const float gamma, const float smoothing_factor) { // Checks required for correctness TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, "Incorrect number of real classes."); TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, "Invalid label type."); TORCH_INTERNAL_ASSERT((num_positives_sum.numel() == 1) && (num_positives_sum.scalar_type() == at::kFloat), "Expect num_positives_sum to be a float32 tensor with only one element."); TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1, "Mis-matched dimensions between class output and label."); for (int64_t i = 0; i < cls_targets_at_level.dim(); i++) TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i), "Mis-matched shape between class output and label."); // Checks required for better performance const int ILP = sizeof(vector_t) / cls_output.element_size(); ASSERT_ALIGNED(vector_t, cls_output.data_ptr()); TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0, "Pad number of classes first to take advantage of vectorized load."); TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes."); int64_t num_classes = cls_output.size(-1); int64_t num_examples = cls_output.numel() / num_classes; at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat)); // Compute the incompelete gradient during fprop since most of the heavy // functions of bprop are the same as fprop, thus trade memory for compute // helps with focal loss. at::Tensor partial_grad = at::empty_like(cls_output); // Set the number of CTAs per SM according to the compute capability. // Each CTA loops on input with stride till the last item. cudaDeviceProp props; cudaGetDeviceProperties(&props, at::cuda::current_device()); int cta_per_sm = 2; if (props.major >= 10) { cta_per_sm = 8; } dim3 block(512); dim3 grid(cta_per_sm * props.multiProcessorCount); // Specialize on label smoothing or not to reduce redundant operations cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (smoothing_factor == 0.0f) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), "focal_loss_fprop", [&] { using accscalar_t = at::acc_type; using labelscalar_t = int64_t; using outscalar_t = float; const int ILP = sizeof(vector_t) / sizeof(scalar_t); focal_loss_forward_cuda_kernel <<>>( loss.data_ptr(), partial_grad.data_ptr(), cls_output.data_ptr(), cls_targets_at_level.data_ptr(), num_positives_sum.data_ptr(), num_examples, num_classes, num_real_classes, alpha, gamma, smoothing_factor); }); } else { AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), "focal_loss_fprop", [&] { using accscalar_t = at::acc_type; using labelscalar_t = int64_t; using outscalar_t = float; const int ILP = sizeof(vector_t) / sizeof(scalar_t); focal_loss_forward_cuda_kernel <<>>( loss.data_ptr(), partial_grad.data_ptr(), cls_output.data_ptr(), cls_targets_at_level.data_ptr(), num_positives_sum.data_ptr(), num_examples, num_classes, num_real_classes, alpha, gamma, smoothing_factor); }); } AT_CUDA_CHECK(cudaGetLastError()); return {loss, partial_grad}; } at::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Tensor& partial_grad, const at::Tensor& num_positives_sum) { // Each thread process ILP elements const int ILP = sizeof(vector_t) / partial_grad.element_size(); dim3 block(512); dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP)); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF(partial_grad.scalar_type(), "focal_loss_bprop", [&] { using accscalar_t = at::acc_type; using outscalar_t = float; const int ILP = sizeof(vector_t) / sizeof(scalar_t); focal_loss_backward_cuda_kernel <<>>(partial_grad.data_ptr(), grad_output.data_ptr(), num_positives_sum.data_ptr(), partial_grad.numel()); }); AT_CUDA_CHECK(cudaGetLastError()); return partial_grad; } ================================================ FILE: apex/contrib/csrc/gpu_direct_storage/gds.cpp ================================================ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. #include // torch #include #include // cuda #include #include // file io #include namespace apex::contrib::gds { // POSIX template ::value, std::nullptr_t>::type = nullptr> std::string cuFileGetErrorString(T status) { status = std::abs(status); return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) : std::string(std::strerror(errno)); } // CUfileError_t template ::value, std::nullptr_t>::type = nullptr> std::string cuFileGetErrorString(T status) { std::string errStr = cuFileGetErrorString(static_cast(status.err)); if (IS_CUDA_ERR(status)) errStr.append(".").append(cudaGetErrorString(static_cast(status.cu_err))); return errStr; } File::File() : is_open(false) {}; File::File(const std::string& filename, const std::string& mode) : filename(filename), mode(mode), is_open(false) { open(filename, mode); } File::~File() { if (is_open) { close(); } } void File::open(const std::string& other_filename, const std::string& other_mode) { TORCH_CHECK(is_open == false, "file", filename, "is already open"); if (!filename.empty()) { TORCH_CHECK(other_filename == filename, "file", filename, "is already open with mode", mode); } if (!mode.empty()) { TORCH_CHECK(other_mode == mode, "file", filename, "is already open with mode", mode); } maybe_register = true; // Open the binary file if (mode == "r") { // for reading fd = ::open(filename.c_str(), O_RDONLY | O_DIRECT); } else if (mode == "w") { // for writing fd = ::open(filename.c_str(), O_CREAT | O_WRONLY | O_DIRECT, 0664); } else if (mode == "rn") { // for reading fd = ::open(filename.c_str(), O_RDONLY); maybe_register = false; } else if (mode == "wn") { // for writing fd = ::open(filename.c_str(), O_CREAT | O_WRONLY, 0664); maybe_register = false; } else { TORCH_CHECK(false, "only r and w modes are currently supported, but got:", mode); } TORCH_CHECK(fd >= 0, "fcntl cannot open file: ", filename); // Register cuFile handle if (maybe_register) { memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); cf_descr.handle.fd = fd; cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; status = cuFileHandleRegister(&cf_handle, &cf_descr); if (status.err != CU_FILE_SUCCESS) { TORCH_CHECK(false, "cuFileHandleRegister failed: ", cuFileGetErrorString(status)); } } is_open = true; } void File::close() { // Deregister cuFile handle and close the file if (is_open) { if (maybe_register) { cuFileHandleDeregister(cf_handle); } ::close(fd); fd = -1; } is_open = false; } void File::load_data(const torch::Tensor& tensor) { TORCH_CHECK(mode == "r", filename, " was opened for read only"); c10::cuda::CUDAGuard gpuGuard(tensor.device()); void* dataPtr = tensor.data_ptr(); const size_t nbytes = tensor.nbytes(); // Read the binary file ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, 0, 0); TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuFileGetErrorString(ret)); } void File::save_data(const torch::Tensor& tensor) { TORCH_CHECK(mode == "w", filename, " was opened for write only"); c10::cuda::CUDAGuard gpuGuard(tensor.device()); void* dataPtr = tensor.data_ptr(); const size_t nbytes = tensor.nbytes(); // Register device memory status = cuFileBufRegister(dataPtr, nbytes, 0); TORCH_CHECK(status.err == CU_FILE_SUCCESS, "cuFileBufRegister failed: ", cuFileGetErrorString(status)); // Write device memory contents to the file ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, 0, 0); status = cuFileBufDeregister(dataPtr); TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuFileGetErrorString(ret)); TORCH_CHECK(status.err == CU_FILE_SUCCESS, "cuFileBufDeregister failed:", cuFileGetErrorString(status)); } // Just for benchmarking purposes void File::load_data_no_gds(const torch::Tensor& tensor) { TORCH_CHECK(mode == "rn", filename, " was opened for read only"); c10::cuda::CUDAGuard gpuGuard(tensor.device()); void* dataPtrCPU = nullptr; void* dataPtr = tensor.data_ptr(); const size_t nbytes = tensor.nbytes(); dataPtrCPU = malloc(nbytes); TORCH_CHECK(dataPtrCPU != nullptr, "malloc failed"); const ssize_t nbytes_read = pread(fd, dataPtrCPU, nbytes, 0); TORCH_CHECK(nbytes_read == nbytes || nbytes_read == 0, "fcntl pread failed"); C10_CUDA_CHECK(cudaMemcpy(dataPtr, dataPtrCPU, nbytes, cudaMemcpyHostToDevice)); free(dataPtrCPU); } void File::save_data_no_gds(const torch::Tensor& tensor) { TORCH_CHECK(mode == "wn", filename, " was opened for write only"); c10::cuda::CUDAGuard gpuGuard(tensor.device()); void* dataPtrCPU = nullptr; void* dataPtr = tensor.data_ptr(); const size_t nbytes = tensor.nbytes(); dataPtrCPU = malloc(nbytes); TORCH_CHECK(dataPtrCPU != nullptr, "malloc failed"); C10_CUDA_CHECK(cudaMemcpy(dataPtrCPU, dataPtr, nbytes, cudaMemcpyDeviceToHost)); const ssize_t nbytes_written = pwrite(fd, dataPtrCPU, nbytes, 0); TORCH_CHECK(nbytes_written == nbytes, "fcntl pwrite failed"); free(dataPtrCPU); } } // namespace apex::contrib::gds ================================================ FILE: apex/contrib/csrc/gpu_direct_storage/gds.h ================================================ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. #pragma once #include #include #include namespace apex::contrib::gds { class File { public: File(); File(const std::string& filename, const std::string& mode); ~File(); void open(const std::string& filename, const std::string& mode); void close(); void load_data(const torch::Tensor& tensor); void save_data(const torch::Tensor& tensor); void load_data_no_gds(const torch::Tensor& tensor); void save_data_no_gds(const torch::Tensor& tensor); private: std::string filename; std::string mode; CUfileDescr_t cf_descr; CUfileHandle_t cf_handle; CUfileError_t status; int fd = -1; bool is_open = false; bool maybe_register = true; }; } // namespace apex::contrib::gds ================================================ FILE: apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp ================================================ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. #include #include #include #include // python bindings PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_>(m, "_GDSFile") .def(py::init<>()) .def(py::init()) .def("open", &apex::contrib::gds::File::open) .def("close", &apex::contrib::gds::File::close) .def("load_data", &apex::contrib::gds::File::load_data) .def("save_data", &apex::contrib::gds::File::save_data) .def("load_data_no_gds", &apex::contrib::gds::File::load_data_no_gds) .def("save_data_no_gds", &apex::contrib::gds::File::save_data_no_gds); } ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc.cpp ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include #include #include #include #include #include template float inline unpack(const T& x) { return {}; } template <> float inline unpack(const __half& x) { return __half2float(x); } template <> float inline unpack(const __nv_bfloat16& x) { return __bfloat162float(x); } template <> float inline unpack(const float& x) { return x; } //////////////////////////////////////////////////////////////////////////////////////////////////// template void check_results(const char* name, const T* out, const T* ref, size_t elts, float tol) { // The number of errors. int failed = 0; // The number of infinite value. int infs = 0; // The min/max values. float min_val = FLT_MAX, max_val = -FLT_MAX, max_err = 0.f; // The total sum of error. double sum_err = 0.0; // The case we are checking. printf("\e[1;34mchecking.....................: %s\e[0m\n", name); fflush(stdout); // Iterate over the different values. for (size_t ii = 0; ii < elts; ++ii) { float a = unpack(out[ii]); float b = unpack(ref[ii]); // Compute the absolute norms. float abs_a = fabsf(a); float abs_b = fabsf(b); // Compute the error. float den = abs_a + abs_b; // Is one of the quantities very small? bool is_small = abs_a <= tol || abs_b <= tol || den <= tol; // The error. float err = is_small ? fabsf(a - b) : fabsf(a - b) / den; // Is the result ok? bool ok = !isnan(a) && !isnan(b) && err <= tol; // Print the error. if (!ok && (failed < 10 || err > max_err)) { fprintf(stderr, ">> invalid result for ii=%lu:\n", ii); if (std::is_same::value || std::is_same::value) { // The data. fprintf(stderr, ">> found...: 0x%04x (%10.6f)\n", reinterpret_cast(out[ii]), a); fprintf(stderr, ">> expected: 0x%04x (%10.6f)\n", reinterpret_cast(ref[ii]), b); } else if (std::is_same::value) { fprintf(stderr, ">> found...: 0x%08x (%10.6f)\n", reinterpret_cast(a), a); fprintf(stderr, ">> expected: 0x%08x (%10.6f)\n", reinterpret_cast(b), b); } else { fprintf(stderr, "\e[1;34mUnknown type of check_results\e[0m\n"); exit(1); } fprintf(stderr, ">> error...: %.6f\n", err); } // Update the number of failures. failed += ok ? 0 : 1; // Measure min/max errors. min_val = fminf(min_val, a); max_val = fmaxf(max_val, a); max_err = fmaxf(max_err, err); // Accumulate the sum. sum_err = sum_err + (double)err; infs += !isfinite(a); infs += !isfinite(b); } if (!failed && infs < 10) { printf("\e[1;32mcheck........................: OK\e[0m\n"); } else { printf("\e[1;31mcheck........................: FAILED\e[0m\n"); } printf("tested.......................: %lu\n", elts); printf("failures.....................: %d\n", failed); printf("failure rate.................: %.2lf%%\n", (double)failed * 100.0 / (double)elts); printf("infs.........................: %d\n", infs); printf("tolerance....................: %.8f\n", tol); printf("\n"); printf("min. value...................: %.6f\n", min_val); printf("max. value...................: %.6f\n", max_val); printf("max. error...................: %.6f\n", max_err); printf("sum. error...................: %.6lf\n", sum_err); printf("avg. error...................: %.6lf\n", sum_err / (double)elts); printf("\n"); } template void check_results(const char* name, const __half* out, const __half* ref, size_t elts, float tol); template void check_results(const char* name, const __nv_bfloat16* out, const __nv_bfloat16* ref, size_t elts, float tol); template void check_results(const char* name, const float* out, const float* ref, size_t elts, float tol); //////////////////////////////////////////////////////////////////////////////////////////////////// static void group_norm_nhwc_bwd_(void* dx_h, float* dgamma_h, float* dbeta_h, const void* dy_h, const void* x_h, const float* gamma_h, const float* beta_h, const float2* sums_h, float epsilon, int n, int h, int w, int c, int groups, bool with_swish, bool use_fp32, bool use_bf16) { // The number of channels in each group. int channels_per_group = c / groups; // The normalization term to compute the means. float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group); // The array to compute gamma. float* dgamma = (float*)malloc(c * sizeof(float)); // The array to compute beta. float* dbeta = (float*)malloc(c * sizeof(float)); // Set gamma/beta to 0. memset(dgamma, 0, c * sizeof(float)); memset(dbeta, 0, c * sizeof(float)); // Normalize the activations. for (int ni = 0; ni < n; ++ni) { for (int gi = 0; gi < groups; ++gi) { // The sums from the fwd pass. float2 sums = sums_h[ni * groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = sums.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). float x_sq_mean = sums.y; // The variance. float x_var = x_sq_mean - x_mean * x_mean; // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)). float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + epsilon); // TODO: We should store rcp_x_stddev instead of the sums of squares. // The following nested loops compute 2 means. float mean_1 = 0.f, mean_2 = 0.f; // Iterate over the activations in the group. for (int hi = 0; hi < h; ++hi) { for (int wi = 0; wi < w; ++wi) { for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Convert the element at that position to float. float x; if (use_fp32) { x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { x = __half2float(reinterpret_cast(x_h)[offset]); } // The output. float dy; if (use_fp32) { dy = reinterpret_cast(dy_h)[offset]; } else if (use_bf16) { dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); } else { dy = __half2float(reinterpret_cast(dy_h)[offset]); } // Gamma. float gamma = gamma_h[ci]; // X - X_mean. float x_minus_x_mean = x - x_mean; // Normalize X. float x_norm = x_minus_x_mean * rcp_x_stddev; if (with_swish) { // Beta float beta = beta_h[ci]; float x_gn = x_norm * gamma + beta; float s = sigmoid(x_gn); dy = dy * s * (1.f + x_gn * (1.f - s)); } // Compute the gradient for beta. dbeta[ci] += dy; // Compute the gradient for gamma. dgamma[ci] += dy * x_norm; // The gradient that enters the x_norm node. float dx_norm = dy * gamma; // Accumulators over 2 means mean_1 += x_norm * dx_norm; mean_2 += dx_norm; } // ii } // wi } // hi mean_1 *= rcp_hwc_per_group; mean_2 *= rcp_hwc_per_group; // Iterate over the activations in the group. for (int hi = 0; hi < h; ++hi) { for (int wi = 0; wi < w; ++wi) { for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; float x; if (use_fp32) { x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { x = __half2float(reinterpret_cast(x_h)[offset]); } // The output. float dy; if (use_fp32) { dy = reinterpret_cast(dy_h)[offset]; } else if (use_bf16) { dy = __bfloat162float(reinterpret_cast(dy_h)[offset]); } else { dy = __half2float(reinterpret_cast(dy_h)[offset]); } // Gamma. float gamma = gamma_h[ci]; // X - X_mean. float x_minus_x_mean = x - x_mean; // Normalize X. float x_norm = x_minus_x_mean * rcp_x_stddev; if (with_swish) { // Beta float beta = beta_h[ci]; float x_gn = x_norm * gamma + beta; float s = sigmoid(x_gn); dy = dy * s * (1.f + x_gn * (1.f - s)); } // The gradient that enters the x_norm node. float dx_norm = dy * gamma; // Input gradient float dx = (dx_norm - (x_norm * mean_1 + mean_2)) * rcp_x_stddev; // Set the output gradient. if (use_fp32) { reinterpret_cast(dx_h)[offset] = dx; } else if (use_bf16) { reinterpret_cast<__nv_bfloat16*>(dx_h)[offset] = __float2bfloat16_rn(dx); } else { reinterpret_cast<__half*>(dx_h)[offset] = __float2half_rn(dx); } } // ii } // wi } // hi } // gi } // ni // Store gamma/beta. for (int ci = 0; ci < c; ++ci) { dgamma_h[ci] = dgamma[ci]; dbeta_h[ci] = dbeta[ci]; } // Release temporary memory. free(dgamma); free(dbeta); } //////////////////////////////////////////////////////////////////////////////////////////////////// static void group_norm_nhwc_fwd_(void* y_h, const void* x_h, const float* gamma_h, const float* beta_h, float epsilon, int n, int h, int w, int c, int groups, bool with_swish, bool use_fp32, bool use_bf16) { // The number of channels in each group. int channels_per_group = c / groups; // The normalization term to compute the means. float inv_hwcg = 1.f / (float)(h * w * channels_per_group); // Normalize the activations. for (int ni = 0; ni < n; ++ni) { for (int gi = 0; gi < groups; ++gi) { // The sums to compute the mean/variance for that group. float sum = 0.f, sum_sq = 0.f; // Iterate over the activations in the group. for (int hi = 0; hi < h; ++hi) { for (int wi = 0; wi < w; ++wi) { for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Convert the element at that position to float. float x; if (use_fp32) { x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { x = __half2float(reinterpret_cast(x_h)[offset]); } // Update the sums. sum += x; sum_sq += x * x; } // ii } // wi } // hi // Compute the mean. float mean = sum * inv_hwcg; // Compute the average value for the squares. float mean_sq = sum_sq * inv_hwcg; // Compute the variance. float var = mean_sq - (mean * mean); // Invert the variance. float inv_stddev = var <= 0.f ? 1.f : (1.f / sqrtf(var + epsilon)); // Iterate over the data to normalize the output. for (int hi = 0; hi < h; ++hi) { for (int wi = 0; wi < w; ++wi) { for (int ii = 0; ii < channels_per_group; ++ii) { // The channel. int ci = gi * channels_per_group + ii; // Compute the src/dst offset. size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci; // Normalize. float x; if (use_fp32) { x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { x = __bfloat162float(reinterpret_cast(x_h)[offset]); } else { x = __half2float(reinterpret_cast(x_h)[offset]); } float y = (x - mean) * inv_stddev; // Scale with gamma and add beta. y = y * gamma_h[ci] + beta_h[ci]; // Apply swish (if needed). if (with_swish) { y = y * sigmoid(y); } // Store the result. if (use_fp32) { reinterpret_cast(y_h)[offset] = y; } else if (use_bf16) { reinterpret_cast<__nv_bfloat16*>(y_h)[offset] = __float2bfloat16_rn(y); } else { reinterpret_cast<__half*>(y_h)[offset] = __float2half_rn(y); } } // ii } // wi } // hi } // gi } // ni } //////////////////////////////////////////////////////////////////////////////////////////////////// template void random_data(T* dst_h, size_t n, bool use_1s, int range = 3) { for (size_t ii = 0; ii < n; ++ii) { float x = 1.f; if (!use_1s) { x = (float)(rand() % range - (range / 2)); } if (std::is_same::value) { dst_h[ii] = __float2half_rn(x); } else if (std::is_same::value) { dst_h[ii] = x; } else if (std::is_same::value) { dst_h[ii] = __float2bfloat16_rn(x); } else { fprintf(stderr, "\e[1;34mUnknown type of random_data\e[0m\n"); exit(1); } } } template void random_data(float* dst_h, size_t n, bool use_1s, int range); template void random_data(__half* dst_h, size_t n, bool use_1s, int range); template void random_data(__nv_bfloat16* dst_h, size_t n, bool use_1s, int range); //////////////////////////////////////////////////////////////////////////////////////////////////// enum class Mode { FWD_INFERENCE, FWD_TRAINING, BWD }; //////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char** argv) { // The tensor size. int n = 2, h = 64, w = 64, c = 320, groups = 32; // The default mode is inference. Mode mode = Mode::FWD_INFERENCE; // The constant epsilon for sqrt(var + epsilon). float epsilon = 1.e-5f; // Do we fuse with the Swish activation function? bool with_swish = false; // Do we use the one-pass kernel? bool use_one_pass = false; // The number of runs to time the code. int runs = 1; // Do we use 1s for the input data. bool use_1s = false; // The tolerance to check the results. float tol = 1.e-3f; // Do we skip the checks? bool skip_checks = false; // Do we output csv format only bool csv_output = false; // Use fp32 IO bool use_fp32 = false; // Use bf16 IO bool use_bf16 = false; // Parse the parameters. for (int ii = 1; ii < argc; ++ii) { if (!strcmp(argv[ii], "-1s")) { use_1s = true; } else if (!strcmp(argv[ii], "-bwd")) { mode = Mode::BWD; } else if (!strcmp(argv[ii], "-c") && ++ii < argc) { c = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-epsilon") && ++ii < argc) { epsilon = (float)strtod(argv[ii], nullptr); } else if (!strcmp(argv[ii], "-fwd")) { mode = Mode::FWD_INFERENCE; } else if (!strcmp(argv[ii], "-fwd-tr")) { mode = Mode::FWD_TRAINING; } else if (!strcmp(argv[ii], "-groups") && ++ii < argc) { groups = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-h") && ++ii < argc) { h = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-n") && ++ii < argc) { n = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-one-pass")) { use_one_pass = true; } else if (!strcmp(argv[ii], "-runs") && ++ii < argc) { runs = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-skip-checks")) { skip_checks = true; } else if (!strcmp(argv[ii], "-tol") && ++ii < argc) { tol = (float)strtod(argv[ii], nullptr); } else if (!strcmp(argv[ii], "-w") && ++ii < argc) { w = strtol(argv[ii], nullptr, 10); } else if (!strcmp(argv[ii], "-with-swish")) { with_swish = true; } else if (!strcmp(argv[ii], "-csv")) { csv_output = true; } else if (!strcmp(argv[ii], "-fp32")) { use_fp32 = true; } else if (!strcmp(argv[ii], "-bf16")) { use_bf16 = true; } else if (ii < argc) { fprintf(stderr, "Unknown argument: %s\n", argv[ii]); return 1; } else { fprintf(stderr, "Argument %s requires a value\n", argv[ii - 1]); return 1; } } if (use_bf16 && use_fp32) { fprintf(stderr, "Can't use fp32 and bf16 IO at the same time\n"); return 1; } // Header. if (!csv_output) { printf("\n"); printf("#######################################################################\n"); printf("# Group Norm NHWC + Swish kernel\n"); printf("# --\n"); printf("# Compiled on %s\n", __DATE__); printf("#######################################################################\n"); printf("\n"); } // GPU info. cudaDeviceProp props; CHECK_CUDA(cudaGetDeviceProperties(&props, 0)); if (!csv_output) { printf("device.......................: %s\n", props.name); printf("cc...........................: %d.%d\n", props.major, props.minor); printf("# of sms.....................: %d\n", props.multiProcessorCount); } // Dram peak bandwidth. float dram_clock = props.memoryClockRate / 1.e6f; float dram_peak = 2.f * dram_clock * props.memoryBusWidth / 8.f; if (!csv_output) { printf("dram clock...................: %.3f GHz\n", dram_clock); printf("dram peak....................: %.3f TB/s\n", dram_peak * 1.e-3f); printf("\n"); } // Output the problem size. if (!csv_output) { printf("n............................: %d\n", n); printf("h............................: %d\n", h); printf("w............................: %d\n", w); printf("c............................: %d\n", c); printf("groups.......................: %d\n", groups); printf("epsilon......................: %f\n", epsilon); printf("with swish...................: %s\n", with_swish ? "true" : "false"); printf("channels per group...........: %d\n", c / groups); if (mode == Mode::BWD) { printf("mode.........................: bwd\n"); } else if (mode == Mode::FWD_INFERENCE) { printf("mode.........................: fwd inference\n"); } else if (mode == Mode::FWD_TRAINING) { printf("mode.........................: fwd training\n"); } else { assert(false); } printf("\n"); } // Compute the SOL. double bytes = 0; int32_t io_bytes = use_fp32 ? sizeof(float) : sizeof(__half); if (mode != Mode::BWD) { bytes = (double)n * h * w * c * io_bytes + // src (double)c * 4 + // gamma (double)c * 4 + // beta (double)n * h * w * c * io_bytes; // out } else { bytes = (double)n * h * w * c * io_bytes * 2 + // src, dsrc (double)c * 4 * 2 + // gamma, dgamma (double)c * 4 * 2 + // beta, dbeta (double)n * h * w * c * io_bytes * 1; // dout } double gbytes = bytes * 1.e-9; double dram_sol = gbytes / dram_peak * 1.e3; if (!csv_output) { printf("bytes........................: %.3lfGB\n", gbytes); printf("dram sol.....................: %.6lfms\n", dram_sol); // The number of runs to measure performance. printf("runs.........................: %d\n", runs); printf("\n"); } // The number of elements in the x tensor. The layout is N x H x W x C. size_t x_elts = (size_t)n * h * w * c; // The size of the src in bytes. size_t x_sz = x_elts * io_bytes; // Allocate the src/dst on the host. void* x_h = malloc(x_sz); void* y_h = malloc(x_sz); // Allocate src/dst on the device. void *x_d, *y_d; CHECK_CUDA(cudaMalloc((void**)&x_d, x_sz)); CHECK_CUDA(cudaMalloc((void**)&y_d, x_sz)); // The number of elements in the gamma/beta array. size_t gamma_elts = (size_t)c; // The size of the gamma/beta array in bytes. size_t gamma_sz = gamma_elts * sizeof(float); // Allocate gamma/beta on the host. float* gamma_h = (float*)malloc(gamma_sz); // Allocate gamma/beta on the device. float* gamma_d; CHECK_CUDA(cudaMalloc((void**)&gamma_d, gamma_sz)); // Allocate gamma/beta on the host. float* beta_h = (float*)malloc(gamma_sz); // Allocate gamma/beta on the device. float* beta_d; CHECK_CUDA(cudaMalloc((void**)&beta_d, gamma_sz)); // Allocate the reference on the host (to be computed on the host). void* y_ref_h = nullptr; if (!skip_checks) { y_ref_h = malloc(x_sz); } // Allocate the src/dst on the host for the gradients (bwd). void *dx_h = nullptr, *dy_h = nullptr; if (mode == Mode::BWD) { dx_h = malloc(x_sz); dy_h = malloc(x_sz); } // Allocate src/dst on the device. void *dx_d = nullptr, *dy_d = nullptr; if (mode == Mode::BWD) { CHECK_CUDA(cudaMalloc((void**)&dx_d, x_sz)); CHECK_CUDA(cudaMalloc((void**)&dy_d, x_sz)); } // The gradients for gamma and beta on the host. float *dgamma_h = nullptr, *dbeta_h = nullptr; if (mode == Mode::BWD) { dgamma_h = (float*)malloc(gamma_sz); dbeta_h = (float*)malloc(gamma_sz); } // The gradients for gamma and beta on the device. float *dgamma_d = nullptr, *dbeta_d = nullptr; if (mode == Mode::BWD) { CHECK_CUDA(cudaMalloc((void**)&dgamma_d, gamma_sz)); CHECK_CUDA(cudaMalloc((void**)&dbeta_d, gamma_sz)); } // The number of sums for the bwd pass. size_t sums_elts = mode == Mode::FWD_INFERENCE ? 0 : n * groups; // The size needed to store that array. size_t sums_sz = sums_elts * sizeof(float2); // The sums for the bwd pass on the host. float2* sums_h = nullptr; if (sums_sz > 0) { sums_h = (float2*)malloc(sums_sz); } // The sums for the bwd pass on the device. float2* sums_d = nullptr; if (sums_sz > 0) { CHECK_CUDA(cudaMalloc((void**)&sums_d, sums_sz)); } // Allocate the reference on the host (to be computed on the host). void* dx_ref_h = nullptr; if (mode == Mode::BWD && !skip_checks) { dx_ref_h = malloc(x_sz); } // Allocate the reference on the host (to be computed on the host). float *dgamma_ref_h = nullptr, *dbeta_ref_h = nullptr; if (mode == Mode::BWD && !skip_checks) { dgamma_ref_h = (float*)malloc(gamma_sz); dbeta_ref_h = (float*)malloc(gamma_sz); } // Generate random input data for the forward pass. if (use_fp32) { random_data(reinterpret_cast(x_h), x_elts, use_1s); } else if (use_bf16) { random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(x_h), x_elts, use_1s); } else { random_data<__half>(reinterpret_cast<__half*>(x_h), x_elts, use_1s); } random_data(gamma_h, gamma_elts, use_1s); random_data(beta_h, gamma_elts, use_1s); // Generate the gradients for the bwd pass. if (mode == Mode::BWD) { if (use_fp32) { random_data(reinterpret_cast(dy_h), x_elts, use_1s); } else if (use_bf16) { random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(dy_h), x_elts, use_1s); } else { random_data<__half>(reinterpret_cast<__half*>(dy_h), x_elts, use_1s); } } // Precompute the sums (from the fwd pass) for bwd. if (mode == Mode::BWD) { // Clear the array of sums (all the elements are set to 0.f). memset(sums_h, 0, sums_sz); // The number of channels in each group. int channels_per_group = c / groups; // Iterate over the different groups. for (int ni = 0; ni < n; ++ni) { for (int gi = 0; gi < groups; ++gi) { for (int hi = 0; hi < h; ++hi) { for (int wi = 0; wi < w; ++wi) { for (int ii = 0; ii < channels_per_group; ++ii) { // The position of the channel. int ci = gi * channels_per_group + ii; // The offset to the element. int64_t offset = (int64_t)ni * h * w * c + hi * w * c + wi * c + ci; // The element in float. float x; if (use_fp32) { x = reinterpret_cast(x_h)[offset]; } else if (use_bf16) { x = __bfloat162float(reinterpret_cast<__nv_bfloat16*>(x_h)[offset]); } else { x = __half2float(reinterpret_cast<__half*>(x_h)[offset]); } // Update the sums (sum of X and sum of squares). sums_h[ni * groups + gi].x += x; sums_h[ni * groups + gi].y += x * x; } } } } } // The normalization term to compute the means. float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group); // Normalize the sums. for (int ngi = 0; ngi < n * groups; ++ngi) { sums_h[ngi].x *= rcp_hwc_per_group; sums_h[ngi].y *= rcp_hwc_per_group; } } // Compute the golden reference on the host. if (!skip_checks) { if (mode == Mode::BWD) { group_norm_nhwc_bwd_(dx_ref_h, dgamma_ref_h, dbeta_ref_h, dy_h, x_h, gamma_h, beta_h, sums_h, epsilon, n, h, w, c, groups, with_swish, use_fp32, use_bf16); } else { group_norm_nhwc_fwd_(y_ref_h, x_h, gamma_h, beta_h, epsilon, n, h, w, c, groups, with_swish, use_fp32, use_bf16); } } // Copy to the device. CHECK_CUDA(cudaMemcpyAsync(x_d, x_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); CHECK_CUDA(cudaMemcpyAsync(gamma_d, gamma_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); CHECK_CUDA(cudaMemcpyAsync(beta_d, beta_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); if (mode == Mode::BWD) { CHECK_CUDA(cudaMemcpyAsync(dy_d, dy_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); // // DEBUG. // printf("sums_h[0] = %8.3f, %8.3f\n", sums_h[0].x, sums_h[0].y); // // END OF DEBUG. CHECK_CUDA(cudaMemcpyAsync(sums_d, sums_h, sums_sz, cudaMemcpyHostToDevice, cudaStreamDefault)); } // Reset the output buffer with garbage to detect invalid results. if (mode == Mode::BWD) { CHECK_CUDA(cudaMemsetAsync(dx_d, 0xdc, x_sz, cudaStreamDefault)); CHECK_CUDA(cudaMemsetAsync(dgamma_d, 0xdc, gamma_sz, cudaStreamDefault)); CHECK_CUDA(cudaMemsetAsync(dbeta_d, 0xdc, gamma_sz, cudaStreamDefault)); } else { CHECK_CUDA(cudaMemsetAsync(y_d, 0xdc, x_sz, cudaStreamDefault)); } // Declare the parameters. Group_norm_nhwc_fwd_params params_fwd; memset(¶ms_fwd, 0, sizeof(params_fwd)); Group_norm_nhwc_bwd_params params_bwd; memset(¶ms_bwd, 0, sizeof(params_bwd)); const auto precision = [&]() -> PrecisionMode { if (use_fp32) { return PrecisionMode::FP32IOFP32W; } else if (use_bf16) { return PrecisionMode::BF16IOFP32W; } else { return PrecisionMode::FP16IOFP32W; } }(); // Initialize the parameters. if (mode == Mode::BWD) { params_bwd.dx = dx_d; params_bwd.dgamma = dgamma_d; params_bwd.dbeta = dbeta_d; params_bwd.sums = sums_d; params_bwd.dy = dy_d; params_bwd.x = x_d; params_bwd.gamma = gamma_d; params_bwd.beta = beta_d; params_bwd.epsilon = epsilon; params_bwd.n = n; params_bwd.h = h; params_bwd.w = w; params_bwd.c = c; params_bwd.groups = groups; params_bwd.with_swish = with_swish; params_bwd.precision = precision; } else { params_fwd.y = y_d; params_fwd.sums = sums_d; params_fwd.x = x_d; params_fwd.gamma = gamma_d; params_fwd.beta = beta_d; params_fwd.epsilon = epsilon; params_fwd.n = n; params_fwd.h = h; params_fwd.w = w; params_fwd.c = c; params_fwd.groups = groups; params_fwd.with_swish = with_swish; params_fwd.precision = precision; } // The number of barriers. size_t barriers_elts = 0; // The number of elements in the reduction buffer. size_t red_buffer_elts = 0; // The number of elements in the reduction buffer that must be zeroed. size_t zeroed_red_buffer_elts = 0; // Finalize the parameters. dim3 grid; if (mode == Mode::BWD && use_one_pass) { group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props); } else if (mode == Mode::BWD) { group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts); } else if (use_one_pass) { group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props); } else { group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts); } // The size in bytes for the reduction buffer. size_t red_buffer_sz = red_buffer_elts * sizeof(float); // Allocate on the device. if (red_buffer_sz > 0) { float** ptr = mode == Mode::BWD ? ¶ms_bwd.red_buffer : ¶ms_fwd.red_buffer; CHECK_CUDA(cudaMalloc((void**)ptr, red_buffer_sz)); } // The size of the array of barriers. size_t barriers_sz = barriers_elts * sizeof(int); // The size in bytes for the reduction buffer that must be zeroed. size_t zeroed_red_buffer_sz = barriers_sz + zeroed_red_buffer_elts * sizeof(float); // Allocate the buffer if needed. void* zeroed_red_buffer_d_ = nullptr; if (zeroed_red_buffer_sz > 0) { CHECK_CUDA(cudaMalloc((void**)&zeroed_red_buffer_d_, zeroed_red_buffer_sz)); } // The buffer of barriers. DO NOT CALL cudaFree on it!!! int* barriers_d = reinterpret_cast(zeroed_red_buffer_d_); // The zeroed red buffer. DO NOT CALL cudaFree on it!!! float* zeroed_red_buffer_d = reinterpret_cast(&barriers_d[barriers_elts]); // Must be aligned on 4B for floats. It obviously is (unless someone changes the code ;)). assert(reinterpret_cast(zeroed_red_buffer_d) % sizeof(float) == 0); // Set the barriers if needed. if (mode == Mode::BWD) { params_bwd.barriers = barriers_d; params_bwd.zeroed_red_buffer = zeroed_red_buffer_d; } else { params_fwd.barriers = barriers_d; params_fwd.zeroed_red_buffer = zeroed_red_buffer_d; } // Create events to time the reference code. cudaEvent_t start, stop; CHECK_CUDA(cudaEventCreate(&start)); CHECK_CUDA(cudaEventCreate(&stop)); // Time the reference code. CHECK_CUDA(cudaEventRecord(start)); for (int ii = 0; ii < runs; ++ii) { // Clear the zeroed buffer if needed. if (zeroed_red_buffer_sz > 0) { CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_, 0, zeroed_red_buffer_sz, cudaStreamDefault)); } if (use_one_pass && mode == Mode::BWD) { group_norm_nhwc_bwd_one_pass_run(params_bwd, grid, cudaStreamDefault); } else if (use_one_pass) { group_norm_nhwc_fwd_one_pass_run(params_fwd, grid, cudaStreamDefault); } else if (mode == Mode::BWD) { group_norm_nhwc_bwd_two_passes_sum(params_bwd, cudaStreamDefault); group_norm_nhwc_bwd_two_passes_scale(params_bwd, cudaStreamDefault); } else { group_norm_nhwc_fwd_two_passes_sum(params_fwd, cudaStreamDefault); group_norm_nhwc_fwd_two_passes_scale(params_fwd, cudaStreamDefault); } } CHECK_CUDA(cudaEventRecord(stop)); CHECK_CUDA(cudaDeviceSynchronize()); // Print the runtime. float elapsed = 0.f; CHECK_CUDA(cudaEventElapsedTime(&elapsed, start, stop)); if (!csv_output) { printf("elapsed......................: %.3fms\n", elapsed); printf("elapsed per run..............: %.3fms\n", elapsed / (float)runs); printf("efficiency...................: %.3lf%%\n", dram_sol * runs / elapsed * 100.0); printf("\n"); } // Copy the results to the host. if (mode == Mode::BWD) { CHECK_CUDA(cudaMemcpyAsync(dx_h, dx_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); CHECK_CUDA(cudaMemcpyAsync(dgamma_h, dgamma_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); CHECK_CUDA(cudaMemcpyAsync(dbeta_h, dbeta_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); } else { CHECK_CUDA(cudaMemcpyAsync(y_h, y_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault)); } // Make sure the data has been transferred. CHECK_CUDA(cudaStreamSynchronize(cudaStreamDefault)); // Check the results. if (!csv_output) { if (mode == Mode::BWD && !skip_checks) { if (use_fp32) { check_results("dx", reinterpret_cast(dx_h), reinterpret_cast(dx_ref_h), x_elts, tol); } else if (use_bf16) { check_results<__nv_bfloat16>("dx", reinterpret_cast<__nv_bfloat16*>(dx_h), reinterpret_cast<__nv_bfloat16*>(dx_ref_h), x_elts, tol); } else { check_results<__half>("dx", reinterpret_cast<__half*>(dx_h), reinterpret_cast<__half*>(dx_ref_h), x_elts, tol); } check_results("dgamma", dgamma_h, dgamma_ref_h, gamma_elts, tol); check_results("dbeta", dbeta_h, dbeta_ref_h, gamma_elts, tol); } else if (!skip_checks) { if (use_fp32) { check_results("y", reinterpret_cast(y_h), reinterpret_cast(y_ref_h), x_elts, tol); } else if (use_bf16) { check_results<__nv_bfloat16>("y", reinterpret_cast<__nv_bfloat16*>(y_h), reinterpret_cast<__nv_bfloat16*>(y_ref_h), x_elts, tol); } else { check_results<__half>("y", reinterpret_cast<__half*>(y_h), reinterpret_cast<__half*>(y_ref_h), x_elts, tol); } } } else { printf("%d,%d,%d,%d,%d,%d,%d,%f\n", n, h, w, c, groups, (uint32_t)use_one_pass, (uint32_t)mode, elapsed / (float)runs); } // Destroy the cuda events. CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(stop)); // Release device memory. CHECK_CUDA(cudaFree(x_d)); CHECK_CUDA(cudaFree(y_d)); CHECK_CUDA(cudaFree(gamma_d)); CHECK_CUDA(cudaFree(beta_d)); CHECK_CUDA(cudaFree(dx_d)); CHECK_CUDA(cudaFree(dy_d)); CHECK_CUDA(cudaFree(dgamma_d)); CHECK_CUDA(cudaFree(dbeta_d)); CHECK_CUDA(cudaFree(sums_d)); CHECK_CUDA(cudaFree(zeroed_red_buffer_d_)); CHECK_CUDA(cudaFree(params_bwd.red_buffer)); CHECK_CUDA(cudaFree(params_fwd.red_buffer)); // Release host memory. free(x_h); free(y_h); free(gamma_h); free(beta_h); free(dx_h); free(dy_h); free(dgamma_h); free(dbeta_h); free(sums_h); free(y_ref_h); free(dx_ref_h); free(dgamma_ref_h); free(dbeta_ref_h); // Release the GPU. CHECK_CUDA(cudaDeviceReset()); return 0; } //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc.h ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once #include #include #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while (0) //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ __host__ int div_up(int m, int n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ __host__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ void spin_wait_(int* barrier, int step, int expected) { // THE FOLLOWING CODE MUST BE EXECUTED BY A SINGLE THREAD IN THE CTA. // Update the global counter. Make sure prior writes are visible. asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster. for (volatile int found = -1; found != expected;) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Input type followed by parameter type enum PrecisionMode { FP32IOFP16W, FP32IOBF16W, FP32IOFP32W, FP16IOFP16W, FP16IOBF16W, FP16IOFP32W, BF16IOFP16W, BF16IOBF16W, BF16IOFP32W, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_sums { // Is it the 1st element of the group? int flag; // The sum. float sum; // The sum of squares. float sum_sq; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_sums_op { inline __device__ Group_sums operator()(const Group_sums& a, const Group_sums& b) { Group_sums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_norm_nhwc_fwd_params { // The output buffer. Layout NHWC. void* y; // The sums for the bwd pass. Not written if it is a nullptr. float2* sums; // The input buffer. Layout NHWC. const void* x; // The gamma scaling factor. const void* gamma; // The beta term to add in GN. const void* beta; // The constant epsilon for sqrt(var + epsilon). float epsilon; // The barriers for the persistent kernel. int* barriers; // The extra storage for multi-CTA reductions as well as to pass data to the bwd. float *red_buffer, *zeroed_red_buffer; // The number of instances in the batch. int n; // The height and width of each activation map. The number of channels. int64_t h, w, c, hw, hwc; // The number of groups. int groups; // Do we apply the Swish activation function? bool with_swish; // Precomputed values and parameters to control the execution of the kernels. // The number of batch instances per block. int instances_per_block; // The number of activations computed per block. int acts_per_block; // The number of groups in each block. int groups_per_block; // The number of channels per group = c / groups. int channels_per_group; // The number of channels per block = groups_per_block * channels_per_group. int channels_per_block; // The inverse of hwc in floats (to compute mean/var). float inv_hwc_per_group; // IO precision PrecisionMode precision; }; //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&, size_t& red_buffer_elts); //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params&, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params&, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// struct Group_norm_nhwc_bwd_params { // The output buffer. Layout NHWC. void* dx; // The output buffer. Layout NHWC. void* dgamma; // The output buffer. Layout NHWC. void* dbeta; // The input buffer. Layout NHWC. const void* dy; // The input buffer. Layout NHWC. const void* x; // The gamma scaling factor. const void* gamma; // The beta term to add in GN. const void* beta; // The sums from the fwd pass. const float2* sums; // The constant epsilon for sqrt(var + epsilon). float epsilon; // The barriers for the persistent kernel. int* barriers; // The extra storage for multi-CTA reductions as well as to pass data to the bwd. float *red_buffer, *zeroed_red_buffer; // The number of instances in the batch. int n; // The height and width of each activation map. The number of channels. int64_t h, w, c, hw, hwc; // The number of groups. int groups; // Do we apply the Swish activation function? bool with_swish; // Precomputed values and parameters to control the execution of the kernels. // The number of batch instances per block. int instances_per_block; // The number of activations computed per block. int acts_per_block; // The number of groups in each block. int groups_per_block; // The number of channels per group = c / groups. int channels_per_group; // The number of channels per block = groups_per_block * channels_per_group. int channels_per_block; // The inverse of hwc in floats (to compute mean/var). float inv_hwc_per_group; // IO precision PrecisionMode precision; }; //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params&, size_t& red_buffer_elts); //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params&, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params&, cudaStream_t); //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // B A C K W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_BWD_SELECT(FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \ assert(false && "Not implemented"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_BWD_RUNNER_SELECT(function) GN_BWD_SELECT(_run, function) #define GN_BWD_BLOCKS_PER_SM_SELECT(function) GN_BWD_SELECT(_blocks_per_sm, function) //////////////////////////////////////////////////////////////////////////////////////////////////// GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128) GN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160) //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params& params, size_t& barriers_elts, size_t& red_buffer_elts, size_t& zeroed_red_buffer_elts, dim3& grid, const cudaDeviceProp& props) { // The pre-computed dimensions. params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define how many activations are computed per block. if ((params.hw >= 1024 && params.channels_per_group >= 80) || (params.hw >= 256 && params.channels_per_group >= 160)) { params.acts_per_block = 8 * 16; } else if (params.hw >= 512) { params.acts_per_block = 32 * 16; } else if (params.hw >= 256) { params.acts_per_block = 16 * 16; } else if (params.hw >= 128) { params.acts_per_block = 8 * 16; } else if (params.hw > 0) { params.acts_per_block = 8 * 8; } else { // We should never be here if params are set correctly. assert(false); } // Define the number of blocks per activation map. TODO: Make sure it matches the kernel sizes. int blocks_per_slice = div_up(params.hw, params.acts_per_block); // Select the kernel. using Function_t = int (*)(); Function_t blocks_per_sm_function; GN_BWD_BLOCKS_PER_SM_SELECT(blocks_per_sm_function); // The number of blocks that can be run per SM. int blocks_per_sm = blocks_per_sm_function(); // The number of blocks per grid. int max_blocks_per_grid = blocks_per_sm * props.multiProcessorCount; // Make sure we are safe to run that many blocks assert(blocks_per_slice <= max_blocks_per_grid); // The number of blocks per slice is the X dimension of the grid. grid.x = blocks_per_slice; // The number of groups * is the X dimension of the grid. grid.y = std::min(max_blocks_per_grid / blocks_per_slice, params.groups * params.n); // The number of barriers. barriers_elts = blocks_per_slice > 1 ? grid.y * 2 : 0; // Add 1 for the final conversion for dgamma/dbeta. barriers_elts += 1; // The number of elements in the reduction buffer (for the sums and sums of squared). if (blocks_per_slice == 1) { red_buffer_elts = 0; } else { // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats. red_buffer_elts = 2 * grid.x * grid.y * 2; } // The number of elements in the buffer that has to be zeroed. zeroed_red_buffer_elts = params.c * 2; // Make sure a group does not span multiple blocks. assert(params.channels_per_block % params.channels_per_group == 0); } inline void group_norm_nhwc_bwd_one_pass_run(const Group_norm_nhwc_bwd_params& params, const dim3& grid, cudaStream_t stream) { using Function_t = void (*)(const Group_norm_nhwc_bwd_params&, const dim3&, cudaStream_t); Function_t runner; GN_BWD_RUNNER_SELECT(runner); runner(params, grid, stream); } ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // B A C K W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_bwd_one_pass_kernel( Group_norm_nhwc_bwd_params params) { // The IO traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The Weights traits. using WTraits = typename Traits::WTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // Weights type using WType = typename WTraits::Type; // Weights doubled type using WType2 = typename WTraits::Type2; // The number of activations per block. constexpr int ACTS_PER_BLOCK = ACTS_PER_BLOCK_; // The number of channels per group. constexpr int CHANNELS_PER_GROUP = CHANNELS_PER_GROUP_; // The number of threads per block. constexpr int THREADS_PER_BLOCK = THREADS_PER_BLOCK_; // The number of channels per thread (load fp16x2 numbers). constexpr int CHANNELS_PER_THREAD = 2; // The number of threads needed per activation. constexpr int THREADS_PER_ACT = CHANNELS_PER_GROUP / CHANNELS_PER_THREAD; // The number of activations that are loaded per loop. constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT; // The number of rows per thread. constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP; // The number of active threads. constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT; // The object in charge of doing the sums for the block. typedef cub::BlockReduce Block_reduce; // Allocate shared memory for Block_reduce. __shared__ typename Block_reduce::TempStorage temp_storage; // Allocate shared memory to store the sums. __shared__ float2 smem_sums; // Allocate shared memory to store the gamma/beta gradients. __shared__ float4 smem_dgamma_dbeta[THREADS_PER_BLOCK]; // Shared memory to store the gradients for gamma and beta. // The first activation loaded by that thread. int hwi = blockIdx.x * params.acts_per_block + threadIdx.x / THREADS_PER_ACT; // The first channel loaded by that thread. int ci = threadIdx.x % THREADS_PER_ACT * CHANNELS_PER_THREAD; // Is it an active thread? const bool is_active = threadIdx.x < ACTIVE_THREADS; // Iterate over the iterms in the batch. for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) { // The instance and the group. TODO: Use fast divmod? int ni = ngi / params.groups; int gi = ngi % params.groups; // The sums from the fwd pass. float2 fwd = params.sums[ngi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = fwd.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). float x_sq_mean = fwd.y; // The variance. float x_var = x_sq_mean - x_mean * x_mean; // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)). float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon); // The offset to the first activation loaded by that thread. const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci; // The pointer to the first activation loaded by that thread. const IOType* x_ptr = &reinterpret_cast(params.x)[offset]; // The pointer to the first gradient loaded by that thread. const IOType* dy_ptr = &reinterpret_cast(params.dy)[offset]; // Load the X and dY into registers. IOType2 x[ACTS_PER_THREAD], dy[ACTS_PER_THREAD]; #pragma unroll for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { int hwj = hwi + ii * ACTS_PER_LOOP; x[ii] = IOTraits::zero(); dy[ii] = IOTraits::zero(); if (is_active && hwj < params.hw) { x[ii] = *reinterpret_cast(&x_ptr[hwj * params.c]); dy[ii] = *reinterpret_cast(&dy_ptr[hwj * params.c]); } } // Load gamma as well. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); if (is_active) { gamma_f2 = WTraits::unpack(*reinterpret_cast( &reinterpret_cast(params.gamma)[gi * CHANNELS_PER_GROUP + ci])); if (params.with_swish) { beta_f2 = WTraits::unpack(*reinterpret_cast( &reinterpret_cast(params.beta)[gi * CHANNELS_PER_GROUP + ci])); } } // Gradients for gamma and beta (for this particular group). float4 dgamma_dbeta = make_float4(0.f, 0.f, 0.f, 0.f); // Accumulated gradients for dgrad calculation. float mean_1 = 0.f, mean_2 = 0.f; // Compute the sum and the sum of squares for each thread. #pragma unroll for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Convert x to float. float2 x_f2 = IOTraits::unpack(x[ii]); // Convert dY to float. float2 dy_f2 = IOTraits::unpack(dy[ii]); // X - X_mean. float x_minus_x_mean_x = x_f2.x - x_mean; float x_minus_x_mean_y = x_f2.y - x_mean; // Normalize X. float x_norm_x = x_minus_x_mean_x * rcp_x_stddev; float x_norm_y = x_minus_x_mean_y * rcp_x_stddev; if (params.with_swish) { float x_gn_x = x_norm_x * gamma_f2.x + beta_f2.x; float x_gn_y = x_norm_y * gamma_f2.y + beta_f2.y; float s_x = sigmoid(x_gn_x); float s_y = sigmoid(x_gn_y); dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x)); dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y)); } // Update beta. dgamma_dbeta.z += dy_f2.x; dgamma_dbeta.w += dy_f2.y; // Update dgamma. dgamma_dbeta.x += dy_f2.x * x_norm_x; dgamma_dbeta.y += dy_f2.y * x_norm_y; // The gradient that enters the x_norm node. float dx_norm_x = dy_f2.x * gamma_f2.x; float dx_norm_y = dy_f2.y * gamma_f2.y; // Add to the 1st mean. mean_1 += dx_norm_x * x_norm_x; mean_1 += dx_norm_y * x_norm_y; // Add to the 2nd mean. mean_2 += dx_norm_x; mean_2 += dx_norm_y; } // Pack valid gradients. float2 sums = make_float2(0.f, 0.f); if (ACTIVE_THREADS == THREADS_PER_BLOCK || is_active) { sums = make_float2(mean_1, mean_2); } // Store dgamma and dbeta to shared memory. smem_dgamma_dbeta[threadIdx.x] = dgamma_dbeta; // Compute the sums for the block. sums = Block_reduce(temp_storage).Reduce(sums, [](const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); }); // Make sure we can read gamma/beta from smemory. Block_reduce uses one syncthread already. __syncthreads(); // Compute gamma/beta for the block. if (threadIdx.x < THREADS_PER_ACT) { for (int ii = 1; ii < ACTS_PER_LOOP; ++ii) { float4 other = smem_dgamma_dbeta[threadIdx.x + ii * THREADS_PER_ACT]; dgamma_dbeta.x += other.x; dgamma_dbeta.y += other.y; dgamma_dbeta.z += other.z; dgamma_dbeta.w += other.w; } } // The position in the channel dimension - 2 channels per thread. int cj = gi * THREADS_PER_ACT + threadIdx.x; // The reduction buffer dfor gamma/dbeta. float* red_buffer_dgamma_dbeta = ¶ms.zeroed_red_buffer[cj]; // The first threads store their gradients for gamma/beta. if (threadIdx.x < THREADS_PER_ACT) { atomicAdd(&red_buffer_dgamma_dbeta[0 * params.c / 2], dgamma_dbeta.x); atomicAdd(&red_buffer_dgamma_dbeta[1 * params.c / 2], dgamma_dbeta.y); atomicAdd(&red_buffer_dgamma_dbeta[2 * params.c / 2], dgamma_dbeta.z); atomicAdd(&red_buffer_dgamma_dbeta[3 * params.c / 2], dgamma_dbeta.w); } // The block leader stores to global memory, if needed. if (gridDim.x > 1) { // The index of the buffer. int red_buffer_idx = step & 1; // The barrier. int* barrier = ¶ms.barriers[red_buffer_idx * gridDim.y + blockIdx.y]; // The offset to the reduction buffer. int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2; // The reduction buffer. float2* red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); // The offset to the reduction buffer for dgamma/dbeta. // The first thread stores its sums. if (threadIdx.x == 0) { red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums; } // Make sure the data is in memory. if (threadIdx.x == 0) { spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x); } __syncthreads(); // Update the sums. for (int ii = 0; ii < gridDim.x; ++ii) { if (ii != blockIdx.x && threadIdx.x == 0) { float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y]; sums.x += other_sums.x; sums.y += other_sums.y; } } } // Store the result for other threads. if (threadIdx.x == 0) { smem_sums = sums; } // Make sure the sums are in shared memory. __syncthreads(); // Read the 1st mean from shared memory. mean_1 = smem_sums.x; // Read the 2nd mean from shared memory. mean_2 = smem_sums.y; mean_1 *= params.inv_hwc_per_group; mean_2 *= params.inv_hwc_per_group; // The pointer to the first activation stored by that thread. IOType* dx_ptr = &reinterpret_cast(params.dx)[offset]; // Iterate over the activations to normalize the activations and store the results. for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Convert x to float. float2 x_f2 = IOTraits::unpack(x[ii]); // Convert dY to float. float2 dy_f2 = IOTraits::unpack(dy[ii]); // X - X_mean. float2 x_minus_x_mean_f2; x_minus_x_mean_f2.x = x_f2.x - x_mean; x_minus_x_mean_f2.y = x_f2.y - x_mean; // Normalize X. float2 x_norm_f2; x_norm_f2.x = x_minus_x_mean_f2.x * rcp_x_stddev; x_norm_f2.y = x_minus_x_mean_f2.y * rcp_x_stddev; if (params.with_swish) { float x_gn_x = x_norm_f2.x * gamma_f2.x + beta_f2.x; float x_gn_y = x_norm_f2.y * gamma_f2.y + beta_f2.y; float s_x = sigmoid(x_gn_x); float s_y = sigmoid(x_gn_y); dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x)); dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y)); } // The gradient that enters the x_norm node. float2 dx_norm; dx_norm.x = dy_f2.x * gamma_f2.x; dx_norm.y = dy_f2.y * gamma_f2.y; // The gradient along the input path. float2 dx; dx.x = (dx_norm.x - (x_norm_f2.x * mean_1 + mean_2)) * rcp_x_stddev; dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev; // Store the scaled values. int hwj = hwi + ii * ACTS_PER_LOOP; if (is_active && hwj < params.hw) { *reinterpret_cast(&dx_ptr[hwj * params.c]) = IOTraits::pack(dx); } } } // The completion barrier. int* barrier = ¶ms.barriers[gridDim.x == 1 ? 0 : gridDim.y * 2]; // Mark the completion of the threadblock. if (threadIdx.x == 0) { asm volatile("red.release.gpu.global.add.s32 [%0], 1;" ::"l"(barrier)); } // Exit if that's not the last thread block. if (blockIdx.x != gridDim.x - 1 || blockIdx.y != gridDim.y - 1) { return; } // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster. if (threadIdx.x == 0) { for (int found = -1; found != gridDim.x * gridDim.y;) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } __syncthreads(); // The last block converts dgamma and dbeta to half. for (int idx = threadIdx.x; idx < params.c / 2; idx += THREADS_PER_BLOCK) { // Load dgamma. float2 dgamma; dgamma.x = params.zeroed_red_buffer[idx + 0 * params.c / 2]; dgamma.y = params.zeroed_red_buffer[idx + 1 * params.c / 2]; // Load dbeta. float2 dbeta; dbeta.x = params.zeroed_red_buffer[idx + 2 * params.c / 2]; dbeta.y = params.zeroed_red_buffer[idx + 3 * params.c / 2]; // Store to global memory. *reinterpret_cast(&reinterpret_cast(params.dgamma)[idx * 2]) = WTraits::pack(dgamma); *reinterpret_cast(&reinterpret_cast(params.dbeta)[idx * 2]) = WTraits::pack(dbeta); } } ////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // B A C K W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params) { // The IO traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The Weights traits. using WTraits = typename Traits::WTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // Weights type using WType = typename WTraits::Type; // Weights doubled type using WType2 = typename WTraits::Type2; // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan Block_scan; // Allocate shared memory for Block_scan. __shared__ typename Block_scan::TempStorage temp_storage; // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2; // The group that thread works on and the channel in the group (modulus). int gi = ci / params.channels_per_group; // The sums from the fwd pass. float2 fwd = params.sums[ni * params.groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = fwd.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). float x_sq_mean = fwd.y; // The variance. float x_var = x_sq_mean - x_mean * x_mean; // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)). float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon); // Load gamma. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); if (ci < params.c) { gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); if (params.with_swish) { beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } } // The group that thread works on and the channel in the group (modulus). int gj = threadIdx.x * 2 / params.channels_per_group; int cj = threadIdx.x * 2 - params.channels_per_group * gj; // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // The gradients for gamma/beta. float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f); // Accumulated gradients for dgrad calculation float mean_1 = 0.f, mean_2 = 0.f; // Iterate over the activations to compute the sums. for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The offset. int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 x_v2 = IOTraits::zero(); IOType2 dy_v2 = IOTraits::zero(); if (ci < params.c) { x_v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); } // Extract the two half values. float2 x_f2 = IOTraits::unpack(x_v2); float2 dy_f2 = IOTraits::unpack(dy_v2); // X - X_mean. float x_minus_x_mean_x = x_f2.x - x_mean; float x_minus_x_mean_y = x_f2.y - x_mean; // Normalize X. float x_norm_x = x_minus_x_mean_x * rcp_x_stddev; float x_norm_y = x_minus_x_mean_y * rcp_x_stddev; if (params.with_swish) { float x_gn_x = x_norm_x * gamma_f2.x + beta_f2.x; float x_gn_y = x_norm_y * gamma_f2.y + beta_f2.y; float s_x = sigmoid(x_gn_x); float s_y = sigmoid(x_gn_y); dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x)); dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y)); } // Update beta. dbeta.x += dy_f2.x; dbeta.y += dy_f2.y; // Update dgamma. dgamma.x += dy_f2.x * x_norm_x; dgamma.y += dy_f2.y * x_norm_y; // The gradient that enters the x_norm node. float dx_norm_x = dy_f2.x * gamma_f2.x; float dx_norm_y = dy_f2.y * gamma_f2.y; // Add to the 1st mean. mean_1 += dx_norm_x * x_norm_x; mean_1 += dx_norm_y * x_norm_y; // Add to the 2nd mean. mean_2 += dx_norm_x; mean_2 += dx_norm_y; } // The data for the summations. Group_sums inp{cj == 0 ? 1 : 0, mean_1, mean_2}; // Do the segmented scan. Group_sums out; Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op()); // Store the results for the groups in shared memory (to produce coalesced stores later). if (cj == params.channels_per_group - 2 /* 2 channels per thread */) { smem[gj] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); // The global group index. int gk = blockIdx.x * params.groups_per_block + threadIdx.x; // The first threads (those storing to global memory, load the values). float2 sums = smem[threadIdx.x]; // Store to global memory. if (threadIdx.x < params.groups_per_block && gk < params.groups) { atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x); atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y); } // The base pointer for the gradients for gamma and beta. float* dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n * params.groups * 2]; // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved. int ck = blockIdx.x * params.channels_per_block + threadIdx.x; // Store dgamma and dbeta as well. if (ck < params.c) { atomicAdd(&dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck], dgamma.x); atomicAdd(&dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck], dgamma.y); atomicAdd(&dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck], dbeta.x); atomicAdd(&dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck], dbeta.y); } } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params& params, size_t& zeroed_red_buffer_elts) { // The pre-computed dimensions. params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; if (params.c >= 1280) { blocks_per_act_slice = 128 / params.n; } else if (params.c >= 640) { blocks_per_act_slice = 256 / params.n; } else { blocks_per_act_slice = 512 / params.n; } // Clamp to at least 1 to avoid divide-by-zero when batch size is large. blocks_per_act_slice = max(blocks_per_act_slice, 1); // Make sure we launch blocks per activation is no less than activations blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n)); // Define how many activations are computed per block. params.acts_per_block = div_up(params.hw, blocks_per_act_slice); // The number of channels per block. params.channels_per_block = 320; // Special case to deal with 30 channels per group. if (params.channels_per_block % params.channels_per_group != 0) { params.channels_per_block = 240; } // Special case to deal with 70 channels per group. if (params.c == 2240) { params.channels_per_block = 280; } else if (params.c == 832) { params.channels_per_block = 208; } if (params.c % params.channels_per_block != 0) { if (params.c % 512 == 0 && params.c != 1536 && params.c != 3072 && params.c % 448 != 0) { params.channels_per_block = 512; } else if (params.c % 42 == 0) { params.channels_per_block = 336; } else if (params.c % 384 == 0) { params.channels_per_block = 384; } else if (params.c % 256 == 0 && params.c % 448 != 0 && params.c % 392 != 0) { params.channels_per_block = 256; } else if (params.c % 128 == 0 && params.c % 448 != 0 && params.c % 392 != 0) { params.channels_per_block = 128; } else if (params.c % 448 == 0 && params.c % 392 != 0) { params.channels_per_block = 448; } else if (params.c % 392 == 0) { params.channels_per_block = 392; } } // The number of groups per block. params.groups_per_block = params.channels_per_block / params.channels_per_group; // Make sure the number of channels is a multiple of the number of channels per block. assert(params.c % params.channels_per_block == 0); // Make sure a group does not span multiple blocks. assert(params.channels_per_block % params.channels_per_group == 0); // The number of elements in the reduction buffer (for the sums and sums of squared). zeroed_red_buffer_elts = params.n * params.groups * 2 + params.c * 2; } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params& params, cudaStream_t stream) { // The dimension of the grid. dim3 grid; // The number of blocks to compute all the channels. grid.x = params.c / params.channels_per_block; // The number of blocks to compute all the activations in a given instance. grid.y = div_up(params.hw, params.acts_per_block); // The number of instances. grid.z = params.n; if (params.precision == PrecisionMode::FP16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOFp16W) } else if (params.precision == PrecisionMode::FP16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOBf16W) } else if (params.precision == PrecisionMode::FP16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOFp32W) } else if (params.precision == PrecisionMode::BF16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOFp16W) } else if (params.precision == PrecisionMode::BF16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOBf16W) } else if (params.precision == PrecisionMode::BF16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOFp32W) } else if (params.precision == PrecisionMode::FP32IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOFp16W) } else if (params.precision == PrecisionMode::FP32IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOBf16W) } else { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOFp32W) } // Make sure it launched ok. CHECK_CUDA(cudaGetLastError()); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params params) { // The IO traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The Weights traits. using WTraits = typename Traits::WTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // Weights type using WType = typename WTraits::Type; // Weights doubled type using WType2 = typename WTraits::Type2; // The instance in the batch. int ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2; // The group that thread works on and the channel in the group (modulus). int gi = ci / params.channels_per_group; // Load the gradients for the group. float mean_1 = 0.f, mean_2 = 0.f; if (gi < params.groups) { mean_1 = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi]; mean_2 = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi]; } // The sums from the fwd pass. float2 fwd = params.sums[ni * params.groups + gi]; // The mean of X (computed during the fwd pass -- one value per batch*group). float x_mean = fwd.x; // The mean of squares of X (computed during the fwd pass -- one value per batch*group). float x_sq_mean = fwd.y; // The variance. float x_var = x_sq_mean - x_mean * x_mean; // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)). float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon); // Mutiply by 1/(HWC) to get real mean mean_1 *= params.inv_hwc_per_group; mean_2 *= params.inv_hwc_per_group; // Load gamma. float2 gamma_f2 = make_float2(0.f, 0.f); float2 beta_f2 = make_float2(0.f, 0.f); if (ci < params.c) { gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); if (params.with_swish) { beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } } // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The src/dst offset. int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 x_v2 = IOTraits::zero(); IOType2 dy_v2 = IOTraits::zero(); if (ci < params.c) { x_v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); dy_v2 = *reinterpret_cast(&reinterpret_cast(params.dy)[offset]); } // Extract the two half values. float2 x_f2 = IOTraits::unpack(x_v2); float2 dy_f2 = IOTraits::unpack(dy_v2); // X - X_mean. float2 x_minus_x_mean_f2; x_minus_x_mean_f2.x = x_f2.x - x_mean; x_minus_x_mean_f2.y = x_f2.y - x_mean; // Normalize X. float2 x_norm_f2; x_norm_f2.x = x_minus_x_mean_f2.x * rcp_x_stddev; x_norm_f2.y = x_minus_x_mean_f2.y * rcp_x_stddev; if (params.with_swish) { float x_gn_x = x_norm_f2.x * gamma_f2.x + beta_f2.x; float x_gn_y = x_norm_f2.y * gamma_f2.y + beta_f2.y; float s_x = sigmoid(x_gn_x); float s_y = sigmoid(x_gn_y); dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x)); dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y)); } // The gradient that enters the x_norm node. float2 dx_norm; dx_norm.x = dy_f2.x * gamma_f2.x; dx_norm.y = dy_f2.y * gamma_f2.y; // The gradient along the input path. float2 dx; dx.x = (dx_norm.x - (x_norm_f2.x * mean_1 + mean_2)) * rcp_x_stddev; dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev; // Store the scaled values. if (ci < params.c) { *reinterpret_cast(&reinterpret_cast(params.dx)[offset]) = IOTraits::pack(dx); } } // Load gamma/beta and convert to half. if (blockIdx.y > 0 || blockIdx.z > 0 || ci >= params.c) { return; } // The base pointer for the gradients for gamma and beta. float* dgamma_beta_ptr = ¶ms.zeroed_red_buffer[params.n * params.groups * 2]; // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved. int ck = blockIdx.x * params.channels_per_block + threadIdx.x; // Load the FP32 version of dgamma and dbeta. float2 dgamma, dbeta; if (ck < params.c) { dgamma.x = dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck]; dgamma.y = dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck]; dbeta.x = dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck]; dbeta.y = dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck]; // Convert to half2 and store to memory. *reinterpret_cast(&reinterpret_cast(params.dgamma)[ci]) = WTraits::pack(dgamma); *reinterpret_cast(&reinterpret_cast(params.dbeta)[ci]) = WTraits::pack(dbeta); } } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params& params, cudaStream_t stream) { // The dimension of the grid. dim3 grid; // The number of blocks to compute all the channels. grid.x = params.c / params.channels_per_block; // The number of blocks to compute all the activations in a given instance. grid.y = div_up(params.hw, params.acts_per_block); // The number of instances. grid.z = params.n; if (params.precision == PrecisionMode::FP16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOFp16W) } else if (params.precision == PrecisionMode::FP16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOBf16W) } else if (params.precision == PrecisionMode::FP16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOFp32W) } else if (params.precision == PrecisionMode::BF16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOFp16W) } else if (params.precision == PrecisionMode::BF16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOBf16W) } else if (params.precision == PrecisionMode::BF16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOFp32W) } else if (params.precision == PrecisionMode::FP32IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOFp16W) } else if (params.precision == PrecisionMode::FP32IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOBf16W) } else { CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOFp32W) } // Make sure it launched ok. CHECK_CUDA(cudaGetLastError()); } //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // F O R W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_FWD_SELECT(FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function) \ GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \ assert(false && "Not implemented"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_FWD_RUNNER_SELECT(function) GN_FWD_SELECT(_run, function) #define GN_FWD_BLOCKS_PER_SM_SELECT(function) GN_FWD_SELECT(_blocks_per_sm, function) //////////////////////////////////////////////////////////////////////////////////////////////////// GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128) GN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160) //////////////////////////////////////////////////////////////////////////////////////////////////// inline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params& params, size_t& barriers_elts, size_t& red_buffer_elts, dim3& grid, const cudaDeviceProp& props) { // The pre-computed dimensions. params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Select the kernel. using Function_t = int (*)(); Function_t blocks_per_sm_function; GN_FWD_BLOCKS_PER_SM_SELECT(blocks_per_sm_function); // Define how many activations are computed per block. if (params.hw >= 1024 && params.channels_per_group >= 80 || (params.hw >= 256 && params.channels_per_group >= 160)) { params.acts_per_block = 8 * 16; } else if (params.hw >= 512) { params.acts_per_block = 16 * 32; } else if (params.hw >= 256) { params.acts_per_block = 16 * 16; } else if (params.hw >= 128) { params.acts_per_block = 8 * 16; } else if (params.hw > 0) { params.acts_per_block = 8 * 8; } else { // We should never be here if params are set correctly. assert(false); } // Define the number of blocks per activation map. TODO: Make sure it matches the kernel sizes. int blocks_per_slice = div_up(params.hw, params.acts_per_block); // The number of blocks that can be run per SM. int blocks_per_sm = blocks_per_sm_function(); // The number of blocks per grid. int max_blocks_per_grid = blocks_per_sm * props.multiProcessorCount; // Make sure we are safe to run that many blocks assert(blocks_per_slice <= max_blocks_per_grid); // The number of blocks per slice is the X dimension of the grid. grid.x = blocks_per_slice; // The number of groups * is the X dimension of the grid. grid.y = std::min(max_blocks_per_grid / blocks_per_slice, params.groups * params.n); // The number of barriers. barriers_elts = blocks_per_slice > 1 ? grid.y * 2 : 0; // The number of elements in the reduction buffer (for the sums and sums of squared). if (blocks_per_slice == 1) { red_buffer_elts = 0; } else { // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats. red_buffer_elts = 2 * grid.x * grid.y * 2; } } inline void group_norm_nhwc_fwd_one_pass_run(const Group_norm_nhwc_fwd_params& params, const dim3& grid, cudaStream_t stream) { using Function_t = void (*)(const Group_norm_nhwc_fwd_params&, const dim3&, cudaStream_t); Function_t runner; GN_FWD_RUNNER_SELECT(runner); runner(params, grid, stream); } ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // F O R W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_fwd_one_pass_kernel( Group_norm_nhwc_fwd_params params) { // The traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The Weights traits. using WTraits = typename Traits::WTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // Weights type using WType = typename WTraits::Type; // Weights doubled type using WType2 = typename WTraits::Type2; // The number of activations per block. constexpr int ACTS_PER_BLOCK = ACTS_PER_BLOCK_; // The number of channels per group. constexpr int CHANNELS_PER_GROUP = CHANNELS_PER_GROUP_; // The number of threads per block. constexpr int THREADS_PER_BLOCK = THREADS_PER_BLOCK_; // The number of channels per thread (load fp16x2 numbers). constexpr int CHANNELS_PER_THREAD = 2; // The number of threads needed per activation. constexpr int THREADS_PER_ACT = CHANNELS_PER_GROUP / CHANNELS_PER_THREAD; // The number of activations that are loaded per loop. constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT; // The number of rows per thread. constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP; // The number of active threads. constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT; // The object in charge of doing the sums for the block. typedef cub::BlockReduce Block_reduce; // Allocate shared memory for Block_reduce. __shared__ typename Block_reduce::TempStorage temp_storage; // Allocate shared memory to store the sums. __shared__ float2 smem_sums; // The first activation loaded by that thread. int hwi = blockIdx.x * params.acts_per_block + threadIdx.x / THREADS_PER_ACT; // The first channel loaded by that thread. int ci = threadIdx.x % THREADS_PER_ACT * CHANNELS_PER_THREAD; // Is it an active thread? const bool is_active = threadIdx.x < ACTIVE_THREADS; // Iterate over the iterms in the batch. for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) { // The instance and the group. TODO: Use fast divmod? int ni = ngi / params.groups; int gi = ngi % params.groups; // The offset to the first activation loaded by that thread. const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci; // The pointer to the first activation loaded by that thread. const IOType* x_ptr = &reinterpret_cast(params.x)[offset]; // Load the activations into registers. IOType2 x[ACTS_PER_THREAD]; #pragma unroll for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { int hwj = hwi + ii * ACTS_PER_LOOP; x[ii] = IOTraits::zero(); if (is_active && hwj < params.hw) { x[ii] = *reinterpret_cast(&x_ptr[hwj * params.c]); } } // Compute the sum and the sum of squares for each thread. float2 sums = make_float2(0.f, 0.f); #pragma unroll for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { float2 f2 = IOTraits::unpack(x[ii]); sums.x += f2.x + f2.y; sums.y += f2.x * f2.x + f2.y * f2.y; } // Clear invalid threads. if (ACTIVE_THREADS < THREADS_PER_BLOCK && !is_active) { sums = make_float2(0.f, 0.f); } // Compute the sums for the block. sums = Block_reduce(temp_storage).Reduce(sums, [](const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); }); // The block leader stores to global memory, if needed. if (gridDim.x > 1) { // The index of the buffer (double-buffering). int red_buffer_idx = step & 1; // The barrier. int* barrier = ¶ms.barriers[red_buffer_idx * gridDim.y + blockIdx.y]; // The offset to the reduction buffer. int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2; // The reduction buffer. float2* red_buffer = reinterpret_cast(¶ms.red_buffer[red_buffer_offset]); // The first thread stores its sums. if (threadIdx.x == 0) { red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums; } // Make sure the data is in memory. if (threadIdx.x == 0) { spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x); } __syncthreads(); // Update the sums. for (int ii = 0; ii < gridDim.x; ++ii) { if (ii != blockIdx.x && threadIdx.x == 0) { float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y]; sums.x += other_sums.x; sums.y += other_sums.y; } } } // Store the result for other threads. if (threadIdx.x == 0) { smem_sums = sums; } // Store the results to global memory as well (for training). if (params.sums != nullptr && blockIdx.x == 0 && threadIdx.x == 0) { sums.x *= params.inv_hwc_per_group; sums.y *= params.inv_hwc_per_group; params.sums[ngi] = sums; } // Make sure the sums are in shared memory. __syncthreads(); // Load gamma/beta. float2 gamma_f2 = WTraits::unpack( *reinterpret_cast(&reinterpret_cast(params.gamma)[gi * CHANNELS_PER_GROUP + ci])); float2 beta_f2 = WTraits::unpack( *reinterpret_cast(&reinterpret_cast(params.beta)[gi * CHANNELS_PER_GROUP + ci])); // Compute the mean. float mean = smem_sums.x * params.inv_hwc_per_group; // Compute the variance. float var = smem_sums.y * params.inv_hwc_per_group - (mean * mean); // Compute the inverse of the stddev. float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var + params.epsilon); // The pointer to the first activation stored by that thread. IOType* y_ptr = &reinterpret_cast(params.y)[offset]; // Iterate over the activations to normalize the activations and store the results. for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) { // Extract the two half values. float2 f2 = IOTraits::unpack(x[ii]); // Normalize the channels. f2.x = (f2.x - mean) * inv_stddev; f2.y = (f2.y - mean) * inv_stddev; // Scale by gamma and add beta. f2.x = gamma_f2.x * f2.x + beta_f2.x; f2.y = gamma_f2.y * f2.y + beta_f2.y; // Apply Swish if needed. if (params.with_swish) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } // Store the scaled values. int hwj = hwi + ii * ACTS_PER_LOOP; if (is_active && hwj < params.hw) { *reinterpret_cast(&y_ptr[hwj * params.c]) = IOTraits::pack(f2); } } } } ////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "macros.h" #include "traits.h" //////////////////////////////////////////////////////////////////////////////////////////////////// // // F O R W A R D // //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params) { // The traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan Block_scan; // Allocate shared memory for Block_scan. __shared__ typename Block_scan::TempStorage temp_storage; // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2; // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // The sums. float sum = 0.f, sum_sq = 0.f; // Iterate over the activations to compute the sums. for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The offset. int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 v2 = IOTraits::zero(); if (ci < params.c) { v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); } // Extract the two values. float2 f2 = IOTraits::unpack(v2); // Update the sum. sum += f2.x + f2.y; // Update the sum of squares. sum_sq += f2.x * f2.x + f2.y * f2.y; } // The group that thread works on and the channel in the group (modulus). int gj = threadIdx.x * 2 / params.channels_per_group; int cj = threadIdx.x * 2 - params.channels_per_group * gj; // The data for the summations. Group_sums inp{cj == 0 ? 1 : 0, sum, sum_sq}; // Do the segmented scan. Group_sums out; Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op()); // Store the results for the groups in shared memory (to produce coalesced stores later). if (cj == params.channels_per_group - 2 /* 2 channels per thread */) { smem[gj] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); // The global group index. int gk = blockIdx.x * params.groups_per_block + threadIdx.x; // Threads that have nothing left to do, exit. if (threadIdx.x >= params.groups_per_block || gk >= params.groups) { return; } // The first threads (those storing to global memory, load the values). float2 sums = smem[threadIdx.x]; // Store to global memory. atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x); atomicAdd(¶ms.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y); } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params& params, size_t& zeroed_red_buffer_elts) { // The pre-computed dimensions. params.hw = params.h * params.w; params.hwc = params.c * params.hw; // The number of channels per group. params.channels_per_group = params.c / params.groups; // The inverse to compute the mean/variance. params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group); // Define the number of blocks per activation map. That's a simple heuristic. int blocks_per_act_slice = 0; if (params.c >= 1280) { blocks_per_act_slice = 128 / params.n; } else if (params.c >= 640) { blocks_per_act_slice = 256 / params.n; } else { blocks_per_act_slice = 512 / params.n; } // Clamp to at least 1 to avoid divide-by-zero when batch size is large. blocks_per_act_slice = max(blocks_per_act_slice, 1); // Make sure we launch blocks per activation is no less than activations blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n)); // Define how many activations are computed per block. params.acts_per_block = div_up(params.hw, blocks_per_act_slice); // The number of channels per block. params.channels_per_block = 320; // Special case to deal with 30 channels per group. if (params.channels_per_block % params.channels_per_group != 0) { params.channels_per_block = 240; } // Special case to deal with 70 channels per group. if (params.c == 2240) { params.channels_per_block = 280; } else if (params.c == 832) { params.channels_per_block = 208; } if (params.c % params.channels_per_block != 0) { if (params.c % 512 == 0 && params.c != 1536 && params.c != 3072 && params.c % 448 != 0) { params.channels_per_block = 512; } else if (params.c % 42 == 0) { params.channels_per_block = 336; } else if (params.c % 384 == 0) { params.channels_per_block = 384; } else if (params.c % 256 == 0 && params.c % 448 != 0 && params.c % 392 != 0) { params.channels_per_block = 256; } else if (params.c % 128 == 0 && params.c % 448 != 0 && params.c % 392 != 0) { params.channels_per_block = 128; } else if (params.c % 448 == 0 && params.c % 392 != 0) { params.channels_per_block = 448; } else if (params.c % 392 == 0) { params.channels_per_block = 392; } } // The number of groups per block. params.groups_per_block = params.channels_per_block / params.channels_per_group; // Make sure the number of channels is a multiple of the number of channels per block. assert(params.c % params.channels_per_block == 0); // Make sure a group does not span multiple blocks. assert(params.channels_per_block % params.channels_per_group == 0); // The number of elements in the reduction buffer (for the sums and sums of squared). zeroed_red_buffer_elts = params.n * params.groups * 2; } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params& params, cudaStream_t stream) { // The dimension of the grid. dim3 grid; // The number of blocks to compute all the channels. grid.x = params.c / params.channels_per_block; // The number of blocks to compute all the activations in a given instance. grid.y = div_up(params.hw, params.acts_per_block); // The number of instances. grid.z = params.n; // Launch the kernel. if (params.precision == PrecisionMode::FP16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOFp16W) } else if (params.precision == PrecisionMode::FP16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOBf16W) } else if (params.precision == PrecisionMode::FP16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOFp32W) } else if (params.precision == PrecisionMode::BF16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOFp16W) } else if (params.precision == PrecisionMode::BF16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOBf16W) } else if (params.precision == PrecisionMode::BF16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOFp32W) } else if (params.precision == PrecisionMode::FP32IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOFp16W) } else if (params.precision == PrecisionMode::FP32IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOBf16W) } else { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOFp32W) } // Make sure it launched ok. CHECK_CUDA(cudaGetLastError()); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params params) { // The traits. using Traits = Traits_; // The IO traits. using IOTraits = typename Traits::IOTraits; // The Weights traits. using WTraits = typename Traits::WTraits; // The IO type using IOType = typename IOTraits::Type; // The IO doubled type using IOType2 = typename IOTraits::Type2; // Weights type using WType = typename WTraits::Type; // Weights doubled type using WType2 = typename WTraits::Type2; // The instance in the batch. int ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2; // The group that thread works on and the channel in the group (modulus). int gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. float sum = 0.f, sum_sq = 0.f; if (gi < params.groups) { sum = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi]; sum_sq = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi]; } // Load gamma/beta. float2 gamma_f2, beta_f2; if (ci < params.c) { gamma_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.gamma)[ci])); beta_f2 = WTraits::unpack(*reinterpret_cast(&reinterpret_cast(params.beta)[ci])); } // Compute the mean. float mean = sum * params.inv_hwc_per_group; // Compute the variance. float var = sum_sq * params.inv_hwc_per_group - (mean * mean); // Compute the inverse of the stddev. float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var + params.epsilon); // The first activation loaded by that block. int hw_begin = blockIdx.y * params.acts_per_block; // The last activation loaded by that block. int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw); // Iterate over the activations to compute the sums. for (int hwi = hw_begin; hwi < hw_end; ++hwi) { // The src/dst offset. int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. IOType2 v2 = IOTraits::zero(); if (ci < params.c) { v2 = *reinterpret_cast(&reinterpret_cast(params.x)[offset]); } // Extract the two values. float2 f2 = IOTraits::unpack(v2); // Normalize the channels. f2.x = (f2.x - mean) * inv_stddev; f2.y = (f2.y - mean) * inv_stddev; // Scale by gamma and add beta. f2.x = gamma_f2.x * f2.x + beta_f2.x; f2.y = gamma_f2.y * f2.y + beta_f2.y; // Apply Swish if needed. if (params.with_swish) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } // Store the scaled values. if (ci < params.c) { *reinterpret_cast(&reinterpret_cast(params.y)[offset]) = IOTraits::pack(f2); } } // Write the sums if needed. if (params.sums != nullptr && gi < params.groups) { float2 sums; sums.x = sum * params.inv_hwc_per_group; sums.y = sum_sq * params.inv_hwc_per_group; params.sums[ni * params.groups + gi] = sums; } } //////////////////////////////////////////////////////////////////////////////////////////////////// void group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params& params, cudaStream_t stream) { // The dimension of the grid. dim3 grid; // The number of blocks to compute all the channels. grid.x = params.c / params.channels_per_block; // The number of blocks to compute all the activations in a given instance. grid.y = div_up(params.hw, params.acts_per_block); // The number of instances. grid.z = params.n; if (params.precision == PrecisionMode::FP16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOFp16W) } else if (params.precision == PrecisionMode::FP16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOBf16W) } else if (params.precision == PrecisionMode::FP16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOFp32W) } else if (params.precision == PrecisionMode::BF16IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOFp16W) } else if (params.precision == PrecisionMode::BF16IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOBf16W) } else if (params.precision == PrecisionMode::BF16IOFP32W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOFp32W) } else if (params.precision == PrecisionMode::FP32IOFP16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOFp16W) } else if (params.precision == PrecisionMode::FP32IOBF16W) { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOBf16W) } else { CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOFp32W) } // Make sure it launched ok. CHECK_CUDA(cudaGetLastError()); } //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 10, /* THREADS_PER_BLOCK */ 640) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 112, /* THREADS_PER_BLOCK */ 448) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 12, /* THREADS_PER_BLOCK */ 384) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 120, /* THREADS_PER_BLOCK */ 480) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 128, /* THREADS_PER_BLOCK */ 512) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 14, /* THREADS_PER_BLOCK */ 224) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 16, /* THREADS_PER_BLOCK */ 256) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 160, /* THREADS_PER_BLOCK */ 640) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 20, /* THREADS_PER_BLOCK */ 640) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 24, /* THREADS_PER_BLOCK */ 384) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 26, /* THREADS_PER_BLOCK */ 416) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 28, /* THREADS_PER_BLOCK */ 448) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 30, /* THREADS_PER_BLOCK */ 480) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 32, /* THREADS_PER_BLOCK */ 512) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 4, /* THREADS_PER_BLOCK */ 128) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 40, /* THREADS_PER_BLOCK */ 640) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 42, /* THREADS_PER_BLOCK */ 672) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 48, /* THREADS_PER_BLOCK */ 384) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 56, /* THREADS_PER_BLOCK */ 448) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 60, /* THREADS_PER_BLOCK */ 480) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 64, /* THREADS_PER_BLOCK */ 512) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 70, /* THREADS_PER_BLOCK */ 560) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 8, /* THREADS_PER_BLOCK */ 128) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 80, /* THREADS_PER_BLOCK */ 640) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 84, /* THREADS_PER_BLOCK */ 672) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 96, /* THREADS_PER_BLOCK */ 768) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" #include "macros.h" GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 98, /* THREADS_PER_BLOCK */ 392) ================================================ FILE: apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include "group_norm_nhwc.h" #include "group_norm_nhwc_bwd_one_pass.h" #include "group_norm_nhwc_fwd_one_pass.h" //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA_STATUS(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while (0) #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CHANNELS_LAST(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) #define CHECK_NHWC_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CHANNELS_LAST(x) static bool initialized = false; static cudaDeviceProp props; const std::unordered_set supported_c_values = {128, 256, 320, 384, 448, 512, 640, 768, 896, 960, 1024, 1280, 1344, 1536, 1792, 1920, 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096}; const std::unordered_set supported_groups_values = {16, 32}; std::vector group_norm_fwd(torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes, bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } CHECK_NHWC_INPUT(input); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments int n = input.size(0); int c = input.size(1); int h = input.size(2); int w = input.size(3); // Check kernel constraints TORCH_CHECK(supported_groups_values.count(groups), "`groups` of {16, 32} are only supported but ", groups, " is passed"); TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, " is not included in supported_c_values"); // Allocate tensors auto options = at::TensorOptions(at::kCUDA); auto output = at::empty_like(input, at::MemoryFormat::Preserve); auto sums_d = at::empty({2 * n * groups}, options.dtype(at::kFloat)); // Declare the parameters. Group_norm_nhwc_fwd_params params_fwd; memset(¶ms_fwd, 0, sizeof(params_fwd)); // Initialize the parameters. params_fwd.y = reinterpret_cast(output.data_ptr()); params_fwd.sums = reinterpret_cast(sums_d.data_ptr()); params_fwd.x = const_cast(reinterpret_cast(input.data_ptr())); params_fwd.gamma = const_cast(reinterpret_cast(weight.data_ptr())); params_fwd.beta = const_cast(reinterpret_cast(bias.data_ptr())); params_fwd.epsilon = eps; params_fwd.n = n; params_fwd.h = h; params_fwd.w = w; params_fwd.c = c; params_fwd.groups = groups; params_fwd.with_swish = with_swish; PrecisionMode mode; if (input.dtype() == torch::kFloat32) { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::FP32IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } } else if (input.dtype() == torch::kBFloat16) { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::BF16IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::FP16IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; } } params_fwd.precision = mode; // The number of barriers. size_t barriers_elts = 0; // The number of elements in the reduction buffer. size_t red_buffer_elts = 0; // The number of elements in the reduction buffer that must be zeroed. size_t zeroed_red_buffer_elts = 0; // Finalize the parameters. dim3 grid; if (passes == 1) { group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props); } else { group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts); } // Allocate on the device. auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); params_fwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); params_fwd.barriers = barriers.data_ptr(); auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); params_fwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { group_norm_nhwc_fwd_one_pass_run(params_fwd, grid, stream); } else { group_norm_nhwc_fwd_two_passes_sum(params_fwd, stream); group_norm_nhwc_fwd_two_passes_scale(params_fwd, stream); } return {output, sums_d}; } std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tensor sums, torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes, bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } CHECK_NHWC_INPUT(grad_output); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments int n = input.size(0); int c = input.size(1); int h = input.size(2); int w = input.size(3); // Check kernel constraints TORCH_CHECK(supported_groups_values.count(groups), "`groups` of {16, 32} are only supported but ", groups, " is passed"); TORCH_CHECK(supported_c_values.count(c), "`c` of ", c, " is not included in supported_c_values"); // Allocate tensors auto options = at::TensorOptions(at::kCUDA); auto grad_input = at::empty_like(input, at::MemoryFormat::Preserve); auto grad_weight = at::empty_like(weight, at::MemoryFormat::Preserve); auto grad_bias = at::empty_like(bias, at::MemoryFormat::Preserve); auto sums_d = at::empty({2 * n * groups}, options.dtype(at::kFloat)); // Declare the parameters. Group_norm_nhwc_bwd_params params_bwd; memset(¶ms_bwd, 0, sizeof(params_bwd)); // Initialize the parameters. params_bwd.dx = reinterpret_cast(grad_input.data_ptr()); params_bwd.dgamma = reinterpret_cast(grad_weight.data_ptr()); params_bwd.dbeta = reinterpret_cast(grad_bias.data_ptr()); params_bwd.sums = const_cast(reinterpret_cast(sums.data_ptr())); params_bwd.dy = const_cast(reinterpret_cast(grad_output.data_ptr())); params_bwd.x = const_cast(reinterpret_cast(input.data_ptr())); ; params_bwd.gamma = const_cast(reinterpret_cast(weight.data_ptr())); params_bwd.beta = const_cast(reinterpret_cast(bias.data_ptr())); ; params_bwd.epsilon = eps; params_bwd.n = n; params_bwd.h = h; params_bwd.w = w; params_bwd.c = c; params_bwd.groups = groups; params_bwd.with_swish = with_swish; PrecisionMode mode; if (input.dtype() == torch::kFloat32) { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::FP32IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } } else if (input.dtype() == torch::kBFloat16) { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::BF16IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { if (weight.dtype() == torch::kFloat16) { mode = PrecisionMode::FP16IOFP16W; } else if (weight.dtype() == torch::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; } } params_bwd.precision = mode; // The number of barriers. size_t barriers_elts = 0; // The number of elements in the reduction buffer. size_t red_buffer_elts = 0; // The number of elements in the reduction buffer that must be zeroed. size_t zeroed_red_buffer_elts = 0; // Finalize the parameters. dim3 grid; if (passes == 1) { group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props); } else { group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts); } // Allocate on the device. auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); params_bwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); params_bwd.barriers = barriers.data_ptr(); auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); params_bwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { group_norm_nhwc_bwd_one_pass_run(params_bwd, grid, stream); } else { group_norm_nhwc_bwd_two_passes_sum(params_bwd, stream); group_norm_nhwc_bwd_two_passes_scale(params_bwd, stream); } return {grad_input, grad_weight, grad_bias}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &group_norm_fwd, "NHWC group norm forward", py::call_guard()); m.def("backward", &group_norm_bwd, "NHWC group norm backward", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/group_norm/macros.h ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #define GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ void group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_run( \ const Group_norm_nhwc_##PASS_NAME##_params& params, const dim3& grid, cudaStream_t stream) #define GN_ONE_PASS_RUN_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ auto kernel = \ group_norm_nhwc_##PASS_NAME##_one_pass_kernel; \ \ const Group_norm_nhwc_##PASS_NAME##_params* params_ = ¶ms; \ if (grid.x > 1) { \ CHECK_CUDA(cudaLaunchCooperativeKernel((const void*)kernel, grid, dim3(THREADS_PER_BLOCK), (void**)¶ms_, 0, \ stream)); \ \ } else { \ CHECK_CUDA(cudaLaunchKernel((const void*)kernel, grid, dim3(THREADS_PER_BLOCK), (void**)¶ms_, 0, stream)); \ } \ \ CHECK_CUDA(cudaGetLastError()); \ } ////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME) \ int group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_blocks_per_sm() #define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \ auto kernel = \ group_norm_nhwc_##PASS_NAME##_one_pass_kernel; \ \ int blocks_per_sm = 0; \ CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, THREADS_PER_BLOCK, 0)); \ \ CHECK_CUDA(cudaGetLastError()); \ return blocks_per_sm; \ } ////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_ONE_PASS_(FUNCTION, Traits, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ FUNCTION(Traits, 512, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ FUNCTION(Traits, 256, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ FUNCTION(Traits, 128, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ FUNCTION(Traits, 64, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \ PASS_NAME); \ GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); #define GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \ GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) #define GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, fwd) #define GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, bwd) #define GN_FWD_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \ GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \ CHANNELS_PER_GROUP, PASS_NAME) \ if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && \ params.precision == PrecisionMode::PRECISION) { \ function = \ group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \ } else #define GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \ CHANNELS_PER_GROUP, PASS_NAME, LIMIT_CPG) \ if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP && \ params.precision == PrecisionMode::PRECISION && CHANNELS_PER_GROUP >= LIMIT_CPG) { \ function = \ group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \ } else #define GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Traits, PRECISION, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, PASS_NAME) \ GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 1024, 128, CHANNELS_PER_GROUP, \ PASS_NAME, 80) \ GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 128, CHANNELS_PER_GROUP, \ PASS_NAME, 160) \ GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 512, 512, CHANNELS_PER_GROUP, PASS_NAME) \ GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 256, CHANNELS_PER_GROUP, PASS_NAME) \ GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 128, 128, CHANNELS_PER_GROUP, PASS_NAME) \ GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 0, 64, CHANNELS_PER_GROUP, PASS_NAME) #define GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, fwd) #define GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) \ GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP, \ FUNC_POSTFIX, function, bwd) //////////////////////////////////////////////////////////////////////////////////////////////////// #define GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, PASS_NAME) \ GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) \ GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) #define GN_FWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, fwd) #define GN_BWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, bwd) //////////////////////////////////////////////////////////////////////////////////////////////////// #define CALL_TWO_PASS_KERNEL(Kernel, Precision) \ if (params.channels_per_block == 320) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 280) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 208) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 240) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 512) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 448) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 384) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 256) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 128) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 336) { \ Kernel<<>>(params); \ } else if (params.channels_per_block == 392) { \ Kernel<<>>(params); \ } else { \ assert(false); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm/traits.h ================================================ /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once #include #include #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fp32 { // Type is float32_t using Type = float; // Doubled type using Type2 = float2; // Unpack input to accumulators type static inline __device__ float2 unpack(const float2& f2) { return f2; } // Pack the accumulators into outputs. static inline __device__ float2 pack(const float2& f2) { return f2; } static inline __device__ float2 zero() { return {0.f, 0.f}; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fp16 { // Type is __half using Type = __half; // Doubled type using Type2 = __half2; // Unpack input to accumulators type static inline __device__ float2 unpack(const __half2& h2) { // FIXME(nkorobov): __half22float2 makes compilation error in container return {__half2float(h2.x), __half2float(h2.y)}; } // Pack the accumulators into outputs. static inline __device__ __half2 pack(const float2& f2) { // FIXME(nkorobov): __float22half2_rn makes compilation error in container return {__float2half_rn(f2.x), __float2half_rn(f2.y)}; } static inline __device__ __half2 zero() { uint32_t zero = 0; return *reinterpret_cast<__half2*>(&zero); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Bf16 { // Type is __nv_bfloat16 using Type = __nv_bfloat16; // Doubled type using Type2 = __nv_bfloat162; // Unpack input to accumulators type static inline __device__ float2 unpack(const __nv_bfloat162& h2) { // FIXME(nkorobov): __half22float2 makes compilation error in container return {__bfloat162float(h2.x), __bfloat162float(h2.y)}; } // Pack the accumulators into outputs. static inline __device__ __nv_bfloat162 pack(const float2& f2) { // FIXME(nkorobov): __float22bfloat162_rn makes compilation error in container return {__float2bfloat16_rn(f2.x), __float2bfloat16_rn(f2.y)}; } static inline __device__ __nv_bfloat162 zero() { uint32_t zero = 0; return *reinterpret_cast<__nv_bfloat162*>(&zero); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fp32IOFp16W { // IO traits using IOTraits = Fp32; // Weigths traits using WTraits = Fp16; }; struct Fp32IOBf16W { // IO traits using IOTraits = Fp32; // Weigths traits using WTraits = Bf16; }; struct Fp32IOFp32W { // IO traits using IOTraits = Fp32; // Weigths traits using WTraits = Fp32; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fp16IOFp16W { // IO traits using IOTraits = Fp16; // Weigths traits using WTraits = Fp16; }; struct Fp16IOBf16W { // IO traits using IOTraits = Fp16; // Weigths traits using WTraits = Bf16; }; struct Fp16IOFp32W { // IO traits using IOTraits = Fp16; // Weigths traits using WTraits = Fp32; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Bf16IOFp16W { // IO traits using IOTraits = Bf16; // Weigths traits using WTraits = Fp16; }; struct Bf16IOBf16W { // IO traits using IOTraits = Bf16; // Weigths traits using WTraits = Bf16; }; struct Bf16IOFp32W { // IO traits using IOTraits = Bf16; // Weigths traits using WTraits = Fp32; }; //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py ================================================ import pathlib hw_c_list = [ (8 * 8, 1280), (8 * 8, 2560), (16 * 16, 640), (16 * 16, 1280), (16 * 16, 1920), (16 * 16, 2560), (32 * 32, 320), (32 * 32, 640), (32 * 32, 960), (32 * 32, 1280), (32 * 32, 1920), (64 * 64, 320), (64 * 64, 640), (64 * 64, 960), ] def run(): src_path = pathlib.Path(__file__).parent.absolute() for f in src_path.glob("gn_cuda_inst_*.cu"): f.unlink() for hw, c in hw_c_list: print(f"GN_CUDA_INST_DEFINE({hw}, {c})") with open(src_path / f"gn_cuda_inst_{hw}_{c}.cu", "w") as f: f.write('#include "gn_cuda_host_template.cuh"\n') f.write("\n") f.write("\n") f.write("namespace group_norm_v2 {\n") f.write("\n") f.write(f"GN_CUDA_INST_DEFINE({hw}, {c})\n") f.write("\n") f.write("} // namespace group_norm_v2\n") with open(src_path / "gn_dispatch_hw_c.hpp", "w") as f: f.write("#pragma once\n") f.write("\n") f.write("#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] { \\\n") for hw, c in hw_c_list: f.write( f" if (hw == {hw} && c == {c}) {{ constexpr int HW = {hw}, C = {c}; return __VA_ARGS__(); }} \\\n" ) f.write( ' throw std::invalid_argument("DISPATCH_HW_C " + std::to_string(hw) + " " + std::to_string(c)); \\\n' ) f.write(" }()\n") if __name__ == "__main__": run() ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn.cpp ================================================ #include "gn.hpp" #include #include namespace group_norm_v2 { torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, std::optional mean_var_out, int sm_margin) { if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { throw std::invalid_argument("gn dtype mismatch"); } torch::Tensor out = torch::empty_like(x); float* ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); thread_local torch::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); } if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } return out; } auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, float eps, bool silu, int num_groups, int sm_margin) { if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { throw std::invalid_argument("gn_bwd dtype mismatch"); } torch::Tensor grad_input = torch::empty_like(x); torch::Tensor grad_weight = torch::empty_like(w); torch::Tensor grad_bias = torch::empty_like(w); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); thread_local torch::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); } if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } return std::make_tuple(grad_input, grad_weight, grad_bias); } } // namespace group_norm_v2 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); } ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn.hpp ================================================ #pragma once #include #include namespace group_norm_v2 { struct Meta { int64_t red_buffer_size; int64_t barrier_size; int BLOCK_DIM_X; int C_PER_BLOCK; int ROWS_PER_BLOCK; int VEC_ELEMS; bool LOAD_TWICE; int BLOCKS_PER_SM; bool HARDWARE_CLUSTER; int wgrad_sync_method; }; template void gn_cuda(T* out, T* x, T* w, T* b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float* mean_var_out, float* red_buffer, unsigned* barrier, int sm_margin, cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only); template void gn_bwd_cuda(T* grad_input, T* grad_weight, T* grad_bias, T* grad_output, T* x, T* w, T* b, float* mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float* red_buffer, unsigned* barrier, int sm_margin, cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only); } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda.cu ================================================ #include #include #include #include #include #include #include "gn.hpp" #include "gn_dispatch_hw_c.hpp" #include "gn_utils.hpp" #define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) \ [&] { \ if (num_groups == 16 && silu == true) { \ constexpr int NUM_GROUPS = 16; \ constexpr bool SILU = true; \ return __VA_ARGS__(); \ } \ if (num_groups == 32 && silu == false) { \ constexpr int NUM_GROUPS = 32; \ constexpr bool SILU = false; \ return __VA_ARGS__(); \ } \ throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + \ std::to_string(silu)); \ }() namespace group_norm_v2 { template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)); template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)); template void gn_cuda(GN_CUDA_HOST_PARAMS(T)) { DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { return gn_cuda_single_shape(GN_CUDA_HOST_ARGS); }); }); } template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) { DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] { return gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_ARGS); }); }); } template void gn_cuda(GN_CUDA_HOST_PARAMS(half)); template void gn_cuda(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(half)); template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh ================================================ #pragma once #include #include #include #include #include #include "gn_cuda_kernel.cuh" #include "gn_utils.hpp" namespace group_norm_v2 { #define DISPATCH_LOWER_BOUND_N(VALUE, CONST_NAME, ...) \ [&] { \ if (VALUE >= 16) { \ constexpr int CONST_NAME = 16; \ return __VA_ARGS__(); \ } \ if (VALUE >= 8) { \ constexpr int CONST_NAME = 8; \ return __VA_ARGS__(); \ } \ if (VALUE >= 4) { \ constexpr int CONST_NAME = 4; \ return __VA_ARGS__(); \ } \ if (VALUE >= 2) { \ constexpr int CONST_NAME = 2; \ return __VA_ARGS__(); \ } \ if (VALUE >= 1) { \ constexpr int CONST_NAME = 1; \ return __VA_ARGS__(); \ } \ throw std::invalid_argument("DISPATCH_LOWER_BOUND_N " + std::to_string(VALUE)); \ }() #define DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, ...) \ [&] { \ if (runtime_cuda_arch == 1000 && sm_count >= 148) { \ constexpr int RUNTIME_CUDA_ARCH = 1000, LB_SM_COUNT = 148; \ return __VA_ARGS__(); \ } \ throw std::invalid_argument("DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT " + std::to_string(runtime_cuda_arch) + \ " " + std::to_string(sm_count)); \ }() #define DISPATCH_SM_MARGIN(VALUE, CONST_NAME, ...) \ [&] { \ if (VALUE == 0) { \ constexpr int CONST_NAME = 0; \ return __VA_ARGS__(); \ } \ if (VALUE == 32) { \ constexpr int CONST_NAME = 32; \ return __VA_ARGS__(); \ } \ throw std::invalid_argument("DISPATCH_SM_MARGIN " + std::to_string(VALUE)); \ }() inline constexpr int get_max_cuda_arch() { int cuda_arch_list[] = {__CUDA_ARCH_LIST__}; int max_cuda_arch = -1; for (int cuda_arch_item : cuda_arch_list) { if (cuda_arch_item > max_cuda_arch) { max_cuda_arch = cuda_arch_item; } } return max_cuda_arch; } template constexpr auto compute_gn_params() { constexpr int C = G * CPG; // Initialize each variable to comply with C++17 int BLOCK_DIM_X = 0; int C_PER_BLOCK = 0; int ROWS_PER_BLOCK = 0; bool LOAD_TWICE = false; int BLOCKS_PER_SM = 0; WgradSyncMethod wgrad_sync_method = WGRAD_SYNC_UNSPECIFIED; // There are two tiling strategies: // - block sync: each block handles a whole group, i.e., a multiple of (G * HW) elements // - virtual cluster sync: each virtual cluster handles a group // Block sync can avoid cross-block synchronization latency, but it may cause low occupancy. // Use block sync if the IO size is small, when latency rather than occupancy dominates the kernel running time. // Elements to load for forward pass is `x`, elements to load for backward pass are `x` and `grad_output`, hence there // is a factor of (1 + BWD) if (HW * CPG * (1 + BWD) * sizeof(T) <= 20480) { // Strategy 1: block sync C_PER_BLOCK = CPG; ROWS_PER_BLOCK = HW; BLOCK_DIM_X = lcm(32, C_PER_BLOCK); while (BLOCK_DIM_X < 256) { BLOCK_DIM_X *= 2; } BLOCKS_PER_SM = 1; // The size of registers is 65536 registers * 4 bytes per register. // We have to leave some room for other variables and compiler optimizations, // so we use 36000 as the threshold. LOAD_TWICE = BLOCKS_PER_SM * ROWS_PER_BLOCK * C_PER_BLOCK * (1 + BWD) * sizeof(T) > 36000 * 4; } else { // Strategy 2: virtual cluster sync // A virtual cluster is a group of blocks that are synchronized with each other. // Each group, i.e., a multiple of (G * HW) elements, should be handled on the same virtual cluster. // If the virtual cluster size is supported by the hardware, HARDWARE_CLUSTER is preferred; // otherwise, cooperative groups are used (i.e., PERSISTENT kernels). int c_per_cluster = lcm(128 / (int)sizeof(T), CPG); C_PER_BLOCK = c_per_cluster; BLOCK_DIM_X = C_PER_BLOCK == 320 ? 320 : 480; // Maximum number of rows that should reside in registers int register_max_rows = 36000 * 4 / (C_PER_BLOCK * (1 + BWD) * sizeof(T)); std::tuple best_candidate{}; BLOCKS_PER_SM = 0; ROWS_PER_BLOCK = 0; for (int blocks_per_sm = 1; blocks_per_sm <= 3; blocks_per_sm++) { for (int rows_per_block = HW; rows_per_block >= 1; rows_per_block /= 2) { int virtual_cluster_size = (HW / rows_per_block) * (c_per_cluster / C_PER_BLOCK); if (virtual_cluster_size > blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)) { continue; } int num_clusters = blocks_per_sm * (LB_SM_COUNT - SM_MARGIN) / virtual_cluster_size; int num_tasks = LB_N * (C / c_per_cluster); int num_waves = up_div(num_tasks, num_clusters); bool load_twice = rows_per_block > register_max_rows / blocks_per_sm; // Wave utilization: the percent of SMs that are used for each wave // For example, SM_COUNT=100 and VIRTUAL_CLUSTER_SIZE=64, // if BLOCKS_PER_SM=1, num_clusters=1, wave_util=64%; // if BLOCKS_PER_SM=2, num_clusters=3, wave_util=96%. // This helps select a good number of BLOCKS_PER_SM int wave_util = 10000 * std::min(num_tasks, num_clusters) * virtual_cluster_size / (blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)); decltype(best_candidate) candidate = { true, !load_twice, // Prefer no load twice !(num_waves >= 2 && blocks_per_sm == 1), // When there are multiple waves, prefer multiple blocks per SM to ensure overlapping -num_waves, // Prefer fewer waves std::min(9000, wave_util), // Prefer high wave utilization -blocks_per_sm, // Prefer fewer blocks per SM in order to reduce threads overhead }; if (candidate > best_candidate) { // Assign each element respectively to comply with C++17 std::get<0>(best_candidate) = std::get<0>(candidate); std::get<1>(best_candidate) = std::get<1>(candidate); std::get<2>(best_candidate) = std::get<2>(candidate); std::get<3>(best_candidate) = std::get<3>(candidate); std::get<4>(best_candidate) = std::get<4>(candidate); std::get<5>(best_candidate) = std::get<5>(candidate); static_assert(std::tuple_size::value == 6, "missing assignments"); BLOCKS_PER_SM = blocks_per_sm; ROWS_PER_BLOCK = rows_per_block; } } } LOAD_TWICE = ROWS_PER_BLOCK > register_max_rows / BLOCKS_PER_SM; } int c_per_cluster = lcm(CPG, C_PER_BLOCK); int virtual_cluster_size = (c_per_cluster / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); // The occupancy is affected if cluster size is large. // For example, on H100, when gridDim=128 and each block occupies the whole SM, // if cluster is not used, all blocks can be active simultaneously. // if cluster size is 16, not all blocks can be active simultaneously (which can be queried by // cudaOccupancyMaxActiveClusters), // so there will be two waves which impacts efficiency. // When SM_MARGIN is set, no cluster should be used because other kernels may occupy a part of the cluster. bool HARDWARE_CLUSTER = virtual_cluster_size <= 2 && virtual_cluster_size != 1 && SM_MARGIN == 0; int MAX_VEC_BYTES = 8; // Sometimes 4 or 16 is better, but there is no trivial way to select the best vectorization size. int VEC_ELEMS = std::min(gcd(MAX_VEC_BYTES / (int)sizeof(T), C_PER_BLOCK), gcd(MAX_VEC_BYTES / (int)sizeof(T), ROWS_PER_BLOCK * C_PER_BLOCK / BLOCK_DIM_X)); return std::make_tuple(BLOCK_DIM_X, C_PER_BLOCK, ROWS_PER_BLOCK, VEC_ELEMS, LOAD_TWICE, BLOCKS_PER_SM, HARDWARE_CLUSTER, wgrad_sync_method); } // Save compilation time for unused CUDA_ARCHs // For each template argument from DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT, the kernel is only compiled for the // corresponding CUDA_ARCH template class CompileCondition { public: __host__ __device__ static constexpr bool matches() { #if defined(__CUDA_ARCH__) return __CUDA_ARCH__ == EFFECTIVE_CUDA_ARCH; #else return false; #endif } }; template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)) { if (out == x) { throw std::invalid_argument("not __restrict__"); } cudaDeviceProp const& deviceProp = get_device_prop(device_id); int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; int sm_count = deviceProp.multiProcessorCount; DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { if (hw != HW) { throw std::invalid_argument("wrong HW"); } if (num_groups * channels_per_group != C) { throw std::invalid_argument("wrong C"); } if (num_groups != G) { throw std::invalid_argument("wrong G"); } if (silu != SILU) { throw std::invalid_argument("wrong SILU"); } if (n < LB_N) { throw std::invalid_argument("wrong LB_N"); } if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); } if (sm_count < LB_SM_COUNT) { throw std::invalid_argument("wrong LB_SM_COUNT"); } if (sm_margin != SM_MARGIN) { throw std::invalid_argument("wrong SM_MARGIN"); } constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX constexpr int CPG = C / G; constexpr auto params = compute_gn_params(); constexpr int BLOCK_DIM_X = std::get<0>(params); constexpr int C_PER_BLOCK = std::get<1>(params); constexpr int ROWS_PER_BLOCK = std::get<2>(params); constexpr int VEC_ELEMS = std::get<3>(params); constexpr bool LOAD_TWICE = std::get<4>(params); constexpr int BLOCKS_PER_SM = std::get<5>(params); constexpr bool HARDWARE_CLUSTER = std::get<6>(params); constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); constexpr int NUM_VIRTUAL_CLUSTERS = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; constexpr bool PERSISTENT = !HARDWARE_CLUSTER && VIRTUAL_CLUSTER_SIZE >= 2; // Only virtual cluster sync (not include hardware cluster sync) requires PERSISTENT kernels if (meta_ptr) { constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2; meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS; meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; meta_ptr->C_PER_BLOCK = C_PER_BLOCK; meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; meta_ptr->VEC_ELEMS = VEC_ELEMS; meta_ptr->LOAD_TWICE = LOAD_TWICE; meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; meta_ptr->wgrad_sync_method = (int)WGRAD_SYNC_UNSPECIFIED; } if (meta_only) { return; } cudaLaunchConfig_t config = {0}; config.gridDim = dim3( VIRTUAL_CLUSTER_SIZE, PERSISTENT ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1); config.blockDim = BLOCK_DIM_X; config.stream = stream; cudaLaunchAttribute attribute[2]; if constexpr (HARDWARE_CLUSTER) { attribute[0].id = cudaLaunchAttributeClusterDimension; attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; // Cluster size in X-dimension attribute[0].val.clusterDim.y = 1; attribute[0].val.clusterDim.z = 1; config.attrs = attribute; config.numAttrs++; } if constexpr (PERSISTENT) { attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; attribute[config.numAttrs].val.cooperative = 1; config.attrs = attribute; config.numAttrs++; } auto kernel = &gn_cuda_kernel >; if constexpr (HARDWARE_CLUSTER) { if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); } int max_cluster_size; int active_clusters; CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void*)kernel, &config)); if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void*)kernel, &config)); } if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; } else { // Fallback to cooperative groups because hardware cluster cannot be active simultaneously constexpr bool HARDWARE_CLUSTER_NEW = false; constexpr bool PERSISTENT_NEW = !HARDWARE_CLUSTER_NEW && VIRTUAL_CLUSTER_SIZE >= 2; config.gridDim = dim3( VIRTUAL_CLUSTER_SIZE, PERSISTENT_NEW ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1); config.attrs = nullptr; config.numAttrs = 0; if constexpr (PERSISTENT_NEW) { attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; attribute[config.numAttrs].val.cooperative = 1; config.attrs = attribute; config.numAttrs++; } kernel = &gn_cuda_kernel >; } } CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, out, x, w, b, eps, n, mean_var_out, red_buffer, barrier)); }); }); }); } template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)) { if (grad_input == grad_output || grad_input == x) { throw std::invalid_argument("not __restrict__"); } cudaDeviceProp const& deviceProp = get_device_prop(device_id); int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10; int sm_count = deviceProp.multiProcessorCount; DISPATCH_LOWER_BOUND_N(n, LB_N, [&] { DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] { DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] { if (hw != HW) { throw std::invalid_argument("wrong HW"); } if (num_groups * channels_per_group != C) { throw std::invalid_argument("wrong C"); } if (num_groups != G) { throw std::invalid_argument("wrong G"); } if (silu != SILU) { throw std::invalid_argument("wrong SILU"); } if (n < LB_N) { throw std::invalid_argument("wrong LB_N"); } if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) { throw std::invalid_argument("wrong RUNTIME_CUDA_ARCH"); } if (sm_count < LB_SM_COUNT) { throw std::invalid_argument("wrong LB_SM_COUNT"); } if (sm_margin != SM_MARGIN) { throw std::invalid_argument("wrong SM_MARGIN"); } constexpr int EFFECTIVE_CUDA_ARCH = std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch()); // Assume the max CUDA_ARCH is used to generate PTX constexpr bool REQUIRES_WGRAD = true; constexpr int CPG = C / G; constexpr auto params = compute_gn_params(); constexpr int BLOCK_DIM_X = std::get<0>(params); constexpr int C_PER_BLOCK = std::get<1>(params); constexpr int ROWS_PER_BLOCK = std::get<2>(params); constexpr int VEC_ELEMS = std::get<3>(params); constexpr bool LOAD_TWICE = std::get<4>(params); constexpr int BLOCKS_PER_SM = std::get<5>(params); constexpr bool HARDWARE_CLUSTER = std::get<6>(params); constexpr WgradSyncMethod wgrad_sync_method_hint = std::get<7>(params); constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK); constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); constexpr int NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE; // PERSISTENT is required because wgrad reduction requires synchronization. // TODO: specilize for the case that REQUIRES_WGRAD == false constexpr bool PERSISTENT = true; // Determine whether to align each virtual cluster to a fixed range of channels // If aligned, WGRAD_REUSE_SUM_SYNC_GROUP can be used, then less local wgrad memory is used (leave more room // for compiler // optimizations), and wgrad reduction is more efficient. // However, aligning can cause low occupancy. // There is a trade-off, and the condition to align is `NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / // C_PER_CLUSTER)` constexpr WgradSyncMethod wgrad_sync_method = wgrad_sync_method_hint == WGRAD_SYNC_UNSPECIFIED ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER) || NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED % (C / C_PER_CLUSTER) == 0 ? (HARDWARE_CLUSTER ? WGRAD_ARRIVE_AND_WAIT_GROUP : WGRAD_REUSE_SUM_SYNC_GROUP) : WGRAD_REUSE_SUM_SYNC_GRID : wgrad_sync_method_hint; constexpr int NUM_VIRTUAL_CLUSTERS = wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED / (C / C_PER_CLUSTER) * (C / C_PER_CLUSTER) : NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED; if (meta_ptr) { constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2 + std::max(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) * (HW / ROWS_PER_BLOCK) * C * 2; meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS + C / C_PER_CLUSTER; meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X; meta_ptr->C_PER_BLOCK = C_PER_BLOCK; meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK; meta_ptr->VEC_ELEMS = VEC_ELEMS; meta_ptr->LOAD_TWICE = LOAD_TWICE; meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM; meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER; meta_ptr->wgrad_sync_method = (int)wgrad_sync_method; } if (meta_only) { return; } cudaLaunchConfig_t config = {0}; config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? NUM_VIRTUAL_CLUSTERS : n * (C / C_PER_CLUSTER), 1); config.blockDim = BLOCK_DIM_X; config.stream = stream; cudaLaunchAttribute attribute[2]; if constexpr (HARDWARE_CLUSTER) { attribute[0].id = cudaLaunchAttributeClusterDimension; attribute[0].val.clusterDim.x = 1; // Cluster size in X-dimension attribute[0].val.clusterDim.y = 1; attribute[0].val.clusterDim.z = 1; config.attrs = attribute; config.numAttrs++; } if constexpr (PERSISTENT) { attribute[config.numAttrs].id = cudaLaunchAttributeCooperative; attribute[config.numAttrs].val.cooperative = 1; config.attrs = attribute; config.numAttrs++; } auto kernel = &gn_bwd_cuda_kernel >; if constexpr (HARDWARE_CLUSTER) { if constexpr (VIRTUAL_CLUSTER_SIZE > 8) { CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); } int max_cluster_size; int active_clusters; CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void*)kernel, &config)); if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) { attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void*)kernel, &config)); } if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) { attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE; } else { // Fallback to cooperative groups for dgrad computation because hardware cluster cannot be active // simultaneously attribute[0].val.clusterDim.x = 1; kernel = &gn_bwd_cuda_kernel >; } } CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, n, red_buffer, barrier)); }); }); }); } #define GN_CUDA_INST_DEFINE(HW, C) \ template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ template void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(half)); \ template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ template void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(half)); \ template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); \ template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); \ template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 1280) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 1920) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 320) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 640) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(1024, 960) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 1280) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 1920) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 2560) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(256, 640) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 320) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 640) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(4096, 960) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(64, 1280) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu ================================================ #include "gn_cuda_host_template.cuh" namespace group_norm_v2 { GN_CUDA_INST_DEFINE(64, 2560) } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh ================================================ #pragma once #include #include "gn_utils.hpp" namespace group_norm_v2 { namespace cg = cooperative_groups; template inline constexpr T up_div(T a, T b) { return (a + b - 1) / b; } template inline constexpr T round_up(T a, T b) { return up_div(a, b) * b; } inline constexpr unsigned round_up_pow2(unsigned x) { int log = 0; x--; while (x) { x /= 2; log++; } return 1U << log; } inline constexpr unsigned round_down_pow2(unsigned x) { return round_up_pow2(x + 1) / 2; } template inline constexpr T gcd(T a, T b) { while (b != 0) { int t = b; b = a % b; a = t; } return a; } template inline constexpr T lcm(T a, T b) { return (a * b) / gcd(a, b); } template inline constexpr T relative_prime(T x, T min) { int p = min; while (gcd(p, x) != 1) { p++; } return p; } template inline constexpr T max_divisor(T x, T max) { int p = max; while (x % p != 0) { p--; } return p; } constexpr unsigned FINAL_MASK = 0xffffffff; template __device__ void virtual_cluster_sync(unsigned int* barrier) { if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { __syncthreads(); } else if constexpr (HARDWARE_CLUSTER) { cg::this_cluster().sync(); } else { static_assert(PERSISTENT, "potential deadlock"); volatile unsigned int* arrived = &barrier[blockIdx.y]; __syncthreads(); if (threadIdx.x == 0) { unsigned int expected = VIRTUAL_CLUSTER_SIZE; bool gpu_master = blockIdx.x == 0; unsigned int nb = 1; if (gpu_master) { nb = 0x80000000 - (expected - 1); } unsigned int oldArrive; asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory"); unsigned int current_arrive; do { asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived) : "memory"); } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); } __syncthreads(); } } template __device__ unsigned int group_barrier_arrive(unsigned int* barrier, bool gpu_master) { static_assert(PERSISTENT, "potential deadlock"); volatile unsigned int* arrived = &barrier[0]; __syncthreads(); if (threadIdx.x == 0) { unsigned int expected = NUM_BLOCKS; unsigned int nb = 1; if (gpu_master) { nb = 0x80000000 - (expected - 1); } unsigned int oldArrive; asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), "r"(nb) : "memory"); return oldArrive; } else { return 0; } } __device__ inline void group_barrier_wait(unsigned int* barrier, unsigned int oldArrive) { volatile unsigned int* arrived = &barrier[0]; if (threadIdx.x == 0) { unsigned int current_arrive; do { asm volatile("ld.acquire.gpu.u32 %0,[%1];" : "=r"(current_arrive) : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived) : "memory"); } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive)); } __syncthreads(); } // Calculate `n` (batch id) and `c` (channel range id) for each loop template class NCScheduler; template class NCScheduler { public: __device__ NCScheduler(int64_t n) { nc_loop_ = blockIdx.y; at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); } __device__ auto get_nc() { int64_t n_loop = nc_loop_ / (C / C_PER_CLUSTER); int c_loop = nc_loop_ % (C / C_PER_CLUSTER); return std::make_tuple(n_loop, c_loop); } __device__ void next(int64_t n) { if constexpr (PERSISTENT) { nc_loop_ += NUM_VIRTUAL_CLUSTERS; at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER); } } __device__ bool at_end(int64_t n) { return !PERSISTENT || at_end_; } private: int64_t nc_loop_; bool at_end_; }; template class NCScheduler { public: __device__ NCScheduler(int64_t n) { n_loop_ = blockIdx.y / (C / C_PER_CLUSTER); c_loop_ = blockIdx.y % (C / C_PER_CLUSTER); } __device__ auto get_nc() { return std::make_tuple(n_loop_, c_loop_); } __device__ void next(int64_t n) { if constexpr (PERSISTENT) { n_loop_ += NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER); } } __device__ bool at_end(int64_t n) { return !PERSISTENT || n_loop_ >= n; } private: int64_t n_loop_; int c_loop_; }; class CompileConditionAlwaysTrue { public: __device__ static constexpr bool matches() { return true; } }; template __global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_cuda_kernel( T* __restrict__ out, T const* __restrict__ x, T const* __restrict__ w, T const* __restrict__ b, float eps, int64_t n, float* __restrict__ mean_var_out, float* __restrict__ red_buffer, unsigned* __restrict__ barrier) { // Procedure Overview // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is // used) // 3. Group sum: read from gmem, write mean&var to smem // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); constexpr int C = G * CPG; static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), "inefficient configuration, please reduce C_PER_CLUSTER"); static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); struct alignas(VEC_ELEMS * sizeof(T)) U { T data[VEC_ELEMS]; }; auto compute_mean_var = [&](float2 sum) { float mean = sum.x / (HW * CPG); float var = std::max(0.f, sum.y / (HW * CPG) - mean * mean); return float2{mean, var}; }; static_assert(HW % ROWS_PER_BLOCK == 0, "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; if constexpr (CompileCondition::matches()) { int step = 0; constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; NCScheduler nc_scheduler(n); while (true) { // TODO: unroll the loop if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { break; } } auto [n_loop, c_loop] = nc_scheduler.get_nc(); if constexpr (PERSISTENT) { nc_scheduler.next(n); } static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, "each block should load one or more C_PER_BLOCK at once"); constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; int block_group_start = block_channel_start / CPG; int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; U frag[ROWS_PER_BLOCK / ROWS_PER_IO]; // GCD_VEC_CPG is an important constant that determines how many channels can be merged in reduction computation // For example, VEC_ELEMS=4 and CPG=10, then GCD_VEC_CPG=2, // so we need to store only 2 sums on each thread, and compute only 2 mean&var for each thread. constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); // If each block handles only one group, run warpReduce and store the sum to `sum_per_channel_single_group`; // otherwise store (VEC_ELEMS / GCD_VEC_CPG) sums to `sum_per_channel_multi_group`, where `relative_prime` is used // for swizzle. constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime( 128 / (int)sizeof(float2), ROWS_PER_IO)]; if constexpr (LOAD_TWICE) { float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; U val = *reinterpret_cast(&x[input_idx]); for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { float2 sum = frag_sum_per_channel[i]; for (int k = 0; k < GCD_VEC_CPG; k++) { sum.x += (float)val.data[i * GCD_VEC_CPG + k]; sum.y += (float)val.data[i * GCD_VEC_CPG + k] * (float)val.data[i * GCD_VEC_CPG + k]; } frag_sum_per_channel[i] = sum; } } for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { if constexpr (SINGLE_GROUP_PER_BLOCK) { for (int mask = 16; mask > 0; mask >>= 1) { frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); } static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); if (threadIdx.x % 32 == 0) { sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; } } else { sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; } } __syncthreads(); } else { for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; frag[j] = *reinterpret_cast(&x[input_idx]); } for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { float2 sum = {0.f, 0.f}; for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { for (int k = 0; k < GCD_VEC_CPG; k++) { sum.x += (float)frag[j].data[i * GCD_VEC_CPG + k]; sum.y += (float)frag[j].data[i * GCD_VEC_CPG + k] * (float)frag[j].data[i * GCD_VEC_CPG + k]; } } if constexpr (SINGLE_GROUP_PER_BLOCK) { for (int mask = 16; mask > 0; mask >>= 1) { sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); } static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); if (threadIdx.x % 32 == 0) { sum_per_channel_single_group[threadIdx.x / 32] = sum; } } else { sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; } } __syncthreads(); } U uw = *reinterpret_cast(&w[thread_channel_start]); U ub = *reinterpret_cast(&b[thread_channel_start]); // Three cases for the red_buffer: // - Block sync (VIRTUAL_CLUSTER_SIZE=1): use shared memory // - Virtual cluster sync with HARDWARE_CLUSTER: use distributed shared memory // - Virtual cluster sync without HARDWARE_CLUSTER: use global memory, i.e., `red_buffer` constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; // Specialize for the case that each group is handled by only one block // For common cases, blockSum produces partial sum and stores it to the red_buffer, and groupSum produces // mean&var For the special case, blockSum produces mean&var directly constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = VIRTUAL_CLUSTER_SIZE == 1 && MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented [[maybe_unused]] __align__(16) __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; // Block sum if constexpr (SINGLE_GROUP_PER_BLOCK) { // block reduce if (threadIdx.x < 32) { float2 sum_local_group = threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); } if (threadIdx.x == 0) { if constexpr (USE_SHARED_RED_BUFFER) { if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { shared_red_buffer[0] = compute_mean_var(sum_local_group); } else { shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; } } else { *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * 2]) = sum_local_group; } } } } else { // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)), round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); float2 sum_local_group = {0.f, 0.f}; if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; // TODO: map threads to both the CPG loop and the ROWS loop for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { int c = local_group_idx * CPG + local_c_loop; if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; src_thread_tile_y += THREADS_PER_GROUP) { int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; } } } } static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); } if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { if constexpr (USE_SHARED_RED_BUFFER) { static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); } else { shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; } } else { *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * 2]) = sum_local_group; } } } virtual_cluster_sync(barrier); // Group sum __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)), round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); float2 sum_global_group = {0.f, 0.f}; if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { if constexpr (C_PER_BLOCK % CPG == 0) { // Special case: no cross-virtual_cluster_dim_x reduction float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { float2 val; if constexpr (USE_SHARED_RED_BUFFER) { if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } else { static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank( shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } } else { val = *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) * 2]); } buffer[i / THREADS_PER_GROUP] = val; } for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { float2 val = buffer[i / THREADS_PER_GROUP]; sum_global_group.x += val.x; sum_global_group.y += val.y; } } else { // Common case: cross-virtual_cluster_dim_x reduction int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { int src_virtual_block_idx_x = i % virtual_cluster_dim_x; int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; int src_block_group_start = src_block_channel_start / CPG; int relative_group_idx = local_group_idx - src_block_group_start; if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { float2 val; if constexpr (USE_SHARED_RED_BUFFER) { static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); static_assert(VIRTUAL_CLUSTER_SIZE != 1, "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; } else { val = *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) * 2]); } sum_global_group.x += val.x; sum_global_group.y += val.y; } } } } if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { // Need cluster sync after distributed shared memory access, otherwise behavior is undefined if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { cg::this_cluster().barrier_arrive(); } } else { cg::this_cluster().barrier_arrive(); } } static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); } if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); } __syncthreads(); } auto get_mean_var = [&](int relative_group_idx) { return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] : mean_var[relative_group_idx]; }; if (mean_var_out) { static_assert(MAX_NUM_GROUPS_PER_BLOCK <= BLOCK_DIM_X, "need loop"); if (virtual_block_idx_y == 0 && threadIdx.x < MAX_NUM_GROUPS_PER_BLOCK) { int g = block_group_start + threadIdx.x; if (C_PER_BLOCK % CPG == 0 || g < G) { *reinterpret_cast(&mean_var_out[(n_loop * G + g) * 2]) = get_mean_var(threadIdx.x); } } } float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; float frag_var[VEC_ELEMS / GCD_VEC_CPG]; for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { frag_mean[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; frag_var[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; } for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; U val; if constexpr (LOAD_TWICE) { val = *reinterpret_cast(&x[input_idx]); } else { val = frag[j]; } for (int k = 0; k < VEC_ELEMS; k++) { float f = ((float)val.data[k] - frag_mean[k / GCD_VEC_CPG]) * rsqrtf(frag_var[k / GCD_VEC_CPG] + eps) * (float)uw.data[k] + (float)ub.data[k]; if constexpr (SILU) f = f / (1.f + expf(-f)); val.data[k] = f; } *reinterpret_cast(&out[input_idx]) = val; } if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { cg::this_cluster().barrier_wait(); } } else { cg::this_cluster().barrier_wait(); } } if constexpr (!PERSISTENT) { break; } step ^= 1; } } } enum WgradSyncMethod { WGRAD_ARRIVE_AND_WAIT_GRID = 0, // grid arrive after the last virtual cluster sync WGRAD_ARRIVE_AND_WAIT_GROUP, // group arrive after the last virtual cluster sync (a group sync means synchronizing // all clusters cooperating on the same groups) WGRAD_REUSE_SUM_SYNC_GRID, // grid sync together with the last virtual cluster sync WGRAD_REUSE_SUM_SYNC_GROUP, // group sync together with the last virtual cluster sync WGRAD_SYNC_AT_LAST, // add a sync at the end of NC loops WGRAD_SYNC_UNSPECIFIED, }; template __global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_bwd_cuda_kernel( T* __restrict__ grad_input, T* __restrict__ grad_weight, T* __restrict__ grad_bias, T const* __restrict__ grad_output, T const* __restrict__ x, T const* __restrict__ w, T const* __restrict__ b, float const* __restrict__ mean_var, float eps, int64_t n, float* __restrict__ red_buffer, unsigned* __restrict__ barrier) { // Procedure Overview // 1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE) // 2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is // used), // write wgrad to gmem at the last loop (at each loop if not CONSTANT_C_LOOP) // 3. Group sum: read from gmem, write mean&var to smem // 4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem // 5. Wgrad sum: read from gmem, write to gmem static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error"); constexpr int C = G * CPG; static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters"); static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks"); static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results"); static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK), "inefficient configuration, please reduce C_PER_CLUSTER"); static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads"); struct alignas(VEC_ELEMS * sizeof(T)) U { T data[VEC_ELEMS]; }; // This function computes mean_dyw and mean_xdyw. // The function name is not changed because it has the same logic as the forward pass. auto compute_mean_var = [&](float2 sum) { float mean_dyw = sum.x / (HW * CPG); float mean_xdyw = sum.y / (HW * CPG); return float2{mean_dyw, mean_xdyw}; }; static_assert(HW % ROWS_PER_BLOCK == 0, "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis"); constexpr int MAX_NUM_GROUPS_PER_BLOCK = C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1; constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK); constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK; constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK; int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x; int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x; if constexpr (CompileCondition::matches()) { int step = 0; constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0; if constexpr (!CONSTANT_C_LOOP) { static_assert(wgrad_sync_method != WGRAD_ARRIVE_AND_WAIT_GROUP && wgrad_sync_method != WGRAD_REUSE_SUM_SYNC_GROUP, "grid sync is required when each block is responsible for multiple channel ranges"); } NCScheduler nc_scheduler( n); // TODO: I don't know why the template specialization with CONSTANT_C_LOOP=true is slower. [[maybe_unused]] int virtual_cluster_idx_c = blockIdx.y % (C / C_PER_CLUSTER); [[maybe_unused]] cg::grid_group::arrival_token wgrad_sync_token; [[maybe_unused]] float dw_thread[VEC_ELEMS]; [[maybe_unused]] float db_thread[VEC_ELEMS]; [[maybe_unused]] __shared__ union { float2 dwdb_block_buffer[BLOCK_DIM_X][VEC_ELEMS]; struct { float wgrad_buffer[BLOCK_DIM_X / 32][32]; float bgrad_buffer[BLOCK_DIM_X / 32][32]; } transpose_buffer; } union_smem; if constexpr (REQUIRES_WGRAD && CONSTANT_C_LOOP) { for (int i = 0; i < VEC_ELEMS; i++) { dw_thread[i] = 0.f; db_thread[i] = 0.f; } } float* red_buffer_wgrad = &red_buffer[(2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK) * 2]; unsigned* barrier_wgrad = barrier + NUM_VIRTUAL_CLUSTERS; if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { if (nc_scheduler.at_end(n)) { static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { wgrad_sync_token = group_barrier_arrive( barrier_wgrad, blockIdx.x + blockIdx.y == 0); } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { wgrad_sync_token = group_barrier_arrive( barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { wgrad_sync_token = group_barrier_arrive( barrier_wgrad, blockIdx.x + blockIdx.y == 0); group_barrier_wait(barrier_wgrad, wgrad_sync_token); } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { wgrad_sync_token = group_barrier_arrive( barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); } } } while (true) { // TODO: unroll the loop if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { break; } } auto [n_loop, c_loop] = nc_scheduler.get_nc(); if constexpr (PERSISTENT) { nc_scheduler.next(n); } static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize"); static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0, "each block should load one or more C_PER_BLOCK at once"); constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK; static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch"); int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; int block_group_start = block_channel_start / CPG; int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS; U frag_x[ROWS_PER_BLOCK / ROWS_PER_IO]; U frag_dy[ROWS_PER_BLOCK / ROWS_PER_IO]; constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG); constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0; [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime( 128 / (int)sizeof(float2), ROWS_PER_IO)]; [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32]; float frag_mean[VEC_ELEMS / GCD_VEC_CPG]; float frag_var[VEC_ELEMS / GCD_VEC_CPG]; for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { float2 value = *reinterpret_cast(&mean_var[(n_loop * G + (thread_channel_start + k) / CPG) * 2]); frag_mean[k / GCD_VEC_CPG] = value.x; frag_var[k / GCD_VEC_CPG] = value.y; } U uw = *reinterpret_cast(&w[thread_channel_start]); U ub; if constexpr (SILU) { ub = *reinterpret_cast(&b[thread_channel_start]); } if constexpr (REQUIRES_WGRAD && !CONSTANT_C_LOOP) { for (int i = 0; i < VEC_ELEMS; i++) { dw_thread[i] = 0.f; db_thread[i] = 0.f; } } if constexpr (LOAD_TWICE) { float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{}; for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; U ux = *reinterpret_cast(&x[input_idx]); U udy = *reinterpret_cast(&grad_output[input_idx]); for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { float2 sum = frag_sum_per_channel[i]; for (int k = 0; k < GCD_VEC_CPG; k++) { float rnorm = rsqrtf(frag_var[i] + eps); float x_norm = ((float)ux.data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm; // TODO: store rsqrtf in mean_var float grad_gn = udy.data[i * GCD_VEC_CPG + k]; if constexpr (SILU) { float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; float s = 1.f / (1.f + expf(-x_gn)); grad_gn *= s * (1.f + x_gn * (1.f - s)); } sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); if constexpr (REQUIRES_WGRAD) { dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; db_thread[i * GCD_VEC_CPG + k] += grad_gn; } } frag_sum_per_channel[i] = sum; } } for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { if constexpr (SINGLE_GROUP_PER_BLOCK) { for (int mask = 16; mask > 0; mask >>= 1) { frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32); frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32); } static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); if (threadIdx.x % 32 == 0) { sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i]; } } else { sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i]; } } __syncthreads(); } else { for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; frag_x[j] = *reinterpret_cast(&x[input_idx]); frag_dy[j] = *reinterpret_cast(&grad_output[input_idx]); } for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) { float2 sum = {0.f, 0.f}; for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { for (int k = 0; k < GCD_VEC_CPG; k++) { float rnorm = rsqrtf(frag_var[i] + eps); float x_norm = ((float)frag_x[j].data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm; // TODO: store rsqrtf in mean_var float grad_gn = frag_dy[j].data[i * GCD_VEC_CPG + k]; if constexpr (SILU) { float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k]; float s = 1.f / (1.f + expf(-x_gn)); grad_gn *= s * (1.f + x_gn * (1.f - s)); } sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]; sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]); if constexpr (REQUIRES_WGRAD) { dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn; db_thread[i * GCD_VEC_CPG + k] += grad_gn; } } } if constexpr (SINGLE_GROUP_PER_BLOCK) { for (int mask = 16; mask > 0; mask >>= 1) { sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32); sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32); } static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp"); if (threadIdx.x % 32 == 0) { sum_per_channel_single_group[threadIdx.x / 32] = sum; } } else { sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)] [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum; } } __syncthreads(); } if ((CONSTANT_C_LOOP && nc_scheduler.at_end(n)) || !CONSTANT_C_LOOP) { constexpr int NT_C = max_divisor(C_PER_BLOCK, BLOCK_DIM_X); // Number of threads on the C axis constexpr int NT_R = 1; // std::min(32, (int)round_down_pow2(BLOCK_DIM_X / NT_C)); // Number of threads on the ROWS axis // TODO: swizzle for NT_R for (int i = 0; i < VEC_ELEMS; i++) { union_smem.dwdb_block_buffer[threadIdx.x][i ^ ((threadIdx.x / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))] = float2{dw_thread[i], db_thread[i]}; } __syncthreads(); static_assert(NT_C * NT_R <= BLOCK_DIM_X, "not enough threads"); static_assert(C_PER_BLOCK % NT_C == 0, "need to loop once more and check c < C_PER_BLOCK"); for (int i = 0; i < C_PER_BLOCK / NT_C; i++) { int c = i * NT_C + threadIdx.x / NT_R; float dw_block = 0.f; float db_block = 0.f; if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { for (int j = threadIdx.x % NT_R; j < ROWS_PER_IO; j += NT_R) { int src_thread = j * (C_PER_BLOCK / VEC_ELEMS) + c / VEC_ELEMS; float2 val = union_smem.dwdb_block_buffer[src_thread][(c % VEC_ELEMS) ^ ((src_thread / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))]; dw_block += val.x; db_block += val.y; } } static_assert(32 % NT_R == 0, "cannot shuffle"); for (int mask = NT_R / 2; mask > 0; mask >>= 1) { dw_block += __shfl_xor_sync(FINAL_MASK, dw_block, mask, 32); db_block += __shfl_xor_sync(FINAL_MASK, db_block, mask, 32); } if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) { if (threadIdx.x % NT_R == 0) { if constexpr (CONSTANT_C_LOOP) { *reinterpret_cast( &red_buffer_wgrad [((blockIdx.y / (C / C_PER_CLUSTER) * virtual_cluster_dim_y + virtual_block_idx_y) * C + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * 2]) = float2{dw_block, db_block}; } else { *reinterpret_cast( &red_buffer_wgrad[((n_loop * virtual_cluster_dim_y + virtual_block_idx_y) * C + c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) * 2]) = float2{dw_block, db_block}; } } } } } constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1; constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER = VIRTUAL_CLUSTER_SIZE == 1 && MAX_NUM_GROUPS_PER_BLOCK == 1; // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented [[maybe_unused]] __align__(16) __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)]; // Block sum if constexpr (SINGLE_GROUP_PER_BLOCK) { // block reduce if (threadIdx.x < 32) { float2 sum_local_group = threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f}; constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); } if (threadIdx.x == 0) { if constexpr (USE_SHARED_RED_BUFFER) { if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { shared_red_buffer[0] = compute_mean_var(sum_local_group); } else { shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group; } } else { *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * 2]) = sum_local_group; } } } } else { // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)), round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); float2 sum_local_group = {0.f, 0.f}; if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; // TODO: map threads to both the CPG loop and the ROWS loop for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) { int c = local_group_idx * CPG + local_c_loop; if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) { for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO; src_thread_tile_y += THREADS_PER_GROUP) { int channel_idx = (c - block_channel_start) / GCD_VEC_CPG; channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) + channel_idx / (VEC_ELEMS / GCD_VEC_CPG); sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x; sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y; } } } } static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32); sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32); } if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { if constexpr (USE_SHARED_RED_BUFFER) { static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory"); if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group); } else { shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group; } } else { *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) * 2]) = sum_local_group; } } } if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) { if (nc_scheduler.at_end(n)) { static_assert(PERSISTENT, "persistent is a must for reducing wgrad"); if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { virtual_cluster_sync(barrier); wgrad_sync_token = group_barrier_arrive( barrier_wgrad, blockIdx.x + blockIdx.y == 0); } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { virtual_cluster_sync(barrier); wgrad_sync_token = group_barrier_arrive( barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) { static_assert(!HARDWARE_CLUSTER, "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GRID instead."); wgrad_sync_token = group_barrier_arrive( barrier_wgrad, blockIdx.x + blockIdx.y == 0); group_barrier_wait(barrier_wgrad, wgrad_sync_token); } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) { static_assert(!HARDWARE_CLUSTER, "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GROUP instead."); wgrad_sync_token = group_barrier_arrive( barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0); group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); } } else { virtual_cluster_sync(barrier); } } else { virtual_cluster_sync(barrier); } // Group sum __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK]; if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) { // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce) constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)), round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1)); static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads"); float2 sum_global_group = {0.f, 0.f}; if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { if constexpr (C_PER_BLOCK % CPG == 0) { // Special case: no cross-virtual_cluster_dim_x reduction float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)]; for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { float2 val; if constexpr (USE_SHARED_RED_BUFFER) { if constexpr (VIRTUAL_CLUSTER_SIZE == 1) { val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } else { static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank( shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x); val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP]; } } else { val = *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) * 2]); } buffer[i / THREADS_PER_GROUP] = val; } for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) { float2 val = buffer[i / THREADS_PER_GROUP]; sum_global_group.x += val.x; sum_global_group.y += val.y; } } else { // Common case: cross-virtual_cluster_dim_x reduction int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP; for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) { int src_virtual_block_idx_x = i % virtual_cluster_dim_x; int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER; int src_block_group_start = src_block_channel_start / CPG; int relative_group_idx = local_group_idx - src_block_group_start; if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) { float2 val; if constexpr (USE_SHARED_RED_BUFFER) { static_assert(HARDWARE_CLUSTER, "no distributed shared memory"); static_assert(VIRTUAL_CLUSTER_SIZE != 1, "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)"); float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i); val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx]; } else { val = *reinterpret_cast( &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK + src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) * 2]); } sum_global_group.x += val.x; sum_global_group.y += val.y; } } } } if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { // Need cluster sync after distributed shared memory access, otherwise behavior is undefined if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { cg::this_cluster().barrier_arrive(); } } else { cg::this_cluster().barrier_arrive(); } } static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle"); for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) { sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32); sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32); } if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) { mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group); } __syncthreads(); } auto get_mean_var = [&](int relative_group_idx) { return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx] : mean_var[relative_group_idx]; }; float frag_dyw[VEC_ELEMS / GCD_VEC_CPG]; float frag_xdyw[VEC_ELEMS / GCD_VEC_CPG]; for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) { frag_dyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x; frag_xdyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y; } for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) { int64_t input_idx = n_loop * HW * C + (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C + thread_channel_start; U ux; U udy; if constexpr (LOAD_TWICE) { ux = *reinterpret_cast(&x[input_idx]); udy = *reinterpret_cast(&grad_output[input_idx]); } else { ux = frag_x[j]; udy = frag_dy[j]; } U val; for (int k = 0; k < VEC_ELEMS; k++) { float rnorm = rsqrtf(frag_var[k / GCD_VEC_CPG] + eps); float x_norm = ((float)ux.data[k] - frag_mean[k / GCD_VEC_CPG]) * rnorm; // TODO: store rsqrtf in mean_var float grad_gn = udy.data[k]; if constexpr (SILU) { float x_gn = x_norm * (float)uw.data[k] + (float)ub.data[k]; float s = 1.f / (1.f + expf(-x_gn)); grad_gn *= s * (1.f + x_gn * (1.f - s)); } val.data[k] = (grad_gn * (float)uw.data[k] - frag_dyw[k / GCD_VEC_CPG] - frag_xdyw[k / GCD_VEC_CPG] * x_norm) * rnorm; } *reinterpret_cast(&grad_input[input_idx]) = val; } if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) { if constexpr (PERSISTENT) { if (nc_scheduler.at_end(n)) { cg::this_cluster().barrier_wait(); } } else { cg::this_cluster().barrier_wait(); } } if constexpr (!PERSISTENT) { break; } step ^= 1; } // Wgrad sum if constexpr (REQUIRES_WGRAD) { static_assert(PERSISTENT, "cannot reduce wgrad"); static_assert(C % 32 == 0, "cannot reduce wgrad"); if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) { group_barrier_wait(barrier_wgrad, wgrad_sync_token); } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) { group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token); } else if constexpr (wgrad_sync_method == WGRAD_SYNC_AT_LAST) { cg::this_grid().sync(); } // If group sync, map blocks that are responsible for the same range of channels to these channels (named "split // channels"); otherwise, map all blocks to all channels. constexpr bool split_channels = wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP; for (int c = split_channels ? virtual_cluster_idx_c * C_PER_CLUSTER + 32 * (blockIdx.y / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE + blockIdx.x) : 32 * (blockIdx.y * VIRTUAL_CLUSTER_SIZE + blockIdx.x); split_channels ? c < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER : c < C; c += split_channels ? 32 * (NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE) : 32 * (NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE)) { int64_t rows = (CONSTANT_C_LOOP ? std::min(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) : n) * virtual_cluster_dim_y; float sum_wgrad = 0.f; float sum_bgrad = 0.f; if ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + threadIdx.x % 32 < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || (!split_channels && (C % 32 == 0 || c + threadIdx.x % 32 < C))) { for (int64_t i = threadIdx.x / 32; i < rows; i += BLOCK_DIM_X / 32) { float2 val = *reinterpret_cast(&red_buffer_wgrad[(i * C + c + threadIdx.x % 32) * 2]); sum_wgrad += val.x; sum_bgrad += val.y; } } constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32); union_smem.transpose_buffer .wgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = sum_wgrad; union_smem.transpose_buffer .bgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] = sum_bgrad; __syncthreads(); for (int i = threadIdx.x / warp_num_pow2; i < 32 && ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + i < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) || (!split_channels && (C % 32 == 0 || c + i < C))); i += BLOCK_DIM_X / warp_num_pow2) { int j = threadIdx.x % warp_num_pow2; float sum_wgrad = j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.wgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; float sum_bgrad = j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.bgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f; for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) { sum_wgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_wgrad, mask, warp_num_pow2); sum_bgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_bgrad, mask, warp_num_pow2); } if (j == 0) { grad_weight[c + i] = sum_wgrad; grad_bias[c + i] = sum_bgrad; } } __syncthreads(); } } } } } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp ================================================ #pragma once #define DISPATCH_HW_C(hw, c, HW, C, ...) \ [&] { \ if (hw == 64 && c == 1280) { \ constexpr int HW = 64, C = 1280; \ return __VA_ARGS__(); \ } \ if (hw == 64 && c == 2560) { \ constexpr int HW = 64, C = 2560; \ return __VA_ARGS__(); \ } \ if (hw == 256 && c == 640) { \ constexpr int HW = 256, C = 640; \ return __VA_ARGS__(); \ } \ if (hw == 256 && c == 1280) { \ constexpr int HW = 256, C = 1280; \ return __VA_ARGS__(); \ } \ if (hw == 256 && c == 1920) { \ constexpr int HW = 256, C = 1920; \ return __VA_ARGS__(); \ } \ if (hw == 256 && c == 2560) { \ constexpr int HW = 256, C = 2560; \ return __VA_ARGS__(); \ } \ if (hw == 1024 && c == 320) { \ constexpr int HW = 1024, C = 320; \ return __VA_ARGS__(); \ } \ if (hw == 1024 && c == 640) { \ constexpr int HW = 1024, C = 640; \ return __VA_ARGS__(); \ } \ if (hw == 1024 && c == 960) { \ constexpr int HW = 1024, C = 960; \ return __VA_ARGS__(); \ } \ if (hw == 1024 && c == 1280) { \ constexpr int HW = 1024, C = 1280; \ return __VA_ARGS__(); \ } \ if (hw == 1024 && c == 1920) { \ constexpr int HW = 1024, C = 1920; \ return __VA_ARGS__(); \ } \ if (hw == 4096 && c == 320) { \ constexpr int HW = 4096, C = 320; \ return __VA_ARGS__(); \ } \ if (hw == 4096 && c == 640) { \ constexpr int HW = 4096, C = 640; \ return __VA_ARGS__(); \ } \ if (hw == 4096 && c == 960) { \ constexpr int HW = 4096, C = 960; \ return __VA_ARGS__(); \ } \ throw std::invalid_argument("DISPATCH_HW_C " + std::to_string(hw) + " " + std::to_string(c)); \ }() ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_utils.cpp ================================================ #include "gn_utils.hpp" #include #include namespace group_norm_v2 { cudaDeviceProp const& get_device_prop(int device_id) { static std::vector device_props; static std::once_flag flag; std::call_once(flag, [&] { int count; CUDA_CHECK(cudaGetDeviceCount(&count)); device_props.resize(count); for (int i = 0; i < count; i++) { CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i)); } }); return device_props.at(device_id); } } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/group_norm_v2/gn_utils.hpp ================================================ #pragma once #include #include #include #include #include "gn.hpp" // Definition of CUDA_CHECK macro #define CUDA_CHECK(call) \ do { \ cudaError_t err_ = call; \ if (err_ != cudaSuccess) { \ fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, __LINE__, err_, cudaGetErrorString(err_), \ #call); \ exit(EXIT_FAILURE); \ } \ } while (0) #define GN_CUDA_HOST_PARAMS(T) \ T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, \ float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, \ Meta *meta_ptr, bool meta_only #define GN_BWD_CUDA_HOST_PARAMS(T) \ T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, \ bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, \ int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only #define GN_CUDA_HOST_ARGS \ out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, \ stream, device_id, meta_ptr, meta_only #define GN_BWD_CUDA_HOST_ARGS \ grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, \ channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only namespace group_norm_v2 { cudaDeviceProp const& get_device_prop(int device_id); #ifdef __CUDA_ARCH__ template __host__ __device__ inline int print_rank_0(char const* fmt, Ts&&... args) { if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) { return printf(fmt, std::forward(args)...); } return 0; } #endif } // namespace group_norm_v2 ================================================ FILE: apex/contrib/csrc/groupbn/batch_norm.cu ================================================ #include #include #include #include #include "batch_norm.h" #define cudaCheckErrors(msg) \ do { \ cudaError_t __err = cudaGetLastError(); \ if (__err != cudaSuccess) { \ fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ fprintf(stderr, "*** FAILED - ABORTING\n"); \ exit(1); \ } \ } while (0) static size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; } struct Workspace { Workspace(size_t size) : size(size), data(NULL) { auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); dataPtr = allocator.allocate(size); data = dataPtr.get(); } Workspace(const Workspace&) = delete; Workspace(Workspace&&) = default; Workspace& operator=(Workspace&&) = default; ~Workspace() = default; size_t size; void* data; c10::DataPtr dataPtr; }; // Return {y} at::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.data_ptr(); *magic = (*magic + 1) & 0xff; // Allocate output tensor at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 3; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); // Don't fuse in ReLU for now at least bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); return y; } at::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon, const bool fuse_relu) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // Allocate output tensor at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 3; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(nullptr); workspace.push_back(nullptr); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); // Don't fuse in ReLU for now at least bn->fwdInference(stream, fuse_relu); return y; } std::vector nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop) { // shape const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.data_ptr(); *magic = (*magic + 1) & 0xff; // outputs at::Tensor x_grad, scale_grad, bias_grad; // Allocate outputs x_grad = at::empty_like(x); scale_grad = at::empty_like(scale); bias_grad = at::empty_like(bias); // Create wrapper NhwcBatchNorm* bn = new NhwcBatchNorm(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), x_grad.data_ptr(), nullptr, dy.data_ptr()); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {scale_grad.data_ptr(), bias_grad.data_ptr()}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 3; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 3; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 3]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); return std::vector{x_grad, scale_grad, bias_grad}; } int nhwc_bn_fwd_occupancy() { int device_id = -1; cudaGetDevice(&device_id); // max occupancy supported by the code is 2 return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2); } int nhwc_bn_bwd_occupancy() { int device_id = -1; cudaGetDevice(&device_id); // max occupancy supported by the code is 2 return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2); } ================================================ FILE: apex/contrib/csrc/groupbn/batch_norm.h ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file nhwc_batch_norm.h * \brief CUDA NHWC Batch Normalization code * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ #include #include #include #include #include #include "cuda_utils.h" #include "nhwc_batch_norm_kernel.h" #define VERBOSE_DEFAULT false class NhwcBatchNorm { public: NhwcBatchNorm() { name_ = "nhwc_batchnorm"; createTensorDescriptor(&X_tensor_desc_); createTensorDescriptor(&Y_tensor_desc_); } ~NhwcBatchNorm() { destroyTensorDescriptor(X_tensor_desc_); destroyTensorDescriptor(Y_tensor_desc_); } void die() { std::cerr << "batchnorm not initialized" << std::endl; exit(-1); } void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void fwdInference(cudaStream_t stream, bool use_relu); dim3 calc_fwd_grid(int* loop, const int grid_dim_x); dim3 calc_bwd_grid(int* loop, const int grid_dim_x); void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w, int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; c_ = c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. svar_inv_count_ = 1.f / m_bn_adjusted; // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1). int divisor = m_bn_adjusted - 1; // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs. rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor; setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } const std::vector numWorkspaceBytes() const; void setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes); void setInputOutputPointers(void* X, void* dX, void* Y, void* dY) { X_ = X; dX_ = dX; Y_ = Y; dY_ = dY; } // Sets the pointers for the scale and weight (in that order) data and derivative buffers. void setWeightPointers(const std::vector& weight_pointers, const std::vector& deriv_pointers) { assert(weight_pointers.size() == 2); assert(deriv_pointers.size() == 2); scale_ = static_cast(weight_pointers[0]); bias_ = static_cast(weight_pointers[1]); dscale_ = static_cast(deriv_pointers[0]); dbias_ = static_cast(deriv_pointers[1]); } // Sets the pointers for the population mean and variance buffers, in that order. void setParameterPointers(const std::vector& param_pointers) { assert(param_pointers.size() == 2); population_mean_ = static_cast(param_pointers[0]); population_variance_ = static_cast(param_pointers[1]); } void setConstants(const double exp_avg_factor, const double eps) { exp_avg_factor_ = exp_avg_factor; eps_ = eps; } void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { if (status != CUDNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudnnGetErrorString(status); } void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { cudaError_t status = cudaGetLastError(); if (status != cudaSuccess) LOG(FATAL) << string << " " << cudaGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudaGetErrorString(status); } size_t size_retired_ctas(int grid_y) const { // Note that the value of max_grid_y to handle known GPUs is about 160. const int max_grid_y = 1024; if (grid_y > max_grid_y) LOG(INFO) << "GPU capabilities exceeds assumptions."; const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); // Since the region will be initialized once and used for many kernels, // the idea is to return an ample size that will cover all uses. return retired_cta_bytes; } cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; void* X_ = nullptr; void* dX_ = nullptr; void* Y_ = nullptr; void* dY_ = nullptr; // Learned scale and bias weights. float* scale_ = nullptr; float* dscale_ = nullptr; float* bias_ = nullptr; float* dbias_ = nullptr; // Computed population mean and variance parameters. float* population_mean_ = nullptr; float* population_variance_ = nullptr; // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). float* minibatch_mean_ = nullptr; float* minibatch_variance_ = nullptr; int m_ = 0; // Number of values per channel that BN is normalizing. int c_ = 0; // Number of channels over which BN is normalizing. float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance double exp_avg_factor_ = 0.; double eps_ = 0.; std::string name_; private: void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type, int n, int c, int h, int w) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); processCudnnStatus(status, "set tensor descriptor"); } void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnCreateTensorDescriptor(descriptor); processCudnnStatus(status, "create tensor_descriptor"); } void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnDestroyTensorDescriptor(descriptor); processCudnnStatus(status, "destroy tensor_descriptor"); } protected: float* partial_sums_ = nullptr; int* partial_counts_ = nullptr; int* retired_ctas_ = nullptr; void _setFwdParams(NhwcBatchNormFwdParams* params) const; void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const; void _setBwdParams(NhwcBatchNormBwdParams* params) const; // @todo: ability to configure these? // Kernel params static const int USE_ONLINE_APPROACH = 1; static const int THREADS_PER_CTA = 512; static const int THREADS_PER_PIXEL = 16; static const int C_ELEMENTS_PER_CTA = 64; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; typedef uint16_t StorageType; // typedef float StorageType; // increasing this to 6 causes spills in fwd kernel! static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD; static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD; static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; // Derived params static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType); static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType); static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD; static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD; static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE; // max grid.y in case of group bn is limited by exchange buffer size static const int MAX_GBN_BLOCK_Y = 256; // Helper function to launch the forward kernel. // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel // version that was compiled with that occupancy in its launch bounds. This way, we avoid // needless register spills. void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { #define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ auto fwd_func = \ nhwc_batch_norm_fwd; \ if (COMPILED_FOR_OCCUPANCY > 1) { \ cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ } \ void* params_ptr = static_cast(¶ms); \ using FWD_FUNC = decltype(nhwc_batch_norm_fwd); \ if (COOP) { \ cudaLaunchCooperativeKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ } else { \ cudaLaunchKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ } \ checkCudaStatus(name_ + " fwd ser coop kernel"); \ } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { if (occupancy >= 2) LAUNCH_FWD_KERNEL(1, true, false, 2, coop); else LAUNCH_FWD_KERNEL(1, true, false, 1, coop); } else if (outer_loops == 1 && !use_relu) { if (occupancy >= 2) LAUNCH_FWD_KERNEL(1, false, false, 2, coop); else LAUNCH_FWD_KERNEL(1, false, false, 1, coop); } else if (use_relu) { if (occupancy >= 2) LAUNCH_FWD_KERNEL(0, true, false, 2, coop); else LAUNCH_FWD_KERNEL(0, true, false, 1, coop); } else { if (occupancy >= 2) LAUNCH_FWD_KERNEL(0, false, false, 2, coop); else LAUNCH_FWD_KERNEL(0, false, false, 1, coop); } #undef LAUNCH_FWD_KERNEL } // Helper function to launch the backward kernel. void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { #define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ auto bwd_func = nhwc_batch_norm_bwd; \ if (COMPILED_FOR_OCCUPANCY > 1) { \ cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ } \ void* params_ptr = static_cast(¶ms); \ using BWD_FUNC = \ decltype(nhwc_batch_norm_bwd); \ if (COOP) { \ cudaLaunchCooperativeKernel(bwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ } else { \ cudaLaunchKernel(bwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ } \ checkCudaStatus(name_ + " bwd coop serial kernel"); \ } while (0) #define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ auto bwd_relu_func = \ nhwc_batch_norm_bwd_relu; \ if (COMPILED_FOR_OCCUPANCY > 1) { \ cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ } \ void* params_ptr = static_cast(¶ms); \ using BWD_RELU_FUNC = \ decltype(nhwc_batch_norm_bwd_relu); \ if (COOP) { \ cudaLaunchCooperativeKernel(bwd_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, \ stream); \ } else { \ cudaLaunchKernel(bwd_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, stream); \ } \ checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { if (occupancy >= 2) LAUNCH_BWD_RELU_KERNEL(1, 2, coop); else LAUNCH_BWD_RELU_KERNEL(1, 1, coop); } else if (outer_loops == 1 && !use_relu) { if (occupancy >= 2) LAUNCH_BWD_KERNEL(1, 2, coop); else LAUNCH_BWD_KERNEL(1, 1, coop); } else if (use_relu) { if (occupancy >= 2) LAUNCH_BWD_RELU_KERNEL(0, 2, coop); else LAUNCH_BWD_RELU_KERNEL(0, 1, coop); } else { if (occupancy >= 2) LAUNCH_BWD_KERNEL(0, 2, coop); else LAUNCH_BWD_KERNEL(0, 1, coop); } #undef LAUNCH_BWD_KERNEL } public: // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); } // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); } }; const std::vector NhwcBatchNorm::numWorkspaceBytes() const { assert(c_ > 0); // choose the max memory required between fwd/bwd passes int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD); int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD); int grid_x = max(grid_x_fwd, grid_x_bwd); int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); const size_t num_mean_bytes = c_ * sizeof(float); const size_t num_variance_bytes = num_mean_bytes; const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float); const size_t size_counts = grid_y * grid_x * sizeof(int); return {num_mean_bytes, num_variance_bytes, size_retired_ctas(grid_y), size_sums, size_counts}; } void NhwcBatchNorm::setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes) { assert(workspace.size() == 5); assert(num_workspace_bytes.size() == 5); minibatch_mean_ = static_cast(workspace[0]); minibatch_variance_ = static_cast(workspace[1]); retired_ctas_ = static_cast(workspace[2]); partial_sums_ = static_cast(workspace[3]); partial_counts_ = static_cast(workspace[4]); } void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dst = static_cast(Y_); params->gmem_src1 = nullptr; params->gmem_bias = bias_; params->gmem_scale = scale_; params->gmem_running_mean = population_mean_; params->gmem_running_var = population_variance_; params->gmem_saved_mean = minibatch_mean_; params->gmem_saved_var = minibatch_variance_; params->gmem_relu_bitmask = nullptr; params->nhw = m_; params->c = c_; params->svar_inv_count = svar_inv_count_; params->rvar_inv_count = rvar_inv_count_; params->gmem_sums = partial_sums_; params->gmem_counts = partial_counts_; params->gmem_retired_ctas = retired_ctas_; params->var_eps = eps_; params->outer_loops = 0; params->exp_avg_factor = static_cast(exp_avg_factor_); params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dst = static_cast(Y_); params->gmem_src1 = nullptr; params->gmem_bias = bias_; params->gmem_scale = scale_; params->gmem_mean = population_mean_; params->gmem_var = population_variance_; params->nhw = m_; params->c = c_; params->var_eps = eps_; } void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dy = static_cast(dY_); params->gmem_dst = static_cast(dX_); params->gmem_dst1 = nullptr; params->gmem_relu_bitmask = nullptr; params->gmem_dscale = dscale_; params->gmem_dbias = dbias_; params->gmem_scale = scale_; params->gmem_bias = bias_; params->gmem_saved_mean = minibatch_mean_; params->gmem_saved_var = minibatch_variance_; params->nhw = m_; params->c = c_; params->svar_inv_count = svar_inv_count_; params->gmem_sums = partial_sums_; params->gmem_retired_ctas = retired_ctas_; params->outer_loops = 0; params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr // && minibatch_mean_ != nullptr // && minibatch_variance_ != nullptr && population_mean_ != nullptr && population_variance_ != nullptr && X_ != nullptr // && dX_ != nullptr && Y_ != nullptr // && dY_ != nullptr // && dscale_ != nullptr // && dbias_ != nullptr && partial_sums_ != nullptr && partial_counts_ != nullptr; if (!ptrs_are_set) die(); dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA); // @todo: maybe just move this inside initialize routine? NhwcBatchNormFwdInferenceParams params; _setFwdInferenceParams(¶ms); if (use_relu) { nhwc_batch_norm_fwd_inference <<>>(params); checkCudaStatus(name_ + " fwd_inference-relu kernel"); } else { nhwc_batch_norm_fwd_inference <<>>(params); checkCudaStatus(name_ + " fwd_inference kernel"); } } dim3 NhwcBatchNorm::calc_fwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); unsigned int max_grid_x = grid_dim_x; if (grid_dim.x <= max_grid_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); assert(grid_dim.y < MAX_GBN_BLOCK_Y); // FIXME: turn into a loop } else { grid_dim.y = 1; } } else { grid_dim.x = max_grid_x; grid_dim.y = 1; int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD * PIXELS_PER_LDG * grid_dim.x; int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD * PIXELS_PER_LDG * grid_dim.x; *loop = div_up(nhw_in_regs, pixels_per_iteration); } return grid_dim; } dim3 NhwcBatchNorm::calc_bwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); unsigned int max_grid_x = grid_dim_x; if (grid_dim.x <= max_grid_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); assert(grid_dim.y < MAX_GBN_BLOCK_Y); // FIXME: turn into a loop } else { grid_dim.y = 1; } } else { grid_dim.x = max_grid_x; grid_dim.y = 1; int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD * PIXELS_PER_LDG * grid_dim.x; int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD * PIXELS_PER_LDG * grid_dim.x; *loop = div_up(nhw_in_regs, pixels_per_iteration); } return grid_dim; } void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr && minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && population_mean_ != nullptr && population_variance_ != nullptr && X_ != nullptr // && dX_ != nullptr && Y_ != nullptr // && dY_ != nullptr // && dscale_ != nullptr // && dbias_ != nullptr && partial_sums_ != nullptr && partial_counts_ != nullptr && retired_ctas_ != nullptr; if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed NhwcBatchNormFwdParams params; _setFwdParams(¶ms); params.my_data = my_data; params.pair_datas[0] = pair_data; params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); } void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && (bias_ != nullptr || !use_relu) && minibatch_mean_ != nullptr && minibatch_variance_ != nullptr // && population_mean_ != nullptr // && population_variance_ != nullptr && X_ != nullptr && dX_ != nullptr // && Y_ != nullptr && dY_ != nullptr && dscale_ != nullptr && dbias_ != nullptr; if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed NhwcBatchNormBwdParams params; _setBwdParams(¶ms); params.my_data = my_data; params.pair_datas[0] = pair_data; params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); params.wgrad_coeff = 1.0 / bn_group; dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); } #endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ ================================================ FILE: apex/contrib/csrc/groupbn/batch_norm_add_relu.cu ================================================ #include #include #include #include #include "batch_norm_add_relu.h" // FIXME move the common stuff to common h file #define cudaCheckErrors(msg) \ do { \ cudaError_t __err = cudaGetLastError(); \ if (__err != cudaSuccess) { \ fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ fprintf(stderr, "*** FAILED - ABORTING\n"); \ exit(1); \ } \ } while (0) static size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; } struct Workspace { Workspace(size_t size) : size(size), data(NULL) { auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); dataPtr = allocator.allocate(size); data = dataPtr.get(); } Workspace(const Workspace&) = delete; Workspace(Workspace&&) = default; Workspace& operator=(Workspace&&) = default; ~Workspace() = default; size_t size; void* data; c10::DataPtr dataPtr; }; // Return {y} at::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, const float momentum, const float epsilon, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.data_ptr(); *magic = (*magic + 1) & 0xff; // Allocate output tensor at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr, z.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 4; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); workspace.push_back(bitmask.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); // Don't fuse in ReLU for now at least bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); return y; } at::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon) { const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // Allocate output tensor at::Tensor y = at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), nullptr, y.data_ptr(), nullptr, z.data_ptr(), nullptr); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {nullptr, nullptr}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 4; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(nullptr); workspace.push_back(nullptr); workspace.push_back(nullptr); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); // Don't fuse in ReLU for now at least bn->fwdInference(stream); return y; } std::vector nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, const float momentum, const float epsilon, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop) { // shape const int N = x.size(0); const int H = x.size(1); const int W = x.size(2); const int C = x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.data_ptr(); *magic = (*magic + 1) & 0xff; // outputs at::Tensor x_grad, z_grad, scale_grad, bias_grad; // Allocate outputs x_grad = at::empty_like(x); z_grad = at::empty_like(x); scale_grad = at::empty_like(scale); bias_grad = at::empty_like(bias); // Create wrapper NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu(); bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper bn->setInputOutputPointers(x.data_ptr(), x_grad.data_ptr(), nullptr, dy.data_ptr(), nullptr, z_grad.data_ptr()); bn->setWeightPointers({scale.data_ptr(), bias.data_ptr()}, {scale_grad.data_ptr(), bias_grad.data_ptr()}); bn->setParameterPointers({running_mean.data_ptr(), running_inv_var.data_ptr()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset // an allocated workspace for the others size_t total_workspace_bytes = 0; std::vector workspace_offsets; for (auto index = 4; index < workspace_bytes.size(); ++index) { total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); workspace_offsets.push_back(total_workspace_bytes); auto alloc_bytes = workspace_bytes[index]; total_workspace_bytes += alloc_bytes; } // Allocate the workspace Workspace ws(total_workspace_bytes); std::vector workspace; workspace.push_back(minibatch_mean.data_ptr()); workspace.push_back(minibatch_inv_var.data_ptr()); workspace.push_back(bitmask.data_ptr()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; void* retired_ctas = ret_cta.data_ptr(); assert(ret_cta.size(0) >= retired_cta_bytes); workspace.push_back(retired_ctas); for (auto index = 4; index < workspace_bytes.size(); ++index) { void* ptr = reinterpret_cast(ws.data) + workspace_offsets[index - 4]; workspace.push_back(ptr); } bn->setWorkspacePointers(workspace, workspace_bytes); bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); return std::vector{x_grad, z_grad, scale_grad, bias_grad}; } int nhwc_bn_addrelu_fwd_occupancy() { int device_id = -1; cudaGetDevice(&device_id); // max occupancy supported by the code is 2 return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2); } int nhwc_bn_addrelu_bwd_occupancy() { int device_id = -1; cudaGetDevice(&device_id); // max occupancy supported by the code is 2 return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2); } ================================================ FILE: apex/contrib/csrc/groupbn/batch_norm_add_relu.h ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file nhwc_batch_norm_add_relu.h * \brief CUDA NHWC Batch Normalization code with fused addition * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ #include #include #include #include #include #include "cuda_utils.h" #include "nhwc_batch_norm_kernel.h" #define VERBOSE_DEFAULT false class NhwcBatchNormAddRelu { public: NhwcBatchNormAddRelu() { name_ = "nhwc_batchnormaddrelu"; createTensorDescriptor(&X_tensor_desc_); createTensorDescriptor(&Y_tensor_desc_); } ~NhwcBatchNormAddRelu() { destroyTensorDescriptor(X_tensor_desc_); destroyTensorDescriptor(Y_tensor_desc_); } void die() { std::cerr << "batchnormaddrelu not initialized" << std::endl; exit(-1); } void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); void fwdInference(cudaStream_t stream); dim3 calc_fwd_grid(int* loop, const int grid_dim_x); dim3 calc_bwd_grid(int* loop, const int grid_dim_x); void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w, int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; c_ = c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. svar_inv_count_ = 1.f / m_bn_adjusted; // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1). int divisor = m_bn_adjusted - 1; // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs. rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor; setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } const std::vector numWorkspaceBytes() const; void setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes); void setInputOutputPointers(void* X, void* dX, void* Y, void* dY, void* addend, void* dAddend) { X_ = X; dX_ = dX; Y_ = Y; dY_ = dY; addend_ = addend; dAddend_ = dAddend; } // Sets the pointers for the scale and weight (in that order) data and derivative buffers. void setWeightPointers(const std::vector& weight_pointers, const std::vector& deriv_pointers) { assert(weight_pointers.size() == 2); assert(deriv_pointers.size() == 2); scale_ = static_cast(weight_pointers[0]); bias_ = static_cast(weight_pointers[1]); dscale_ = static_cast(deriv_pointers[0]); dbias_ = static_cast(deriv_pointers[1]); } // Sets the pointers for the population mean and variance buffers, in that order. void setParameterPointers(const std::vector& param_pointers) { assert(param_pointers.size() == 2); population_mean_ = static_cast(param_pointers[0]); population_variance_ = static_cast(param_pointers[1]); } void setConstants(const double exp_avg_factor, const double eps) { exp_avg_factor_ = exp_avg_factor; eps_ = eps; } void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { if (status != CUDNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudnnGetErrorString(status); } void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { cudaError_t status = cudaGetLastError(); if (status != cudaSuccess) LOG(FATAL) << string << " " << cudaGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudaGetErrorString(status); } size_t size_retired_ctas(int grid_y) const { // Note that the value of max_grid_y to handle known GPUs is about 160. const int max_grid_y = 1024; if (grid_y > max_grid_y) LOG(INFO) << "GPU capabilities exceeds assumptions."; const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); // Since the region will be initialized once and used for many kernels, // the idea is to return an ample size that will cover all uses. return retired_cta_bytes; } cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; void* X_ = nullptr; void* dX_ = nullptr; void* Y_ = nullptr; void* dY_ = nullptr; void* addend_ = nullptr; void* dAddend_ = nullptr; // Learned scale and bias weights. float* scale_ = nullptr; float* dscale_ = nullptr; float* bias_ = nullptr; float* dbias_ = nullptr; // Computed population mean and variance parameters. float* population_mean_ = nullptr; float* population_variance_ = nullptr; // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). float* minibatch_mean_ = nullptr; float* minibatch_variance_ = nullptr; int m_ = 0; // Number of values per channel that BN is normalizing. int c_ = 0; // Number of channels over which BN is normalizing. float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance double exp_avg_factor_ = 0.; double eps_ = 0.; std::string name_; private: void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type, int n, int c, int h, int w) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); processCudnnStatus(status, "set tensor descriptor"); } void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnCreateTensorDescriptor(descriptor); processCudnnStatus(status, "create tensor_descriptor"); } void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) { cudnnStatus_t status = CUDNN_STATUS_SUCCESS; status = cudnnDestroyTensorDescriptor(descriptor); processCudnnStatus(status, "destroy tensor_descriptor"); } protected: float* partial_sums_ = nullptr; int* partial_counts_ = nullptr; int* retired_ctas_ = nullptr; unsigned int* relu_bitmask_ = nullptr; void _setFwdParams(NhwcBatchNormFwdParams* params) const; void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const; void _setBwdParams(NhwcBatchNormBwdParams* params) const; // @todo: ability to configure these? // Kernel params static const int USE_ONLINE_APPROACH = 1; static const int THREADS_PER_CTA = 512; static const int THREADS_PER_PIXEL = 16; static const int C_ELEMENTS_PER_CTA = 64; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; typedef uint16_t StorageType; // increasing this to 6 causes spills in fwd kernel! static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD; static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD; static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; // Derived params static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType); static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType); static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD; static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD; static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE; // max grid.y in case of group bn is limited by exchange buffer size static const int MAX_GBN_BLOCK_Y = 256; // Helper function to launch the forward kernel. // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel // version that was compiled with that occupancy in its launch bounds. This way, we avoid // needless register spills. void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { #define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnormaddrelu kernel smem too big."; \ auto fwd_func = \ nhwc_batch_norm_fwd; \ if (COMPILED_FOR_OCCUPANCY > 1) { \ cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ } \ void* params_ptr = static_cast(¶ms); \ using FWD_FUNC = decltype(nhwc_batch_norm_fwd); \ if (COOP) { \ cudaLaunchCooperativeKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ } else { \ cudaLaunchKernel(fwd_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_FWD, stream); \ } \ checkCudaStatus(name_ + " fwd ser coop kernel"); \ } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1) { if (occupancy >= 2) LAUNCH_FWD_KERNEL(1, false, true, 2, coop); else LAUNCH_FWD_KERNEL(1, false, true, 1, coop); } else { if (occupancy >= 2) LAUNCH_FWD_KERNEL(0, false, true, 2, coop); else LAUNCH_FWD_KERNEL(0, false, true, 1, coop); } #undef LAUNCH_FWD_KERNEL } // Helper function to launch the backward kernel. void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { #define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnormaddrelu kernel smem too big."; \ auto bwd_add_relu_func = \ nhwc_batch_norm_bwd_add_relu; \ if (COMPILED_FOR_OCCUPANCY > 1) { \ cudaFuncSetAttribute(bwd_add_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ checkCudaStatus(name_ + " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ } \ void* params_ptr = static_cast(¶ms); \ using BWD_ADD_RELU_FUNC = \ decltype(nhwc_batch_norm_bwd_add_relu); \ if (COOP) { \ cudaLaunchCooperativeKernel(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, \ SMEM_SIZE_BWD, stream); \ } else { \ cudaLaunchKernel(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, ¶ms_ptr, SMEM_SIZE_BWD, \ stream); \ } \ checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ } while (0) // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1) { if (occupancy >= 2) LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop); else LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop); } else { if (occupancy >= 2) LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop); else LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop); } #undef LAUNCH_BWD_KERNEL } public: // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); } // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); } }; const std::vector NhwcBatchNormAddRelu::numWorkspaceBytes() const { assert(c_ > 0); // choose the max memory required between fwd/bwd passes int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD); int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD); int grid_x = max(grid_x_fwd, grid_x_bwd); int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); const size_t num_mean_bytes = c_ * sizeof(float); const size_t num_variance_bytes = num_mean_bytes; int elems_per_group = ((m_ + 31) & ~31) * 2; int group_count = div_up(c_, C_ELEMENTS_PER_CTA); const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int); const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float); const size_t size_counts = grid_y * grid_x * sizeof(int); return {num_mean_bytes, num_variance_bytes, bitmask_bytes, size_retired_ctas(grid_y), size_sums, size_counts}; } void NhwcBatchNormAddRelu::setWorkspacePointers(const std::vector& workspace, const std::vector& num_workspace_bytes) { assert(workspace.size() == 6); assert(num_workspace_bytes.size() == 6); minibatch_mean_ = static_cast(workspace[0]); minibatch_variance_ = static_cast(workspace[1]); relu_bitmask_ = static_cast(workspace[2]); retired_ctas_ = static_cast(workspace[3]); partial_sums_ = static_cast(workspace[4]); partial_counts_ = static_cast(workspace[5]); } void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dst = static_cast(Y_); params->gmem_src1 = static_cast(addend_); params->gmem_bias = bias_; params->gmem_scale = scale_; params->gmem_running_mean = population_mean_; params->gmem_running_var = population_variance_; params->gmem_saved_mean = minibatch_mean_; params->gmem_saved_var = minibatch_variance_; params->gmem_relu_bitmask = relu_bitmask_; params->nhw = m_; params->c = c_; params->svar_inv_count = svar_inv_count_; params->rvar_inv_count = rvar_inv_count_; params->gmem_sums = partial_sums_; params->gmem_counts = partial_counts_; params->gmem_retired_ctas = retired_ctas_; params->var_eps = eps_; params->outer_loops = 0; params->exp_avg_factor = static_cast(exp_avg_factor_); params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dst = static_cast(Y_); params->gmem_src1 = static_cast(addend_); params->gmem_bias = bias_; params->gmem_scale = scale_; params->gmem_mean = population_mean_; params->gmem_var = population_variance_; params->nhw = m_; params->c = c_; params->var_eps = eps_; } void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams* params) const { params->gmem_src = static_cast(X_); params->gmem_dy = static_cast(dY_); params->gmem_dst = static_cast(dX_); params->gmem_dst1 = static_cast(dAddend_); params->gmem_relu_bitmask = relu_bitmask_; params->gmem_dscale = dscale_; params->gmem_dbias = dbias_; params->gmem_scale = scale_; params->gmem_bias = bias_; params->gmem_saved_mean = minibatch_mean_; params->gmem_saved_var = minibatch_variance_; params->nhw = m_; params->c = c_; params->svar_inv_count = svar_inv_count_; params->gmem_sums = partial_sums_; params->gmem_retired_ctas = retired_ctas_; params->outer_loops = 0; params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); } void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr // && minibatch_mean_ != nullptr // && minibatch_variance_ != nullptr && population_mean_ != nullptr && population_variance_ != nullptr && X_ != nullptr // && dX_ != nullptr && Y_ != nullptr && addend_ != nullptr // && dY_ != nullptr // && dscale_ != nullptr // && dbias_ != nullptr && partial_sums_ != nullptr && partial_counts_ != nullptr; if (!ptrs_are_set) die(); dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA); // @todo: maybe just move this inside initialize routine? NhwcBatchNormFwdInferenceParams params; _setFwdInferenceParams(¶ms); nhwc_batch_norm_fwd_inference <<>>(params); checkCudaStatus(name_ + " fwd_inference-relu kernel"); } dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); unsigned int max_grid_x = grid_dim_x; if (grid_dim.x <= max_grid_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); assert(grid_dim.y < MAX_GBN_BLOCK_Y); // FIXME: turn into a loop } else { grid_dim.y = 1; } } else { grid_dim.x = max_grid_x; grid_dim.y = 1; int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD * PIXELS_PER_LDG * grid_dim.x; int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD * PIXELS_PER_LDG * grid_dim.x; *loop = div_up(nhw_in_regs, pixels_per_iteration); } return grid_dim; } dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int* loop, const int grid_dim_x) { dim3 grid_dim; grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD); int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); unsigned int max_grid_x = grid_dim_x; if (grid_dim.x <= max_grid_x) { *loop = 1; if (max_grid_x / grid_dim.x > 1) { grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); assert(grid_dim.y < MAX_GBN_BLOCK_Y); // FIXME: turn into a loop } else { grid_dim.y = 1; } } else { grid_dim.x = max_grid_x; grid_dim.y = 1; int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD * PIXELS_PER_LDG * grid_dim.x; int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD * PIXELS_PER_LDG * grid_dim.x; *loop = div_up(nhw_in_regs, pixels_per_iteration); } return grid_dim; } void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr && minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && relu_bitmask_ != nullptr && population_mean_ != nullptr && population_variance_ != nullptr && X_ != nullptr // && dX_ != nullptr && Y_ != nullptr && addend_ != nullptr // && dY_ != nullptr // && dscale_ != nullptr // && dbias_ != nullptr && partial_sums_ != nullptr && partial_counts_ != nullptr && retired_ctas_ != nullptr; if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed NhwcBatchNormFwdParams params; _setFwdParams(¶ms); params.my_data = my_data; params.pair_datas[0] = pair_data; params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop); } void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr && minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && relu_bitmask_ != nullptr // && population_mean_ != nullptr // && population_variance_ != nullptr && X_ != nullptr && dX_ != nullptr // && Y_ != nullptr && dY_ != nullptr && dAddend_ != nullptr && dscale_ != nullptr && dbias_ != nullptr && retired_ctas_ != nullptr; if (!ptrs_are_set) die(); // reset of retired_cta_count no longer needed NhwcBatchNormBwdParams params; _setBwdParams(¶ms); params.my_data = my_data; params.pair_datas[0] = pair_data; params.pair_datas[1] = pair_data2; params.pair_datas[2] = pair_data3; params.magic = magic; params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1); params.wgrad_coeff = 1.0 / bn_group; dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop); } #endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ ================================================ FILE: apex/contrib/csrc/groupbn/cuda_utils.h ================================================ #include #ifndef CUDA_UTILS_H #define CUDA_UTILS_H namespace at { namespace cuda { namespace utils { static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; } } // namespace utils } // namespace cuda } // namespace at #endif ================================================ FILE: apex/contrib/csrc/groupbn/interface.cpp ================================================ #include #include #include #include #include #include #include #include "ATen/Generator.h" #include "ATen/Scalar.h" #include "ATen/Storage.h" #include "ATen/Tensor.h" namespace py = pybind11; int64_t get_buffer_size(const int bn_sync_steps); void* get_data_ptr(const at::Tensor& data); void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset); void close_remote_data(const at::Tensor& handle); at::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop); at::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon, const bool fuse_relu); std::vector nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop); at::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, const float momentum, const float epsilon, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop); at::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon); std::vector nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, const float momentum, const float epsilon, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x, const bool coop); int nhwc_bn_fwd_occupancy(); int nhwc_bn_bwd_occupancy(); int nhwc_bn_addrelu_fwd_occupancy(); int nhwc_bn_addrelu_bwd_occupancy(); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_buffer_size", &get_buffer_size, "get_buffer_size", py::call_guard()); m.def("get_data_ptr", &get_data_ptr, "get_data_ptr", py::call_guard()); m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr", py::call_guard()); m.def("close_remote_data", &close_remote_data, "close_remote_data", py::call_guard()); m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc", py::call_guard()); m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc", py::call_guard()); m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc", py::call_guard()); m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy", py::call_guard()); m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy", py::call_guard()); m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc", py::call_guard()); m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc", py::call_guard()); m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc", py::call_guard()); m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy", py::call_guard()); m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/groupbn/ipc.cu ================================================ #include #include #include #define cudaCheckErrors(msg) \ do { \ cudaError_t __err = cudaGetLastError(); \ if (__err != cudaSuccess) { \ fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \ fprintf(stderr, "*** FAILED - ABORTING\n"); \ exit(1); \ } \ } while (0) template <> struct std::hash { size_t operator()(const cudaIpcMemHandle_t& handle) const { size_t hash = 0; uint8_t* ptr = (uint8_t*)&handle; assert(sizeof(uint8_t) == 1); for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) { hash += *ptr; ptr++; } return hash; } }; template <> struct std::equal_to { bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const { return (std::memcmp((void*)&lhs, (void*)&rhs, sizeof(cudaIpcMemHandle_t)) == 0); } }; namespace { namespace gpuipc { // from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h // The number of threads per pixel. const int THREADS_PER_PIXEL = 16; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; // The number of reducing ops, each uses its own space : mean, var, dscale, dbias const int REDUCE_OPS = 4; // Maximum block.y supported - limited due to buffer allocation const int MAX_BLOCK_Y = 256; const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y; const int BYTES_PER_ELEM = 4; // Buffer size per sync step const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET * THREADS_PER_PIXEL * 2 * ELEMENTS_PER_LDG * BYTES_PER_ELEM; }; // namespace gpuipc class IpcMemHandleRegistry { public: void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) { if (registry_.count(handle) == 0) { registry_.insert(std::make_pair(handle, RegistryEntry())); registry_[handle].dev_ptr = ipcOpenMem(handle); } registry_[handle].ref_count++; return (((uint8_t*)registry_[handle].dev_ptr) + offset); } void releasePtr(const cudaIpcMemHandle_t& handle) { if (registry_.count(handle) == 0) { } if (--registry_[handle].ref_count == 0) { ipcCloseMem(registry_[handle].dev_ptr); registry_.erase(handle); } } struct RegistryEntry { void* dev_ptr; int ref_count; RegistryEntry() : dev_ptr(NULL), ref_count(0) {} }; protected: std::unordered_map registry_; void* ipcOpenMem(const cudaIpcMemHandle_t& handle) { void* data; cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess); cudaCheckErrors("ipc init"); return data; } void ipcCloseMem(void* dev_ptr) { cudaIpcCloseMemHandle(dev_ptr); cudaCheckErrors("ipc close"); } }; } // namespace static IpcMemHandleRegistry ipc_mem_registry; int64_t get_buffer_size(const int bn_sync_steps) { return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; } void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) { cudaIpcMemHandle_t my_handle; memcpy((unsigned char*)(&my_handle), handle.data_ptr(), sizeof(my_handle)); return ipc_mem_registry.getPtr(my_handle, offset); } void close_remote_data(const at::Tensor& handle) { cudaIpcMemHandle_t my_handle; memcpy((unsigned char*)(&my_handle), handle.data_ptr(), sizeof(my_handle)); ipc_mem_registry.releasePtr(my_handle); } void* get_data_ptr(const at::Tensor& data) { return data.data_ptr(); } ================================================ FILE: apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file nhwc_batch_norm_kernel.h * \brief CUDA NHWC Batch Normalization code * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer */ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ #include #include #define DEVICE_FUNCTION static inline __device__ // CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN. #define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3 #define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN //////////////////////////////////////////////////////////////////////////////////////////////////// template struct PackedStorage { enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG }; typedef T Type; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct PackedStorage { enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG / 2 }; typedef int Type; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2 * N]) { #pragma unroll for (int i = 0; i < N; ++i) { uint16_t lo, hi; asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2 * i + 0])); asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2 * i + 1])); asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void to_float(float (&dst)[2 * N], int (&src)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i])); asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2 * i + 0]) : "h"(lo)); asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2 * i + 1]) : "h"(hi)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t* gmem) { dst[0] = __ldg((const int*)gmem); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t* gmem) { unsigned int tmp; asm volatile("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l"((const uint*)gmem)); dst[0] = tmp; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t* gmem) { int2 tmp = __ldg((const int2*)gmem); dst[0] = tmp.x; dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t* gmem) { int2 tmp; asm volatile("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];" : "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2*)gmem)); dst[0] = tmp.x; dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t* gmem) { int tmp[N / 2]; ldg(tmp, gmem); to_float(dst, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t* gmem) { int tmp[N / 2]; ldg_stream(tmp, gmem); to_float(dst, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg(uint16_t* gmem, int (&src)[1]) { reinterpret_cast(gmem)[0] = src[0]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t* gmem, int (&src)[1]) { unsigned int tmp = src[0]; asm volatile("st.global.cs.s32 [%0], %1;" ::"l"((uint*)gmem), "r"(tmp)); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg(uint16_t* gmem, int (&src)[2]) { reinterpret_cast(gmem)[0] = make_int2(src[0], src[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t* gmem, int (&src)[2]) { asm volatile("st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"((uint*)gmem), "r"(src[0]), "r"(src[1])); } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void stg(uint16_t* gmem, float (&src)[N]) { int tmp[N / 2]; from_float(tmp, src); stg(gmem, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void stg_stream(uint16_t* gmem, float (&src)[N]) { int tmp[N / 2]; from_float(tmp, src); stg_stream(gmem, tmp); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float* gmem, int idx) { float2 tmp = __ldg(reinterpret_cast(&gmem[2 * idx])); dst[0] = tmp.x; dst[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float* gmem, int idx) { float4 tmp = __ldg(reinterpret_cast(&gmem[4 * idx])); dst[0] = tmp.x; dst[1] = tmp.y; dst[2] = tmp.z; dst[3] = tmp.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float* smem, int idx) { float2 tmp = *(const float2*)&smem[2 * idx]; x[0] = tmp.x; x[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int* smem, int idx) { x[0] = smem[idx]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float* smem, int idx) { float4 tmp = *(const float4*)&smem[4 * idx]; x[0] = tmp.x; x[1] = tmp.y; x[2] = tmp.z; x[3] = tmp.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int* smem, int idx) { int2 tmp = *(const int2*)&smem[2 * idx]; x[0] = tmp.x; x[1] = tmp.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&src)[2]) { reinterpret_cast(&gmem[2 * idx])[0] = make_float2(src[0], src[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&src)[4]) { reinterpret_cast(&gmem[4 * idx])[0] = make_float4(src[0], src[1], src[2], src[3]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void scaled_write_to_gmem(float* gmem, int idx, const float (&src)[4], const float coeff) { reinterpret_cast(&gmem[4 * idx])[0] = make_float4(src[0] * coeff, src[1] * coeff, src[2] * coeff, src[3] * coeff); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x)[2]) { reinterpret_cast(&smem[2 * idx])[0] = make_float2(x[0], x[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[1]) { smem[idx] = x[0]; } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x)[4]) { reinterpret_cast(&smem[4 * idx])[0] = make_float4(x[0], x[1], x[2], x[3]); } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[2]) { reinterpret_cast(&smem[2 * idx])[0] = make_int2(x[0], x[1]); } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void zero_array(int (&dst)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = 0; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void zero_array(float (&dst)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = 0.f; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { x[i] += y[i]; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { x[i] *= y[i]; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) { #pragma unroll for (int i = 0; i < N; ++i) { x[i] *= scalar; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N], const float (&scale)[N], const float (&m1)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { x[i] = bias[i] + scale[i] * (x[i] - m1[i]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION Storage relu(Storage in) { Storage zero = (Storage)0.f; return (in < zero) ? zero : in; } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_activation(float (&x)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { x[i] = relu(x[i]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void parallel_sums_16x2(float* smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const int sync_iters) { // The size of a warp. const int THREADS_PER_WARP = 32; // The number of warps in a CTA. const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; // The number of threads per pixel. const int THREADS_PER_PIXEL = 16; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; // The number of reducing ops, each uses its own space : mean, var, dscale, dbias const int REDUCE_OPS = 4; // Maximum block.y supported - limited due to buffer allocation const int MAX_BLOCK_Y = 256; const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y; // The warp decomposition. const int warp_id = threadIdx.x / THREADS_PER_WARP; const int lane_id = threadIdx.x % THREADS_PER_WARP; // total size of data per sync iter const int data_total = MAX_OFFSET * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); } // The warp leaders, write to SMEM. if (lane_id < THREADS_PER_PIXEL) { write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x); } // The data is in SMEM. Do the final reduction. __syncthreads(); // The 1st warp does all the work. // We do the final reduction each half-warp sequentially reduces the final values. if (warp_id == 0) { read_from_smem(x, smem, threadIdx.x); #pragma unroll for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP); // Compute the updated sum. add(x, y); } for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); } // Make sure the data was read from SMEM. __syncwarp(); // Store the final values. if (threadIdx.x < THREADS_PER_PIXEL) { // probably could do it earlier, before sync for (int sync_iter = 0; sync_iter < sync_iters; ++sync_iter) { // float* params_pair_data = (reinterpret_cast(params_pair_datas))[sync_iter]; void* params_pair_data = params_pair_datas[sync_iter]; // skip the space consumed by previous sync iterations const int xbuf_offset = sync_iter * data_total; // data starts after flags, but have to skip previous const int data_offset = xbuf_offset + off * ELEMENTS_PER_LDG * THREADS_PER_PIXEL * 2 + ELEMENTS_PER_LDG * threadIdx.x * 2; // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU if (blockIdx.x == 0) { volatile float* write_data = &((reinterpret_cast(params_pair_data))[data_offset]); // write the data to memory region to be reflected to other GPU asm volatile("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" ::"l"(write_data), "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic)); asm volatile("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" ::"l"(write_data + 4), "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic)); } // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU volatile float* read_data = &((reinterpret_cast(params_my_data))[data_offset]); float other[4]; uint32_t other_flag_a, other_flag_b; do { asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" : "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data)); } while ((other_flag_a != magic) || (other_flag_b != magic)); do { asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" : "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data + 4)); } while ((other_flag_a != magic) || (other_flag_b != magic)); add(x, other); } // finally, after syncing up and accounting for partial sums from // other GPUs as required, write the result write_to_smem(smem, threadIdx.x, x); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void parallel_sums_8x4(float* smem, float (&x)[4], int nhw) { // The size of a warp. const int THREADS_PER_WARP = 32; // The number of warps in a CTA. const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; // The number of threads per pixel. const int THREADS_PER_PIXEL = 8; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; // The warp decomposition. const int warp_id = threadIdx.x / THREADS_PER_WARP; const int lane_id = threadIdx.x % THREADS_PER_WARP; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id); } // The warp leaders, write to SMEM. if (lane_id < THREADS_PER_PIXEL) { write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x); } // The data is in SMEM. Do the final reduction. __syncthreads(); // The 1st warp does all the work. // We do the final reduction each half-warp sequentially reduces the final values. if (warp_id == 0) { read_from_smem(x, smem, threadIdx.x); #pragma unroll for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP); // Compute the updated sum. add(x, y); } for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id); x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id); } // Make sure the data was read from SMEM. __syncwarp(); // Store the final values. if (threadIdx.x < THREADS_PER_PIXEL) { write_to_smem(smem, threadIdx.x, x); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void parallel_sums(float* smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { // The size of a warp. const int THREADS_PER_WARP = 32; // The number of warps in a CTA. const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; // The number of pixels computed by a single warp. const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL; // The position in the warp. const int nhw_in_warp = nhw % PIXELS_PER_WARP; // The C in the warp. const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL; // Store the values to shared memory. write_to_smem(smem, threadIdx.x, x); // Compute the parallel sums. for (int offset = PIXELS_PER_WARP / 2; offset > 0; offset /= 2) { // NOP. __syncwarp(); // Read the running sum from the other thread. float y[ELEMENTS_PER_LDG]; if (nhw_in_warp < offset) { read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL); } // Compute the updated sum. add(x, y); // NOP. __syncwarp(); // Update the sum in SMEM. if (offset > 1 && nhw_in_warp < offset) { write_to_smem(smem, threadIdx.x, x); } } // The warps are done. Do the final reduction at the CTA level. __syncthreads(); // The warp leaders, write to SMEM. const int idx = (threadIdx.x / THREADS_PER_WARP) * THREADS_PER_PIXEL + c_in_warp; if (nhw_in_warp == 0) { write_to_smem(smem, idx, x); } // The data is in SMEM. Do the final reduction. __syncthreads(); // Read the 1st element to prepare the work. if (nhw < WARPS_PER_CTA / 2) { read_from_smem(x, smem, threadIdx.x); } // We have the running mean and running m2. Let's build the mean/var of the CTA. for (int offset = WARPS_PER_CTA / 2; offset > 0; offset /= 2) { // NOP. __syncwarp(); // Read the mean and variance from the other pixel. float y[ELEMENTS_PER_LDG]; if (nhw < offset) { read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL); } // Compute the updated sum. add(x, y); // NOP. __syncwarp(); // Store the mean/var for the different pixels. if (nhw < offset) { write_to_smem(smem, threadIdx.x, x); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct ParallelSums { template DEVICE_FUNCTION void dispatch(float* smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { parallel_sums(smem, x, nhw); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct ParallelSums<16, 4> { template DEVICE_FUNCTION void dispatch(float* smem, float (&x)[4], int nhw) { parallel_sums_16x2(smem, x, nhw, 0, 0, 0, 0, 0); } template DEVICE_FUNCTION void dispatchX(float* smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) { parallel_sums_16x2(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters); } }; template <> struct ParallelSums<8, 4> { template DEVICE_FUNCTION void dispatch(float* smem, float (&x)[4], int nhw) { parallel_sums_8x4(smem, x, nhw); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// static inline int div_up(int m, int n) { return (m + n - 1) / n; } //////////////////////////////////////////////////////////////////////////////////////////////////// // It is expected that all threads in the CTA enter this function! DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) { // Register the CTA. if (threadIdx.x == 0) { // Issue the membar. __threadfence(); // Notify that the CTA is done. int val_to_add = 1; if (master) { val_to_add = -(expected_count - 1); } atomicAdd(gmem_retired_ctas, val_to_add); } // Are all CTAs done? if (threadIdx.x == 0) { int retired_ctas = -1; do { __threadfence(); asm volatile("ld.global.cg.b32 %0, [%1];" : "=r"(retired_ctas) : "l"(gmem_retired_ctas)); } while (retired_ctas != 0); } __syncthreads(); } //////////////////////////////////////////////////////////////////////////////////////////////////// struct NhwcBatchNormFwdInferenceParams { // The input/output tensors. uint16_t *gmem_src, *gmem_dst, *gmem_src1; // the final mean and variance as calculated during the training process float *gmem_mean, *gmem_var; // The bias/scale. float *gmem_bias, *gmem_scale; // The dimensions. int nhw, c; // epsilon float var_eps; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively template __global__ __launch_bounds__(THREADS_PER_CTA) void nhwc_batch_norm_fwd_inference( NhwcBatchNormFwdInferenceParams params) { // The number of pixels loaded in a single LDG. const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; // The number of C elements per CTA. const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; // The start position in the NHW dimension where the CTA starts. const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; // thread's starting point in NHW const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG; // The position in the C dimension where the CTA starts. const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG]; float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG]; zero_array(mean); zero_array(var); zero_array(scale); zero_array(bias); if (is_valid_c) { read_from_gmem(var, ¶ms.gmem_var[cta_c], thread_in_cta_c); read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); read_from_gmem(mean, ¶ms.gmem_mean[cta_c], thread_in_cta_c); read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); } // Update the scale with the stddev and eps. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { scale[i] *= rsqrtf(var[i] + params.var_eps); } // The base pointers for reading/writing uint16_t* const gmem_src = ¶ms.gmem_src[thread_c]; uint16_t* const gmem_dst = ¶ms.gmem_dst[thread_c]; const uint16_t* gmem_src1 = nullptr; if (USE_ADD_RELU) { gmem_src1 = ¶ms.gmem_src1[thread_c]; } // apply BN for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) { float x_math[ELEMENTS_PER_LDG]; zero_array(x_math); if (is_valid_c) { ldg(x_math, &gmem_src[nhw * params.c]); } // Normalize and apply activation function normalize(x_math, bias, scale, mean); if (USE_ADD_RELU) { float x1_math[ELEMENTS_PER_LDG]; ldg(x1_math, &gmem_src1[nhw * params.c]); add(x_math, x1_math); relu_activation(x_math); } else if (USE_RELU) { relu_activation(x_math); } if (is_valid_c) { stg(&gmem_dst[nhw * params.c], x_math); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// struct NhwcBatchNormFwdParams { // The input/output tensors. uint16_t *gmem_src, *gmem_dst, *gmem_src1; // The bias/scale. float *gmem_bias, *gmem_scale; // running mean/var (refer BN API from cudnn doc) float *gmem_running_mean, *gmem_running_var; // saved mean/var (refer BN API from cudnn doc) float *gmem_saved_mean, *gmem_saved_var; // ReLU bitmask unsigned int* gmem_relu_bitmask; // The dimensions. int nhw, c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. float svar_inv_count; // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1). float rvar_inv_count; // The buffer to do the reduction for mean, stddev and count. float* gmem_sums; // The buffer to count items in the different CTAs. int* gmem_counts; // The counters of retired CTAs. int* gmem_retired_ctas; // The epsilon to apply to the computation of the variance. float var_eps; // outer loop count int outer_loops; // exponential average factor float exp_avg_factor; // number of CTAs along .x dimension int c_blks; void* my_data; void* pair_datas[4]; int magic; int sync_iters; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) { // The number of pixels loaded in a single LDG. const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; // The number of pixels computed per CTA stored in registers. const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; // The number of pixels computed per CTA stored in SMEM. const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; // The number of C elements per CTA. const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; // The adapter for the storage. typedef PackedStorage PackedStorage_; // The data type for packed storage in SMEM. typedef typename PackedStorage_::Type PackedStorageType; // The number of elements in the packed storage. const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; // Registers to keep the data live for the persistent approach. PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; // Shared memory buffer to store the extra pixels. extern __shared__ PackedStorageType smem_storage_packed[]; for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { // The position in the NHW dimension where the CTA starts. int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; // The position in the NHW dimension where the CTA starts for the portion in SMEM. int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; // The position in the C dimension where the CTA starts. const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; // Clamp thread_c so that we load from valid locations even if we don't use the value if (!is_valid_c) thread_c = params.c - 4; // Single pass numerically stable algorithm, see: // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm // // n = 0, mean = 0.0, M2 = 0.0 // // for x in data: // n += 1 // delta = x - mean // mean += delta/n // delta2 = x - mean // M2 += delta*delta2 // // if n < 2: // return float('nan') // else: // return M2 / (n - 1) // Register to store the number of elements read so far. float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG]; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { mean[i] = 0.f; m2[i] = 0.f; } // The number of elements loaded by this CTA. int cta_count = 0; // The base pointer to load from. const uint16_t* gmem_src = ¶ms.gmem_src[thread_c]; // outer loops int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; // Load the batch of elements. Compute the mean/var across those elements. const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; if (OUTER_LOOPS_ != 1) { // We cannot load everything to store persistently, so let's makes sure registers and // smem are fully utilized, offset is evenly divisible by 32 int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; cta_nhw_regs -= offset; cta_nhw_smem -= offset; } #pragma unroll 1 for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { // The nhw position. int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) - max(nhw_regs, 0), 0); // Load the data and compute the local mean/sum and the variance. if (USE_ONLINE_APPROACH) { // Read the elements from memory. float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; zero_array(x_storage[i]); is_valid[i] = 0.f; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); } else { ldg(x_storage[i], &gmem_src[idx * params.c]); } is_valid[i] = 1.f; } } // Do the math. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float. float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); // Update the count. count += is_valid[i]; // Invert the count. float inv_count = is_valid[i] ? 1.f / count : 0.f; // Update the mean and m2 using deltas. #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { float delta0 = x_math[j] - mean[j]; mean[j] += delta0 * inv_count; float delta1 = x_math[j] - mean[j]; m2[j] += delta0 * delta1 * is_valid[i]; } } } else { // Read the elements from memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; zero_array(x_storage[i]); if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); } else { ldg(x_storage[i], &gmem_src[idx * params.c]); } count += 1.f; } } // Sum the elements in registers. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float. float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); // Update the mean and m2 using deltas. #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { mean[j] += x_math[j]; } } // Compute the mean. float inv_count = 1.f / count; #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { mean[j] *= inv_count; } // Compute the variance. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float. float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); // Is it a valid pixel? float is_valid = i < static_cast(count) ? 1.f : 0.f; // Update the mean and m2 using deltas. #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid; } } } } // The elements to load and store in SMEM. int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; // Load elements from SMEM, update the CTA count. int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0); if (pixels_in_smem > 0) { cta_count += pixels_in_smem; for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; float is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0) * params.c]); // The offset to store in SMEM. const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; // Store in SMEM. write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); // Update the count. count += is_pixel_valid; // Invert the count. float inv_count = is_pixel_valid ? 1.f / count : 0.f; float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); // Update the mean and m2 using deltas. #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { float delta0 = x_math[j] - mean[j]; mean[j] += delta0 * inv_count; float delta1 = x_math[j] - mean[j]; m2[j] += delta0 * delta1 * is_pixel_valid; } } } // We scale the mean by the number of elements. It brings more stability. float m1[ELEMENTS_PER_LDG]; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { m1[i] = mean[i] * count; } // Run the parallel sum accross the CTA to get the local sum. ParallelSums::dispatch(smem, m1, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(m1, smem, thread_in_cta_c); __syncthreads(); // Adjust the variance. float inv_cta_count = 1.f / static_cast(cta_count); #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { float mean_diff = m1[i] * inv_cta_count - mean[i]; m2[i] = m2[i] + mean_diff * mean_diff * count; } // Run the parallel sum accross the CTA to get the local adjusted variance. ParallelSums::dispatch(smem, m2, thread_in_cta_nhw); // The workspace in global memory is distributed across the different CTA. int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; // Write the data for the CTA to global memory. float* gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; if (threadIdx.x < THREADS_PER_PIXEL) { const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; write_to_gmem(&gmem_sums[0], idx, m1); write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, m2); } // The memory location to store the number of pixels per CTA. int* gmem_counts = ¶ms.gmem_counts[c_blk_index * gridDim.x]; if (threadIdx.x == 0) { gmem_counts[blockIdx.x] = cta_count; } // Read the bias and scale. float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG]; if (is_valid_c) { read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); } // The counters to count how many CTAs have retired at this point. // A given cta uses the same counter every other time through the outer loop. int* gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); // Reset the mean to compute the global mean. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { m1[i] = 0.f; } // Build the global mean. #pragma unroll 1 for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { float tmp[ELEMENTS_PER_LDG]; read_from_gmem(tmp, gmem_sums, idx); add(m1, tmp); } if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 3, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, m1, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(m1, smem, thread_in_cta_c); __syncthreads(); // Normalize the mean. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { m1[i] = m1[i] * params.svar_inv_count; } // Reset the variance. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { m2[i] = 0.f; } // for add+relu fusion const uint16_t* gmem_src1 = nullptr; if (USE_ADD_RELU) { gmem_src1 = ¶ms.gmem_src1[thread_c]; } // Build the global variance. #pragma unroll 1 for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration. float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG]; read_from_gmem(tmp_mean, &gmem_sums[0], idx); read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx); // Read the number of pixels visited by a given CTA. cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]); // Compute the diff to update the variance. float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast(cta_count); #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { mean_diff[i] = m1[i] - tmp_mean[i] * inv_cta_count; } // Update the variance. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { m2[i] += tmp_var[i] + mean_diff[i] * mean_diff[i] * static_cast(cta_count); } } if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 2, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, m2, thread_in_cta_nhw); } __syncthreads(); read_from_smem(m2, smem, thread_in_cta_c); // Finalize the stddev. // becasue saved var and running var may have different denominator, we don't do it here // scale_(m2, inv_count); // store the saved mean/var float svarinv[ELEMENTS_PER_LDG]; bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps); } if (is_valid_for_saving) { write_to_gmem(params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG, m1); write_to_gmem(params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG, svarinv); } // store the running mean/var float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG]; zero_array(rmean); zero_array(rvar); if (params.exp_avg_factor != 1.f && is_valid_for_saving) { read_from_gmem(rmean, params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG); read_from_gmem(rvar, params.gmem_running_var, thread_c / ELEMENTS_PER_LDG); } #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + params.exp_avg_factor * m1[i]; rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + params.exp_avg_factor * (m2[i] * params.rvar_inv_count); } if (is_valid_for_saving) { write_to_gmem(params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG, rmean); write_to_gmem(params.gmem_running_var, thread_c / ELEMENTS_PER_LDG, rvar); } // Update the scale with the stddev and eps. multiply(scale, svarinv); // The base pointer to write to. uint16_t* const gmem_dst = ¶ms.gmem_dst[thread_c]; unsigned int* const gmem_relu_bitmask = params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index; // Store the elements in registers. #pragma unroll 1 for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { // The value for nhw. int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; // Normalize the elements and write to memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); const bool is_valid = is_valid_nhw && is_valid_c; // Convert to float. float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); // Normalize and apply activation function normalize(x_math, bias, scale, m1); if (USE_ADD_RELU) { float x1_math[ELEMENTS_PER_LDG]; ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]); add(x_math, x1_math); unsigned int relu_mask; int lane_id = threadIdx.x & 31; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { bool rectified = x_math[i] < 0.0F; unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); if (lane_id == i) { // Thread 0 remembers the relu_mask from the first time through this // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last. relu_mask = local_relu_mask; } if (rectified) { x_math[i] = 0.0F; } } if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; } } else if (USE_RELU) { relu_activation(x_math); } // Write back. if (is_valid) { stg_stream(&gmem_dst[idx * params.c], x_math); } } // The next value of nhw. out_nhw -= pixels_per_iteration; // Read the next elements from memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); } } } // Normalize the elements from SMEM and write them out. if (pixels_in_smem > 0) { #pragma unroll 2 for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); const bool is_valid = is_valid_nhw && is_valid_c; // Read from SMEM. const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); float x_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); // Normalize and apply activation function normalize(x_math, bias, scale, m1); if (USE_ADD_RELU) { float x1_math[ELEMENTS_PER_LDG]; ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]); add(x_math, x1_math); unsigned int relu_mask; int lane_id = threadIdx.x & 31; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { bool rectified = x_math[i] < 0.0F; unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); if (lane_id == i) { relu_mask = local_relu_mask; } if (rectified) { x_math[i] = 0.0F; } } if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; } } else if (USE_RELU) { relu_activation(x_math); } // Write back. if (is_valid) { stg_stream(&gmem_dst[idx * params.c], x_math); } } } // We're about to start on the next c-blk. Needed? __syncthreads(); } } //////////////////////////////////////////////////////////////////////////////////////////////////// struct NhwcBatchNormBwdParams { // The input/output tensors. uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1; // dscale/dbias float *gmem_dscale, *gmem_dbias; // The scale and bias. float *gmem_scale, *gmem_bias; // The mean/inv-var saved from fwd pass float *gmem_saved_mean, *gmem_saved_var; // ReLU bitmask unsigned int* gmem_relu_bitmask; // The dimensions. int nhw, c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. float svar_inv_count; // The buffer to do the reduction for dscale and dbias float* gmem_sums; // The counters of retired CTAs. int* gmem_retired_ctas; // outer loop count int outer_loops; // number of CTAs along .x dimension int c_blks; void* my_data; void* pair_datas[4]; int magic; int sync_iters; float wgrad_coeff; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N], const float (&var_scale)[N], bool valid_data) { #pragma unroll for (int j = 0; j < N; ++j) { float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; if ((y <= 0.f) && valid_data) { dy[j] = 0.f; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) { #pragma unroll for (int j = 0; j < N; ++j) { if ((y[j] <= 0.f) && valid_data) { dy[j] = 0.f; } } } template DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) { #pragma unroll for (int j = 0; j < N; ++j) { if (rectified[j] && valid_data) { dy[j] = 0.f; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N], const float (&var_scale)[N]) { #pragma unroll for (int j = 0; j < N; ++j) { float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; if (y <= 0.f) { dy[j] = 0.f; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) { #pragma unroll for (int j = 0; j < N; ++j) { if (y[j] <= 0.f) { dy[j] = 0.f; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N], const float (&dy)[N], const float (&x)[N], const float (&mean)[N], float inv_count) { #pragma unroll for (int j = 0; j < N; ++j) { float delta0 = dy[j] - dbias[j]; dbias[j] += delta0 * inv_count; delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j]; dscale[j] += delta0 * inv_count; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N], const float (&var)[N], const float (&x)[N], const float (&mean)[N], const float (&dscale)[N], const float (&dbias)[N], float inv_count) { #pragma unroll for (int j = 0; j < N; ++j) { float tmp1 = dy[j] - (dbias[j] * inv_count); float tmp2 = dscale[j] * inv_count; float tmp3 = x[j] - mean[j]; dx[j] = var[j] * (tmp1 - (tmp2 * tmp3)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) { // The number of pixels loaded in a single LDG. const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; // The number of pixels computed per CTA stored in registers. const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; // The number of pixels computed per CTA stored in SMEM. const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; // The number of C elements per CTA. const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; // The data type for packed storage in SMEM. typedef typename PackedStorage_::Type PackedStorageType; // The number of elements in the packed storage. const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; // Registers to keep the data live for the persistent approach. PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; // Shared memory buffer to store the extra pixels. extern __shared__ PackedStorageType smem_storage_packed[]; for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { // The position in the NHW dimension where the CTA starts. int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; // The position in the NHW dimension where the CTA starts for the portion in SMEM. int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; // The position in the C dimension where the CTA starts. const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; // Registers to store the mean used for entire duration float mean[ELEMENTS_PER_LDG]; zero_array(mean); if (is_valid_c) { read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); } // accumulation related registers float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; zero_array(dscale); zero_array(dbias); // The number of elements loaded by this CTA. int cta_count = 0; // The base pointers to load from. const uint16_t* gmem_src = ¶ms.gmem_src[thread_c]; const uint16_t* gmem_dy = ¶ms.gmem_dy[thread_c]; // outer loops int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; // Load the batch of elements. Compute sum across them const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; if (OUTER_LOOPS_ != 1) { // We cannot load everything to store persistently, so let's makes sure registers and // smem are fully utilized int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x; cta_nhw_regs += offset; cta_nhw_smem += offset; } #pragma unroll 1 for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { // The nhw position. int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); // Read the elements from memory. float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; zero_array(x_storage[i]); zero_array(dy_storage[i]); is_valid[i] = 0.f; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); } else { ldg(x_storage[i], &gmem_src[idx * params.c]); ldg(dy_storage[i], &gmem_dy[idx * params.c]); } is_valid[i] = 1.f; } } // Do the math. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float and update float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); // Update the count. count += is_valid[i]; // Invert the count. float inv_count = is_valid[i] ? 1.f / count : 0.f; bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); } } // The elements to load and store in SMEM. int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; // Load elements from SMEM, update the CTA count. int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); if (pixels_in_smem > 0) { cta_count += pixels_in_smem; for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c); PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; zero_array(x_storage_local); zero_array(dy_storage_local); if (is_pixel_valid) { ldg_stream(x_storage_local, &gmem_src[idx * params.c]); ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); } // The offset to store in SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; // Store in SMEM. write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); // Update the count. count += is_pixel_valid; // Invert the count. float inv_count = is_pixel_valid ? 1.f / count : 0.f; float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); } } // We scale the mean by the number of elements. It brings more stability. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dbias[i] *= count; dscale[i] *= count; } // dscale parallel sum ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); __syncthreads(); // The workspace in global memory is distributed across the different CTA. int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; // Write the data for the CTA to global memory. float* gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; if (threadIdx.x < THREADS_PER_PIXEL) { const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; write_to_gmem(&gmem_sums[0], idx, dscale); write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); } // The counters to count how many CTAs have retired at this point. // A given cta uses the same counter every other time through the outer loop. int* gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); // Reset the accumulators for global summation zero_array(dscale); zero_array(dbias); // Build the global accumulation #pragma unroll 1 for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; read_from_gmem(tmp1, gmem_sums, idx); read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dscale[i] += tmp1[i]; dbias[i] += tmp2[i]; } } // dscale parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); // inv-var float var[ELEMENTS_PER_LDG]; zero_array(var); if (is_valid_c) { read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); } // Normalize the dscale. multiply(dscale, var); // store dscale/dbias bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; if (is_valid_for_saving) { if (params.sync_iters > 0) { scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); } else { write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); } } // scale float scale[ELEMENTS_PER_LDG]; zero_array(scale); if (is_valid_c) { read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); } // Further normalize the dscale to be used in dx calculation multiply(dscale, var); // scale the inv-var as well, afterwards multiply(var, scale); // inverse count float inv_count = params.svar_inv_count; // The base pointer to write to. uint16_t* const gmem_dst = ¶ms.gmem_dst[thread_c]; // Store the elements in registers. #pragma unroll 1 for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { // The value for nhw. int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; // Normalize the elements and write to memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float. float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { stg_stream(&gmem_dst[idx * params.c], dx); } } // The next value of nhw. out_nhw -= pixels_per_iteration; // Read the next elements from memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); } } } // Normalize the elements from SMEM and write them out. if (pixels_in_smem > 0) { for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; if (is_valid) { // Read from SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. stg_stream(&gmem_dst[idx * params.c], dx); } } } // We're about to start on the next c-blk. Needed? __syncthreads(); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) { // The number of pixels loaded in a single LDG. const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; // The number of pixels computed per CTA stored in registers. const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; // The number of pixels computed per CTA stored in SMEM. const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; // The number of C elements per CTA. const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; // The data type for packed storage in SMEM. typedef typename PackedStorage_::Type PackedStorageType; // The number of elements in the packed storage. const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; // Registers to keep the data live for the persistent approach. PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; // Shared memory buffer to store the extra pixels. extern __shared__ PackedStorageType smem_storage_packed[]; for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { // The position in the NHW dimension where the CTA starts. int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; // The position in the NHW dimension where the CTA starts for the portion in SMEM. int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; // The position in the C dimension where the CTA starts. const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; // Registers to store the mean/var/scale/bias used for the entire duration // Register usage optimizations: // 1. Can combine bias - (mean * var * scale) into a single register // 2. Can combine var * scale into a single register float varscale[ELEMENTS_PER_LDG]; zero_array(varscale); if (is_valid_c) { read_from_gmem(varscale, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); } float tmp[ELEMENTS_PER_LDG]; zero_array(tmp); if (is_valid_c) { read_from_gmem(tmp, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); } multiply(varscale, tmp); float mean[ELEMENTS_PER_LDG]; zero_array(mean); if (is_valid_c) { read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); } zero_array(tmp); if (is_valid_c) { read_from_gmem(tmp, params.gmem_bias, thread_c / ELEMENTS_PER_LDG); } float mean_var_scale_bias[ELEMENTS_PER_LDG]; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]); } // accumulation related registers float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; zero_array(dscale); zero_array(dbias); // The number of elements loaded by this CTA. int cta_count = 0; // The base pointers to load from. const uint16_t* gmem_src = ¶ms.gmem_src[thread_c]; const uint16_t* gmem_dy = ¶ms.gmem_dy[thread_c]; // outer loops int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; // Load the batch of elements. Compute sum across them const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; if (OUTER_LOOPS_ != 1) { // We cannot load everything to store persistently, so let's makes sure registers and // smem are fully utilized int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x; cta_nhw_regs += offset; cta_nhw_smem += offset; } #pragma unroll 1 for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { // The nhw position. int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); // Read the elements from memory. float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; zero_array(x_storage[i]); zero_array(dy_storage[i]); is_valid[i] = 0.f; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); } else { ldg(x_storage[i], &gmem_src[idx * params.c]); ldg(dy_storage[i], &gmem_dy[idx * params.c]); } is_valid[i] = 1.f; } } // Do the math. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float and update float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); // Update the count. count += is_valid[i]; // Invert the count. float inv_count = is_valid[i] ? 1.f / count : 0.f; relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]); bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); } } // The elements to load and store in SMEM. int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; // Load elements from SMEM, update the CTA count. int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); if (pixels_in_smem > 0) { cta_count += pixels_in_smem; for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c); PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; zero_array(x_storage_local); zero_array(dy_storage_local); if (is_pixel_valid) { ldg_stream(x_storage_local, &gmem_src[idx * params.c]); ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); } // The offset to store in SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; // Store in SMEM. write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); // Update the count. count += is_pixel_valid; // Invert the count. float inv_count = is_pixel_valid ? 1.f / count : 0.f; float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid); bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); } } // We scale the mean by the number of elements. It brings more stability. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dbias[i] *= count; dscale[i] *= count; } // dscale parallel sum ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); __syncthreads(); // The workspace in global memory is distributed across the different CTA. int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; // Write the data for the CTA to global memory. float* gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; if (threadIdx.x < THREADS_PER_PIXEL) { const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; write_to_gmem(&gmem_sums[0], idx, dscale); write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); } // The counters to count how many CTAs have retired at this point. // A given cta uses the same counter every other time through the outer loop. int* gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); // Reset the accumulators for global summation zero_array(dscale); zero_array(dbias); // Build the global accumulation #pragma unroll 1 for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; read_from_gmem(tmp1, gmem_sums, idx); read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dscale[i] += tmp1[i]; dbias[i] += tmp2[i]; } } // dscale parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); // Normalize the dscale. float var[ELEMENTS_PER_LDG]; zero_array(var); if (is_valid_c) { read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); } multiply(dscale, var); // store dscale/dbias bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; if (is_valid_for_saving) { if (params.sync_iters > 0) { scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); } else { write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); } } // Further normalize the dscale to be used in dx calculation float scale[ELEMENTS_PER_LDG]; zero_array(scale); if (is_valid_c) { read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); } multiply(dscale, var); // scale the inv-var as well, afterwards multiply(var, scale); // inverse count float inv_count = params.svar_inv_count; // The base pointer to write to. uint16_t* const gmem_dst = ¶ms.gmem_dst[thread_c]; // Store the elements in registers. #pragma unroll 1 for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { // The value for nhw. int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; // Normalize the elements and write to memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { // Convert to float. float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { stg_stream(&gmem_dst[idx * params.c], dx); } } // The next value of nhw. out_nhw -= pixels_per_iteration; // Read the next elements from memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); } } } // Normalize the elements from SMEM and write them out. if (pixels_in_smem > 0) { for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; if (is_valid) { // Read from SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. stg_stream(&gmem_dst[idx * params.c], dx); } } } // We're about to start on the next c-blk. Needed? __syncthreads(); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) { // The number of pixels loaded in a single LDG. const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; // The number of pixels computed per CTA stored in registers. const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; // The number of pixels computed per CTA stored in SMEM. const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG; // The number of C elements per CTA. const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; // The data type for packed storage in SMEM. typedef typename PackedStorage_::Type PackedStorageType; // The number of elements in the packed storage. const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; // Registers to keep the data live for the persistent approach. PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; // Shared memory buffer to store the extra pixels. extern __shared__ PackedStorageType smem_storage_packed[]; for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { // The position in the NHW dimension where the CTA starts. int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; // The position in the NHW dimension where the CTA starts for the portion in SMEM. int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; // The position in the C dimension where the CTA starts. const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; // Compute the C coordinate of the thread in the CTA. const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; // Compute the C coordinate of the thread. const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG; // Is the thread working on a valid C dimension? const int is_valid_c = thread_c < params.c; float mean[ELEMENTS_PER_LDG]; zero_array(mean); if (is_valid_c) { read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG); } // accumulation related registers float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; zero_array(dscale); zero_array(dbias); // The number of elements loaded by this CTA. int cta_count = 0; // The base pointers to load from. const uint16_t* gmem_src = ¶ms.gmem_src[thread_c]; const uint16_t* gmem_dy = ¶ms.gmem_dy[thread_c]; uint16_t* gmem_dst1 = ¶ms.gmem_dst1[thread_c]; // outer loops int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops; // Load the batch of elements. Compute sum across them const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x; if (OUTER_LOOPS_ != 1) { // We cannot load everything to store persistently, so let's makes sure registers and // smem are fully utilized, offset is evenly divisible by 32 int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; cta_nhw_regs -= offset; cta_nhw_smem -= offset; } const unsigned int* const gmem_relu_bitmask = params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index; #pragma unroll 1 for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { // The nhw position. int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration; // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs)); int lane_id = threadIdx.x & 31; // Read the elements from memory. float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; zero_array(x_storage[i]); zero_array(dy_storage[i]); is_valid[i] = 0.f; const bool is_valid_nhw = static_cast(idx) < static_cast(params.nhw); if (is_valid_nhw) { if (is_valid_c) { if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]); } else { ldg(x_storage[i], &gmem_src[idx * params.c]); ldg(dy_storage[i], &gmem_dy[idx * params.c]); } is_valid[i] = 1.f; } if (lane_id < ELEMENTS_PER_LDG) { relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id]; } } } // Do the math. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG; // Convert to float and update float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; bool rectified[ELEMENTS_PER_LDG]; #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) & (1U << lane_id)) != 0); } to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); // Update the count. count += is_valid[i]; // Invert the count. float inv_count = is_valid[i] ? 1.f / count : 0.f; relu_bwd(dy_math, rectified, is_valid[i]); bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version from_float(dy_storage[i], dy_math); // dZ for elementwise add if (is_valid[i]) { if (loop_i == OUTER_LOOPS - 1) { stg_stream(&gmem_dst1[idx * params.c], dy_storage[i]); } else { stg(&gmem_dst1[idx * params.c], dy_storage[i]); } } } } // The elements to load and store in SMEM. int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem; // Load elements from SMEM, update the CTA count. int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw); if (pixels_in_smem > 0) { cta_count += pixels_in_smem; for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_pixel_valid_nhw = static_cast(idx) < static_cast(params.nhw); const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; unsigned int relu_mask; int lane_id = threadIdx.x & 31; zero_array(x_storage_local); zero_array(dy_storage_local); if (is_pixel_valid_nhw) { if (is_valid_c) { ldg_stream(x_storage_local, &gmem_src[idx * params.c]); ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]); } if (lane_id < ELEMENTS_PER_LDG) { relu_mask = gmem_relu_bitmask[idx * 2 + lane_id]; } } bool rectified[ELEMENTS_PER_LDG]; #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) & (1U << lane_id)) != 0); } // The offset to store in SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; // Store in SMEM. write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; // Update the count. count += is_pixel_valid; // Invert the count. float inv_count = is_pixel_valid ? 1.f / count : 0.f; float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); relu_bwd(dy_math, rectified, is_pixel_valid); bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); from_float(dy_storage_local, dy_math); // dZ for elementwise add if (is_pixel_valid) { stg_stream(&gmem_dst1[idx * params.c], dy_storage_local); } // only store the 'relu-dgrad'ed version! write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); } } // We scale the mean by the number of elements. It brings more stability. #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dbias[i] *= count; dscale[i] *= count; } // dscale parallel sum ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); __syncthreads(); // The workspace in global memory is distributed across the different CTA. int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2; // Write the data for the CTA to global memory. float* gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; if (threadIdx.x < THREADS_PER_PIXEL) { const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x; write_to_gmem(&gmem_sums[0], idx, dscale); write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias); } // The counters to count how many CTAs have retired at this point. // A given cta uses the same counter every other time through the outer loop. int* gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); // Reset the accumulators for global summation zero_array(dscale); zero_array(dbias); // Build the global accumulation #pragma unroll 1 for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) { float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; read_from_gmem(tmp1, gmem_sums, idx); read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx); #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { dscale[i] += tmp1[i]; dbias[i] += tmp2[i]; } } // dscale parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dscale, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dscale, smem, thread_in_cta_c); __syncthreads(); // dbias parallel sum if (params.sync_iters > 0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic, params.sync_iters); } else { ParallelSums::dispatch(smem, dbias, thread_in_cta_nhw); } __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. read_from_smem(dbias, smem, thread_in_cta_c); // Normalize the dscale. float var[ELEMENTS_PER_LDG]; zero_array(var); if (is_valid_c) { read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG); } multiply(dscale, var); // store dscale/dbias bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; if (is_valid_for_saving) { if (params.sync_iters > 0) { scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); } else { write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale); write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias); } } // Further normalize the dscale to be used in dx calculation float scale[ELEMENTS_PER_LDG]; zero_array(scale); if (is_valid_c) { read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG); } multiply(dscale, var); // scale the inv-var as well, afterwards multiply(var, scale); // inverse count float inv_count = params.svar_inv_count; // The base pointer to write to. uint16_t* const gmem_dst = ¶ms.gmem_dst[thread_c]; // Store the elements in registers. #pragma unroll 1 for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) { // The value for nhw. int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration; // Normalize the elements and write to memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; // Convert to float. float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. if (is_valid) { stg_stream(&gmem_dst[idx * params.c], dx); } } // The next value of nhw. out_nhw -= pixels_per_iteration; // Read the next elements from memory. #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; float y[ELEMENTS_PER_LDG]; zero_array(y); if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { ldg_stream(x_storage[i], &gmem_src[idx * params.c]); ldg_stream(dy_storage[i], &gmem_dst1[idx * params.c]); } } } // Normalize the elements from SMEM and write them out. if (pixels_in_smem > 0) { for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG; const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; if (is_valid) { // Read from SMEM. int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG; read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; to_float(x_math, x_storage_local); to_float(dy_math, dy_storage_local); float dx[ELEMENTS_PER_LDG]; bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); // Write back. stg_stream(&gmem_dst[idx * params.c], dx); } } } // We're about to start on the next c-blk. Needed? __syncthreads(); } } #endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ ================================================ FILE: apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp ================================================ #include #include #include void index_mul_2d_float_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); void index_mul_2d_float_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); void index_mul_2d_float_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); void index_mul_2d_half_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); void index_mul_2d_half_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); void index_mul_2d_half_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) void index_mul_2d_float_forward(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_float_foward_cuda(out, in1, in2, idx1); } void index_mul_2d_float_backward(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } void index_mul_2d_float_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); } void index_mul_2d_half_forward(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_half_foward_cuda(out, in1, in2, idx1); } void index_mul_2d_half_backward(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } void index_mul_2d_half_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("float_forward", &index_mul_2d_float_forward, "index mul float calculation forward (CUDA)", py::call_guard()); m.def("float_backward", &index_mul_2d_float_backward, "index mul float calculation backward (CUDA)", py::call_guard()); m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, "index mul float calculation backward backward (CUDA)", py::call_guard()); m.def("half_forward", &index_mul_2d_half_forward, "index mul half calculation forward (CUDA)", py::call_guard()); m.def("half_backward", &index_mul_2d_half_backward, "index mul half calculation backward (CUDA)", py::call_guard()); m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, "index mul half calculation backward backward (CUDA)", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu ================================================ #include #include #include #include __global__ void index_mul_2d_float_dim64(float* out, const float* in1, const float* in2, const int64_t* idx1, const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; constexpr int fea_dim = 64; if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; float4 res, src1, src2; src1 = reinterpret_cast(in1)[vec_idx1]; src2 = reinterpret_cast(in2)[vec_idx2]; res.x = src1.x * src2.x; res.y = src1.y * src2.y; res.z = src1.z * src2.z; res.w = src1.w * src2.w; reinterpret_cast(out)[vec_idx2] = res; } } __global__ void index_mul_2d_float(float* out, const float* in1, const float* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim); int64_t vec_idx2 = (start_idx * fea_dim); for (int i = tidx; i < fea_dim; i += stride) { out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; } } } __global__ void index_mul_2d_half(at::Half* out, const at::Half* in1, const at::Half* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim); int64_t vec_idx2 = (start_idx * fea_dim); for (int i = tidx; i < fea_dim; i += stride) { out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); } } } __global__ void index_mul_2d_grad_float_dim64(float* grad_in1, float* grad_in2, const float* grad_out, const float* in1, const float* in2, const int64_t* idx1, const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; constexpr int fea_dim = 64; if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; float4 src_in1, src_in2, src_grad_out, dst_grad_in2; src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; src_in1 = reinterpret_cast(in1)[vec_idx1]; src_in2 = reinterpret_cast(in2)[vec_idx2]; int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w); dst_grad_in2.x = src_grad_out.x * src_in1.x; dst_grad_in2.y = src_grad_out.y * src_in1.y; dst_grad_in2.z = src_grad_out.z * src_in1.z; dst_grad_in2.w = src_grad_out.w * src_in1.w; reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; } } __global__ void index_mul_2d_grad_float(float* grad_in1, float* grad_in2, const float* grad_out, const float* in1, const float* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; for (int i = tidx; i < fea_dim; i += stride) { float src_in1 = in1[vec_idx1 + i]; float src_in2 = in2[vec_idx2 + i]; float src_grad_out = grad_out[vec_idx2 + i]; grad_in2[vec_idx2 + i] = src_grad_out * src_in1; gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2); } } } __global__ void index_mul_2d_grad_half(at::Half* grad_in1, at::Half* grad_in2, const at::Half* grad_out, const at::Half* in1, const at::Half* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; for (int i = tidx; i < fea_dim; i += stride) { float src_in1 = static_cast(in1[vec_idx1 + i]); float src_in2 = static_cast(in2[vec_idx2 + i]); float src_grad_out = static_cast(grad_out[vec_idx2 + i]); grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1); gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2)); } } } __global__ void index_mul_2d_grad_grad_float_dim64(float* grad_grad_out, float* grad_in1, float* grad_in2, const float* grad_out, const float* grad_grad_in1, const float* grad_grad_in2, const float* in1, const float* in2, const int64_t* idx1, const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; constexpr int fea_dim = 64; if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out; float4 dst_grad_grad_out, dst_grad_in2; src_grad_grad_in1 = reinterpret_cast(grad_grad_in1)[vec_idx1]; src_in1 = reinterpret_cast(in1)[vec_idx1]; src_grad_grad_in2 = reinterpret_cast(grad_grad_in2)[vec_idx2]; src_in2 = reinterpret_cast(in2)[vec_idx2]; dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x; dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y; dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z; dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w; reinterpret_cast(grad_grad_out)[vec_idx2] = dst_grad_grad_out; src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z); gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w); dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x; dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y; dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z; dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w; reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; } } __global__ void index_mul_2d_grad_grad_float(float* grad_grad_out, float* grad_in1, float* grad_in2, const float* grad_out, const float* grad_grad_in1, const float* grad_grad_in2, const float* in1, const float* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; for (int i = tidx; i < fea_dim; i += stride) { float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i]; float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i]; float src_in1 = in1[vec_idx1 + i]; float src_in2 = in2[vec_idx2 + i]; float src_grad_out = grad_out[vec_idx2 + i]; grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1; grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out; gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out); } } } __global__ void index_mul_2d_grad_grad_half(at::Half* grad_grad_out, at::Half* grad_in1, at::Half* grad_in2, const at::Half* grad_out, const at::Half* grad_grad_in1, const at::Half* grad_grad_in2, const at::Half* in1, const at::Half* in2, const int64_t* idx1, const int64_t size, const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; for (int i = tidx; i < fea_dim; i += stride) { float src_grad_grad_in1 = static_cast(grad_grad_in1[vec_idx1 + i]); float src_grad_grad_in2 = static_cast(grad_grad_in2[vec_idx2 + i]); float src_in1 = static_cast(in1[vec_idx1 + i]); float src_in2 = static_cast(in2[vec_idx2 + i]); float src_grad_out = static_cast(grad_out[vec_idx2 + i]); grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1); grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out); gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out)); } } } void index_mul_2d_float_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (fea_dim == 64) { const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMY = 16; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_float_dim64<<>>( out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); } else { const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_float<<>>( out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } AT_CUDA_CHECK(cudaGetLastError()); } void index_mul_2d_float_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (fea_dim == 64) { const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMY = 16; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_float_dim64<<>>( grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); AT_CUDA_CHECK(cudaGetLastError()); } else { const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_float<<>>( grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } } void index_mul_2d_float_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (fea_dim == 64) { const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMY = 16; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_grad_float_dim64<<>>( grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); } else { const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_grad_float<<>>( grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } AT_CUDA_CHECK(cudaGetLastError()); } void index_mul_2d_half_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_half<<>>( out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); AT_CUDA_CHECK(cudaGetLastError()); } void index_mul_2d_half_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_half<<>>( grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } void index_mul_2d_half_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out, const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) { const int64_t size = in2.size(0); const int64_t fea_dim = in2.size(1); if (size < 0) { return; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; index_mul_2d_grad_grad_half<<>>( grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); AT_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/layer_norm/ln.h ================================================ #pragma once #include #include #include #include #include namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct LaunchParams { size_t workspace_bytes; size_t barrier_size; cudaDeviceProp* props; cudaStream_t stream; Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct FwdParams { FwdParams() : ctas_per_col(0), rows(0), cols(0), x(nullptr), z(nullptr), mu(nullptr), rs(nullptr), gamma(nullptr), beta(nullptr), workspace(nullptr), barrier(nullptr), epsilon(0.f) {} // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. int ctas_per_col; // Input is interpreted as matrix. We normalize across columns. int rows; int cols; // Common data pointers. void* x; void* z; void* mu; void* rs; void* gamma; void* beta; // Multi-CTA workspace in gmem. void* workspace; // Multi-CTA sync barriers in gmem. int* barrier; // Output of LN FWD. float epsilon; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct BwdParams : public FwdParams { BwdParams() : FwdParams(), dz(nullptr), dbeta_part(nullptr), dgamma_part(nullptr), dx(nullptr), dbeta(nullptr), dgamma(nullptr) {} // Input: gradient wrt. LN FWD output. void* dz; // Workspace for Wgrad pre-reduction. void* dbeta_part; void* dgamma_part; // Output: Dgrad. void* dx; // Output: Wgrad. void* dbeta; void* dgamma; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using FwdFunction = std::function&, const bool)>; using BwdFunction = std::function&, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; using BwdRegistry = std::unordered_map; extern FwdRegistry FWD_FUNCS; extern BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeId {}; template <> struct TypeId { constexpr static uint32_t Value = 0; }; template <> struct TypeId { constexpr static uint32_t Value = 1; }; template <> struct TypeId { constexpr static uint32_t Value = 2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Type2Key { constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct WeightType2Key : public Type2Key {}; template struct InputType2Key : public Type2Key {}; template struct OutputType2Key : public Type2Key {}; template struct ComputeType2Key : public Type2Key {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Types2Key { constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; constexpr static inline uint64_t get(const uint64_t hidden_size) { constexpr uint64_t type_key = Value; return (type_key << 32) | hidden_size; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdRegistrar { FwdRegistrar(FwdFunction f) { uint64_t key = Types2Key::get(HIDDEN_SIZE); FWD_FUNCS.insert({key, f}); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdRegistrar { BwdRegistrar(BwdFunction f) { uint64_t key = Types2Key::get(HIDDEN_SIZE); BWD_FUNCS.insert({key, f}); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: apex/contrib/csrc/layer_norm/ln_api.cpp ================================================ #include #include "ATen/cuda/CUDAContext.h" #include "ln.h" /* Supported Type combinations: input compute weights output ======================================= fp32 fp32 fp32 fp32 fp16 fp32 fp16 fp16 bf16 fp32 bf16 bf16 fp32 fp32 fp16 fp16 fp32 fp32 bf16 bf16 Remarks: Output type = Weight type Compute always in FP32 */ namespace layer_norm { // Create registries and provide runtime versions of config hash functions. FwdRegistry FWD_FUNCS; BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// uint32_t get_type_id(torch::Dtype dtype) { if (dtype == torch::kFloat16) { return TypeId::Value; } else if (dtype == torch::kBFloat16) { return TypeId::Value; } else if (dtype == torch::kFloat32) { return TypeId::Value; } else { TORCH_CHECK(false, "Type not supported: ", dtype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { using namespace layer_norm; uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); uint64_t launcher_key = (type_key << 32) | hidden_size; return launcher_key; } } // namespace layer_norm //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::FWD_FUNCS.end()) { return iter->second; } else { TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::BwdFunction& get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::BWD_FUNCS.end()) { return iter->second; } else { TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size const at::Tensor& gamma, // hidden_size const at::Tensor& beta, // hidden_size const float epsilon) { auto itype = x.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; auto ctype = torch::kFloat32; TORCH_CHECK(beta.scalar_type() == wtype); TORCH_CHECK(x.is_cuda()) TORCH_CHECK(gamma.is_cuda()) TORCH_CHECK(beta.is_cuda()) TORCH_CHECK(x.is_contiguous()); auto sizes = x.sizes(); TORCH_CHECK(sizes.size() == 2); const int rows = sizes[0]; const int cols = sizes[1]; auto hidden_size = gamma.numel(); TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(hidden_size == cols); TORCH_CHECK(epsilon >= 0.f); auto opts = x.options(); auto z = torch::empty(sizes, opts.dtype(otype)); auto mu = torch::empty({rows}, opts.dtype(ctype)); auto rsigma = torch::empty({rows}, opts.dtype(ctype)); layer_norm::LaunchParams launch_params; launch_params.props = at::cuda::getCurrentDeviceProperties(); launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); // Request the kernel launcher. auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size); // Query the kernel-specific launch parameters. launcher(launch_params, true); at::Tensor workspace, barrier; // Set the kernel runtime parameters. layer_norm::FwdParams& params = launch_params.params; params.rows = rows; params.cols = cols; params.z = z.data_ptr(); params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.beta = beta.data_ptr(); params.x = x.data_ptr(); params.epsilon = epsilon; if (launch_params.barrier_size > 0) { auto options = x.options(); barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } // Launch the kernel. launcher(launch_params, false); return {z, mu, rsigma}; } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector ln_bwd(const at::Tensor& dz, // BxSxhidden_size const at::Tensor& x_or_z, // BxSxhidden_size c10::optional& mu_, // BxS, FP32! const at::Tensor& rsigma, // BxS, FP32! const at::Tensor& gamma, // hidden_size c10::optional& beta_, // hidden_size bool memory_efficient) { auto itype = x_or_z.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; auto ctype = torch::kFloat32; TORCH_CHECK(dz.dtype() == otype); TORCH_CHECK(rsigma.dtype() == ctype); if (mu_.has_value()) { TORCH_CHECK(mu_.value().dtype() == ctype); } TORCH_CHECK(x_or_z.is_cuda()); TORCH_CHECK(dz.is_cuda()); TORCH_CHECK(rsigma.is_cuda()); TORCH_CHECK(gamma.is_cuda()); if (beta_.has_value()) { TORCH_CHECK(beta_.value().is_cuda()); TORCH_CHECK(beta_.value().dtype() == wtype); } TORCH_CHECK(x_or_z.is_contiguous()); TORCH_CHECK(dz.is_contiguous()); auto sizes = x_or_z.sizes(); TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(dz.sizes() == sizes); auto rows = sizes[0]; auto cols = sizes[1]; auto hidden_size = gamma.numel(); TORCH_CHECK(gamma.numel() == cols); if (beta_.has_value()) { TORCH_CHECK(beta_.value().numel() == cols); } auto options = x_or_z.options(); auto dx = torch::empty_like(x_or_z); auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); launch_params.props = at::cuda::getCurrentDeviceProperties(); auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size); launcher(launch_params, true); auto dgamma_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); auto dbeta_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); at::Tensor workspace, barrier; layer_norm::BwdParams& params = launch_params.params; params.rows = rows; params.cols = cols; if (memory_efficient) { params.z = x_or_z.data_ptr(); params.beta = beta_.value().data_ptr(); } else { params.x = x_or_z.data_ptr(); params.mu = mu_.value().data_ptr(); } params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.dz = dz.data_ptr(); params.dx = dx.data_ptr(); params.dbeta = dbeta.data_ptr(); params.dgamma = dgamma.data_ptr(); params.dbeta_part = dbeta_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr(); if (launch_params.barrier_size > 0) { // TODO Any way to avoid this? barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } launcher(launch_params, false); return {dx, dgamma, dbeta, dgamma_part, dbeta_part}; } //////////////////////////////////////////////////////////////////////////////////////////////////// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "CUDA LayerNorm"; m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel", py::call_guard()); m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh ================================================ #pragma once #include "ln_utils.cuh" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { COLS = Ktraits::COLS }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using compute_t = typename Ktraits::compute_t; using index_t = typename Ktraits::index_t; using Ivec = typename Ktraits::Ivec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Reducer = typename Ktraits::Reducer; using reduce_t = typename Reducer::Type; extern __shared__ char smem_[]; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / Ktraits::WARPS_N; const index_t warp_n = warp % Ktraits::WARPS_N; const index_t tid_r = warp_n * THREADS_PER_WARP + lane; const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dz_sum, 0, sizeof(dz_sum)); compute_t* smem_wgrad = reinterpret_cast(smem_); char* smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); Sum sum; constexpr float rn = 1.f / float(COLS); Wvec gamma[LDGS]; Wvec beta[LDGS]; index_t idx = c; #pragma unroll for (int it = 0; it < LDGS; it++) { gamma[it].load_from(params.gamma, idx); if (params.z != nullptr) { beta[it].load_from(params.beta, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // last blocks with syncthreads! // grid stride over rows #pragma unroll 1 for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { const compute_t mu_r = params.z == nullptr ? static_cast(params.mu)[row] : 0.f; const compute_t rs_r = static_cast(params.rs)[row]; Ivec x_or_z[LDGS]; Ovec dz[LDGS]; index_t idx = row * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { dz[it].load_from(params.dz, idx); if (params.z != nullptr) { x_or_z[it].load_from(params.z, idx); } else { x_or_z[it].load_from(params.x, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } compute_t dy[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS]; compute_t mdy_local = 0.f; compute_t mdyy_local = 0.f; #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]); compute_t beta_tmp = compute_t(beta[it].data.elt[jt]); compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]); compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r); compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp; compute_t dz_tmp = dz[it].data.elt[jt]; mdy_local += dy_tmp; mdyy_local += dy_tmp * y_tmp; dy[it * NUM_ELTS + jt] = dy_tmp; y[it * NUM_ELTS + jt] = y_tmp; dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; dz_sum[it].data.elt[jt] += dz_tmp; } } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); mdy_local = layer_norm::Get<0>::of(result) * rn; mdyy_local = layer_norm::Get<1>::of(result) * rn; Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); dx[it].data.elt[jt] = dx_tmp; } dx[it].store_to(params.dx, idx); idx += Ktraits::VEC_COLS_PER_LDG; } } // end: grid stride loop if (WARPS_M == 1) { idx = r * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { dz_sum[it].store_to(params.dbeta_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx); idx += Ktraits::VEC_COLS_PER_LDG; } } else { static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); // Finalize reduction of part dgamma and dbeta for this CTA // by reducing over the rows held across the WARPS_M warps // Assumption: blockSize divides hidden size. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for (int it = 0; it < LDGS; it++) { dz_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dz_sum[NUM_RES]; memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); for (int it = 0; it < ROWS_PER_CTA; it++) { for (int jt = 0; jt < NUM_RES; jt++) { cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for (int it = 0; it < LDGS; it++) { dzy_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dzy_sum[NUM_RES]; memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); for (int it = 0; it < ROWS_PER_CTA; it++) { for (int jt = 0; jt < NUM_RES; jt++) { cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } compute_t* dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; for (int jt = 0; jt < NUM_RES; jt++) { *dgamma_part = cta_dzy_sum[jt]; dgamma_part += Ktraits::THREADS_PER_CTA; } compute_t* dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; for (int jt = 0; jt < NUM_RES; jt++) { *dbeta_part = cta_dz_sum[jt]; dbeta_part += Ktraits::THREADS_PER_CTA; } } } template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; using Reducer = typename Kernel_traits::Reducer; using reduce_t = typename Reducer::Type; Sum sum; enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; constexpr uint32_t bidm = 0; const uint32_t bidn = blockIdx.x; const uint32_t tidx = threadIdx.x; const uint32_t warp = tidx / THREADS_PER_WARP; const uint32_t lane = tidx % THREADS_PER_WARP; Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2) { // Each thread sums over NUM_ELT columns. Vec dbeta_local, dgamma_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { index_t idx = row * Kernel_traits::COLS + col; Vec dbeta_part, dgamma_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); #pragma unroll for (int it = 0; it < NUM_ELT; it++) { dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; } } void* smem_gamma = smem_; void* smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; const int write_row = warp; const int write_col = lane ^ write_row; const int write_idx = write_row * THREADS_PER_WARP + write_col; dgamma_local.store_to(smem_gamma, write_idx); dbeta_local.store_to(smem_beta, write_idx); __syncthreads(); // It would be probably safe to reuse the first row of smem_beta and smem_gamma void* smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void* smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; // More than one iter iff ROWS_PER_CTA < 32. for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { const int read_row = lane; const int read_col = w ^ read_row; const int read_idx = read_row * THREADS_PER_WARP + read_col; memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dgamma_local, 0, sizeof(dgamma_local)); // Load beta and gamma transposed if (read_row < Kernel_traits::ROWS_PER_CTA) { dbeta_local.load_from(smem_beta, read_idx); dgamma_local.load_from(smem_gamma, read_idx); } // Call reducer on the loaded value(s) and convert. #pragma unroll for (int it = 0; it < NUM_ELT; it++) { compute_t b_i = dbeta_local.data.elt[it]; compute_t g_i = dgamma_local.data.elt[it]; b_i = reducer.allreduce(b_i, sum); g_i = reducer.allreduce(g_i, sum); dgamma_local.data.elt[it] = g_i; dbeta_local.data.elt[it] = b_i; } // Leader stores the result at the current column. if (lane == 0) { dgamma_local.store_to(smem_gamma_out, w); dbeta_local.store_to(smem_beta_out, w); } } // All writes done. __syncthreads(); // Pack and store: 2-wide stores with half the threads. if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) { using src_t = typename TypeToVec2::Type; using dst_t = typename TypeToVec2::Type; Vec dbeta_vec2, dgamma_vec2; Vec dbeta_out2, dgamma_out2; dgamma_vec2.load_from(smem_gamma_out, lane); dbeta_vec2.load_from(smem_beta_out, lane); #pragma unroll for (int it = 0; it < NUM_ELT; it++) { dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); } dgamma_out2.store_to(params.dgamma, col_out); dbeta_out2.store_to(params.dbeta, col_out); } } } } // namespace layer_norm ================================================ FILE: apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu ================================================ #include "ln.h" #include "ln_bwd_kernels.cuh" #include "ln_kernel_traits.h" #include "ln_utils.cuh" using namespace layer_norm; template void launch_(LaunchParams& launch_params, const bool configure_params) { using Kernel_traits = Kernel_traits; auto kernel = &ln_bwd_kernel; if (configure_params) { int ctas_per_sm; cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if (Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2; } return; } if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if (Kernel_traits::CTAS_PER_ROW == 1) { kernel<<>>(launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void* params_ = (void*)&launch_params.params; cudaLaunchCooperativeKernel((void*)kernel, grid, block, (void**)¶ms_, Kernel_traits::SMEM_BYTES, stream); } using Kernel_traits_f = layer_norm::Kernel_traits_finalize; auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; kernel_f<<>>(launch_params.params); } // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); ================================================ FILE: apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu ================================================ #include "ln.h" #include "ln_fwd_kernels.cuh" #include "ln_kernel_traits.h" #include "ln_utils.cuh" using namespace layer_norm; template void launch_(LaunchParams& launch_params, const bool configure_params) { using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_kernel; if (configure_params) { int ctas_per_sm; cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if (Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2; } return; } if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { CHECK_CUDA( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if (Kernel_traits::CTAS_PER_ROW == 1) { kernel<<>>( launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void* params_ = (void*)&launch_params.params; cudaLaunchCooperativeKernel((void*)kernel, grid, block, (void**)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); } } // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); ================================================ FILE: apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh ================================================ #pragma once #include "ln.h" #include "ln_utils.cuh" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; using output_t = typename Ktraits::output_t; using index_t = typename Ktraits::index_t; using compute_t = typename Ktraits::compute_t; using Ivec = typename Ktraits::Ivec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; extern __shared__ char smem_[]; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / WARPS_N; const index_t warp_n = warp % WARPS_N; const index_t r = bidm * ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); compute_t* mu_ptr = static_cast(params.mu); compute_t* rs_ptr = static_cast(params.rs); Wvec gamma[LDGS]; Wvec beta[LDGS]; index_t idx = c; #pragma unroll for (int it = 0; it < LDGS; it++) { gamma[it].load_from(params.gamma, idx); beta[it].load_from(params.beta, idx); idx += VEC_COLS_PER_LDG; } constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ivec x[LDGS]; index_t idx = row * Ktraits::VEC_COLS + c; compute_t xf[LDGS * NUM_ELTS]; #pragma unroll for (int it = 0; it < LDGS; it++) { x[it].load_from(params.x, idx); #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { compute_t x_ij = compute_t(x[it].data.elt[jt]); xf[it * NUM_ELTS + jt] = x_ij; } idx += VEC_COLS_PER_LDG; } stats_t s = stats.compute(xf, rn); compute_t mu = layer_norm::Get<0>::of(s); compute_t m2 = layer_norm::Get<1>::of(s); if (bidn == 0 && warp_n == 0 && lane == 0) { mu_ptr[row] = mu; } compute_t rs = rsqrtf(rn * m2 + params.epsilon); if (bidn == 0 && warp_n == 0 && lane == 0) { rs_ptr[row] = rs; } Ovec z[LDGS]; idx = row * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); output_t g_ij = gamma[it].data.elt[jt]; output_t b_ij = beta[it].data.elt[jt]; z[it].data.elt[jt] = (g_ij * y_ij + b_ij); } z[it].store_to(params.z, idx); idx += VEC_COLS_PER_LDG; } } } } // namespace layer_norm ================================================ FILE: apex/contrib/csrc/layer_norm/ln_kernel_traits.h ================================================ #pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// namespace layer_norm { template struct Kernel_traits_base { using weight_t = weight_t_; using input_t = input_t_; using output_t = output_t_; using compute_t = compute_t_; using index_t = index_t_; enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; enum { THREADS_PER_WARP = 32 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Kernel_traits_finalize : public Base { enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; static_assert((int)ROWS_PER_CTA <= (int)Base::THREADS_PER_WARP); // Bytes per global load from the input. enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; // Number of elements fetched by a global load. enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; // Bytes per global store of the weights. enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); // The total number of BYTES_PER_LDG-wide words in a hidden vector. enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); // Shared memory size to transpose the CTA result. enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; // Shared memory size to coalsece the CTA result. enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; // Shared memory requirement per CTA. enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; // The type of the reducer. using Reducer = layer_norm::Reducer; // Condition for the whole CTA to participate in syncthreads. static_assert(COLS % Base::THREADS_PER_WARP == 0); enum { CTAS = COLS / Base::THREADS_PER_WARP }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template > struct Kernel_traits : public Base { using input_t = typename Base::input_t; using weight_t = typename Base::weight_t; using compute_t = typename Base::compute_t; using output_t = typename Base::output_t; using index_t = typename Base::index_t; enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; enum { WARPS_M = WARPS_M_ }; enum { WARPS_N = WARPS_N_ }; enum { COLS = HIDDEN_SIZE_ }; enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; enum { ROWS_PER_CTA = WARPS_M }; enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); using reduce_t = typename layer_norm::TypeToVec2::Type; using Reducer = layer_norm::Reducer; enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; using Ivec = layer_norm::Vec; using Ovec = layer_norm::Vec; using Wvec = layer_norm::Vec; using Cvec = layer_norm::Vec; enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; // Assume that each thread can handle the same number of elements in the output and weights as in the input. static_assert(sizeof(input_t) >= sizeof(output_t)); static_assert(sizeof(input_t) >= sizeof(weight_t)); // The number of columns fetched per load from input: one per thread. enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; // The total number of vectorized loads/stores per hidden vector. enum { VEC_COLS = COLS / ELTS_PER_LDG }; // The number of loads per thread for the input. enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); using Stats = layer_norm::Stats; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: apex/contrib/csrc/layer_norm/ln_utils.cuh ================================================ #pragma once #include #include #include #include "ln.h" //////////////////////////////////////////////////////////////////////////////////////////////////// constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// inline void check_cuda_(cudaError_t status, const char* file, int line) { if (status != cudaSuccess) { fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); exit(status); } } //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA(ans) \ { \ check_cuda_((ans), __FILE__, __LINE__); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// #define DIVUP(x, y) (((x) + ((y) - 1)) / (y)) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams& launch_params, \ const bool configure_params) { \ launch_( \ launch_params, configure_params); \ } \ static FwdRegistrar \ reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_BWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ BYTES_PER_LDG_FINALIZE) \ void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams& launch_params, \ const bool configure_params) { \ launch_(launch_params, configure_params); \ } \ static BwdRegistrar \ reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 operator+(const float2& a, const float2& b) { return {a.x + b.x, a.y + b.y}; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void operator+=(float2& a, const float2& b) { a.x += b.x; a.y += b.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Sum { inline __device__ Sum() {} inline __device__ T operator()(const T& a, const T& b) { return a + b; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ T warp_shuffle_xor(const T& x, uint32_t idx) { return __shfl_xor_sync(uint32_t(-1), x, idx); } template <> inline __device__ float2 warp_shuffle_xor(const float2& x, uint32_t idx) { return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; } template inline __device__ T warp_shuffle_down(const T& x, uint32_t idx) { return __shfl_down_sync(uint32_t(-1), x, idx); } template <> inline __device__ float2 warp_shuffle_down(const float2& x, uint32_t idx) { return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; } //////////////////////////////////////////////////////////////////////////////////////////////////// namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint16 { uint4 u; uint4 v; uint4 s; uint4 t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint8 { uint4 u; uint4 v; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; template <> struct BytesToType<64> { using Type = uint16; static_assert(sizeof(Type) == 64); }; template <> struct BytesToType<32> { using Type = uint8; static_assert(sizeof(Type) == 32); }; template <> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template <> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template <> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template <> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template <> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeToVec2 {}; template <> struct TypeToVec2 { using Type = float2; }; template <> struct TypeToVec2 { using Type = half2; }; template <> struct TypeToVec2 { using Type = nv_bfloat162; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Get { template static inline __device__ R of(const T& vec); }; template <> template inline __device__ R Get<0>::of(const T& vec) { return vec.x; } template <> template inline __device__ R Get<1>::of(const T& vec) { return vec.y; } template <> template inline __device__ R Get<2>::of(const T& vec) { return vec.z; } template <> template inline __device__ R Get<3>::of(const T& vec) { return vec.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Converter { static inline __device__ Dst convert(const Src& from) { return Dst(from); } }; template <> struct Converter { static inline __device__ half2 convert(const float2& x) { return __float22half2_rn(x); } }; template <> struct Converter { static inline __device__ nv_bfloat162 convert(const float2& x) { #if __CUDA_ARCH__ >= 800 return __float22bfloat162_rn(x); #else union { nv_bfloat162 raw; nv_bfloat16 x; nv_bfloat16 y; } tmp; tmp.x = __float2bfloat16_rn(x.x); tmp.y = __float2bfloat16_rn(x.y); return tmp.raw; #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Zeros { static inline __device__ T get() { return T(0.f); } }; template <> struct Zeros { static inline __device__ float2 get() { return make_float2(0.f, 0.f); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Vec { enum { BYTES = NUM_ELT * sizeof(Elt_type) }; using Vec_type = typename BytesToType::Type; using Alias_type = union { Vec_type vec; Elt_type elt[NUM_ELT]; }; Alias_type data; template inline __device__ void to(Vec& other) { #pragma unroll for (int it = 0; it < NUM_ELT; it++) { other.data.elt[it] = S(this->data.elt[it]); } } template inline __device__ void assign(const Op& op) { #pragma unroll for (int it = 0; it < NUM_ELT; it++) { this->data.elt[it] = op(it); } } inline __device__ void load_from(const void* base_ptr, const size_t idx) { this->data.vec = static_cast(base_ptr)[idx]; } inline __device__ void store_to(void* base_ptr, const size_t idx) { static_cast(base_ptr)[idx] = this->data.vec; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct InterCTASync { template inline __device__ InterCTASync(Params& params, uint32_t bidm, uint32_t bidn) : phase_counter_(0), b0_(params.barrier + bidm) // The barrier for this group of CTAs. , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. { // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! } inline __device__ void spin_wait_(int* barrier, int step, int expected) { asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); for (int found = -1; found != expected;) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } inline __device__ void sync() { // ALL THREADS MUST ENTER! // We switch barrier every iteration. int* barrier = phase_counter_ & 0x1 ? b1_ : b0_; // We decrement every other iteration. bool dec = phase_counter_ & 0x2; int step = dec ? -1 : 1; int expected = dec ? 0 : CTAS_PER_ROW; // There are only 4 phases: up/down for b0/b1. phase_counter_ = (phase_counter_ + 1) & 0x3; if (threadIdx.x == 0) { spin_wait_(barrier, step, expected); } // CTA waits for thread 0 __syncthreads(); } int phase_counter_; int* b0_; int* b1_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer : public Reducer { using InterCTASync = InterCTASync; using Base = Reducer; using Type = typename Base::Type; enum { SMEM_BYTES = Base::SMEM_BYTES }; enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; template inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), inter_cta_(params, bidm, bidn), bidn_(bidn) // CTA id within the group. , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} template inline __device__ T allreduce(T data, Op& op) { data = Base::reduce(data, op); // We switch workspace every iteration. T* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; // Warp leaders 0 hold the CTA-local results. if (this->warp_n_ == 0 && this->lane_ == 0) { workspace[bidn_] = data; } inter_cta_.sync(); static_assert(CTAS_PER_ROW <= 32); T total = Zeros::get(); if (this->lane_ < CTAS_PER_ROW) { total = workspace[this->lane_]; } total = Reducer::allreduce_(total, op); return total; } InterCTASync inter_cta_; T* w0_; T* w1_; int bidn_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer { using Type = T; enum { SMEM_BYTES = 0 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { THREADS_PER_WARP = 32 }; template inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : warp_n_(warp_n), lane_(lane) {} template static inline __device__ T allreduce_(T data, Op& op) { #pragma unroll for (int it = 1; it < THREADS_PER_WARP; it *= 2) { data = op(data, warp_shuffle_xor(data, it)); } return data; } template inline __device__ T allreduce(T data, Op& op) { return allreduce_(data, op); } template inline __device__ T reduce(T data, Op& op) { // only lane 0 holds the result! #pragma unroll for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { data = op(data, warp_shuffle_down(data, it)); } return data; } int warp_n_; int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer : public Reducer { using Base = Reducer; using Type = T; enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { THREADS_PER_WARP = 32 }; template inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { smem0_ = &static_cast(smem)[warp_m * WARPS_N]; smem1_ = smem0_ + WARPS_M * WARPS_N; } template inline __device__ T allreduce(T data, Op& op) { T* smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; data = Base::reduce(data, op); if (this->lane_ == 0) { smem[this->warp_n_] = data; } __syncthreads(); T out = Zeros::get(); #pragma unroll for (int it = 0; it < WARPS_N; it++) { out = op(out, smem[it]); } return out; } template inline __device__ T reduce(T data, Op& op) { T* smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; // only intra-CTA group leader holds the result! data = Base::reduce(data, op); if (this->lane_ == 0) { smem[this->warp_n_] = data; } __syncthreads(); T out = Zeros::get(); if (this->warp_n_ == 0 && this->lane_ == 0) { #pragma unroll for (int it = 0; it < WARPS_N; it++) { out = op(out, smem[it]); } } return out; } T* smem0_; T* smem1_; bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void warp_chan_upd_dynamic(T& m_a, T& m2_a, T& n_a, int num_active) { // Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); #pragma unroll for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { // Exchange T n_b = warp_shuffle_down(n_a, step); T m_b = warp_shuffle_down(m_a, step); T m2_b = warp_shuffle_down(m2_a, step); // Update const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( const T delta = m_a - m_b; const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; n_a = n_ab; m_a = m_ab; m2_a = m2_ab; } // Intra-warp broadcast (only lane 0 has valid stats). m_a = __shfl_sync(uint32_t(-1), m_a, 0); m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. using InterCTASync = InterCTASync; using BlockStats = Stats; using stats_t = typename BlockStats::stats_t; enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; template inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : inter_cta_(params, bidm, bidn), block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), bidn_(bidn) // CTA id within the group. , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), warp_n_(warp_n), lane_(lane) {} template inline __device__ stats_t compute(const T (&elts)[N], const T rn) { constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; // TODO rn is not really needed here.. constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); stats_t block_stats = block_stats_.compute(elts, block_rn); stats_t* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; if (warp_n_ == 0 && lane_ == 0) { workspace[bidn_] = block_stats; } // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. inter_cta_.sync(); T n = Zeros::get(); T m = Zeros::get(); T m2 = Zeros::get(); // Assume CTA group size in N less than 32, such that we can finalize with a single warp. static_assert(CTAS_PER_ROW <= 32); // Every warp does the final reduction locally. if (lane_ < CTAS_PER_ROW) { stats_t result = workspace[lane_]; n = ELTS_PER_ROW_PER_CTA; m = layer_norm::Get<0>::of(result); m2 = layer_norm::Get<1>::of(result); } warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); return {m, m2}; } InterCTASync inter_cta_; BlockStats block_stats_; stats_t* w0_; stats_t* w1_; int bidn_; int warp_n_; int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { using WarpStats = Stats; using stats_t = typename WarpStats::stats_t; enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; template inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { smem0_ = static_cast(smem) + warp_m * WARPS_N; smem1_ = smem0_ + WARPS_M * WARPS_N; } template inline __device__ stats_t compute(const T (&elts)[N], const T rn) { stats_t* smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; // Compute warp local for all WARPS_N constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); stats_t warp_stats = warp_stats_.compute(elts, warp_rn); // Each warp warp leader stores its stats const auto warp_n = warp_stats_.reducer_.warp_n_; const auto lane = warp_stats_.reducer_.lane_; if (lane == 0) { smem[warp_n] = warp_stats; } __syncthreads(); T n = Zeros::get(); T m = Zeros::get(); T m2 = Zeros::get(); // Assume that there are less than 32 warps, such that we can finalize with a single warp static_assert(WARPS_N <= 32); if (lane < WARPS_N) { stats_t result = smem[lane]; n = N * THREADS_PER_WARP; m = layer_norm::Get<0>::of(result); m2 = layer_norm::Get<1>::of(result); } warp_chan_upd_dynamic(m, m2, n, WARPS_N); return {m, m2}; } WarpStats warp_stats_; stats_t* smem0_; stats_t* smem1_; bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { using stats_t = typename TypeToVec2::Type; // The simple Warp reducer. using Reducer = Reducer; enum { SMEM_BYTES = 0 }; template inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void* smem) : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} template inline __device__ stats_t compute(const T (&elts)[N], const T rn) { auto sum = Sum(); T m = Zeros::get(); #pragma unroll for (int it = 0; it < N; it++) { m += elts[it]; } m = reducer_.allreduce(m, sum) * rn; T m2 = Zeros::get(); #pragma unroll for (int it = 0; it < N; it++) { T diff = (elts[it] - m); m2 += diff * diff; } m2 = reducer_.allreduce(m2, sum); return {m, m2}; } Reducer reducer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" // symbol to be automatically resolved by PyTorch libs namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); const int k_seq_len = q_seq_len; // const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { softmax_success = dispatch_additive_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } // Matmul2 return {dropout_results, dropout_mask, softmax_results}; } torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; // const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations // torch::Tensor input_grads = torch::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); // backward pass is completely in-place return output_grads; } } // namespace additive_mask_softmax_dropout } // namespace fused_softmax } // namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/dropout.cuh ================================================ #pragma once #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include #include namespace { constexpr int UNROLL = 4; } // namespace template __global__ void apex_fused_dropout_kernel(scalar_t const* inputs, scalar_t* outputs, uint8_t* mask, IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; rand.x = rand.x <= p; rand.y = rand.y <= p; rand.z = rand.z <= p; rand.w = rand.w <= p; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = src[ii] * (&rand.x)[ii] * pinv; mask[li] = (uint8_t)(&rand.x)[ii]; } } __syncthreads(); } } template __global__ void apex_dropout_add_kernel(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, uint8_t* mask, IndexType totalElements, accscalar_t p, std::pair seeds) { accscalar_t pinv = accscalar_t(1) / p; IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seeds.first, idx, seeds.second, &state); IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; rand.x = rand.x <= p; rand.y = rand.y <= p; rand.z = rand.z <= p; rand.w = rand.w <= p; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; add_src[ii] = add_inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; outputs[li] = static_cast(static_cast(add_src[ii]) + int1); mask[li] = (uint8_t)(&rand.x)[ii]; } } __syncthreads(); } } template __global__ void apex_add_kernel(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, IndexType totalElements) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t add_src[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = inputs[li]; add_src[ii] = add_inputs[li]; } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = src[ii] + add_src[ii]; } } __syncthreads(); } } template __global__ void apex_masked_scale_kernel(scalar_t const* inputs, scalar_t* outputs, uint8_t const* mask, IndexType totalElements, accscalar_t scale) { IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; IndexType rounded_size = ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) { scalar_t src[UNROLL]; scalar_t msk[UNROLL]; for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { src[ii] = static_cast(inputs[li]); msk[ii] = static_cast(mask[li]); } } for (int ii = 0; ii < UNROLL; ii++) { IndexType li = linearIndex + blockDim.x * gridDim.x * ii; if (li < totalElements) { outputs[li] = static_cast(src[ii]) * scale * static_cast(msk[ii]); } } } } template void apex_fused_dropout_cuda(scalar_t const* inputs, scalar_t* outputs, uint8_t* mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs(counter_offset); } apex_fused_dropout_kernel<<>>( inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_dropout_add_cuda(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, uint8_t* mask, IndexType totalElements, accscalar_t p) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); // number of times random will be generated per thread, to offset philox // counter in the random state int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; std::pair rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); rng_engine_inputs = at::check_generator(gen)->philox_engine_inputs(counter_offset); } apex_dropout_add_kernel<<>>( inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_add_cuda(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, IndexType totalElements) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); apex_add_kernel <<>>(inputs, add_inputs, outputs, totalElements); C10_CUDA_CHECK(cudaGetLastError()); } template void apex_masked_scale_cuda(scalar_t const* inputs, scalar_t* outputs, uint8_t const* mask, IndexType totalElements, accscalar_t scale) { int block_size = 256; dim3 dim_block(block_size); dim3 grid((totalElements + block_size - 1) / block_size); unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); apex_masked_scale_kernel <<>>(inputs, outputs, mask, totalElements, scale); C10_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace encdec { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); const int k_seq_len = inputs_kv.size(0); const int batches_q = sequences * q_seq_len; const int batches_kv = sequences * k_seq_len; const int head_dim = embed_dim / heads; const int output_lin_q_dim = embed_dim; const int output_lin_kv_dim = 2 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim_q = attn_batches * head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim; const int batch_stride_q = head_dim; const int batch_stride_kv = 2 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); void* k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); void* v_lin_results_ptr = static_cast(static_cast(input_lin_kv_results.data_ptr()) + head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Q Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim_kv, batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_q_results, input_lin_kv_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); const int k_seq_len = inputs_kv.size(0); const int batches_q = sequences * q_seq_len; const int batches_kv = sequences * k_seq_len; const int head_dim = embed_dim / heads; const int output_lin_q_dim = embed_dim; const int output_lin_kv_dim = 2 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim_q = attn_batches * head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim; const int batch_stride_q = head_dim; const int batch_stride_kv = 2 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); auto v_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()) + head_dim; auto q_lin_grads_ptr = static_cast(input_lin_q_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()); auto v_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, batch_stride_kv, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q, batch_stride_q, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), static_cast(input_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, static_cast(&alpha), static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv, static_cast(&alpha), static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads}; } } // end namespace cublas_gemmex } // end namespace encdec } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "layer_norm.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace encdec_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); const int k_seq_len = inputs_kv.size(0); const int batches_q = sequences * q_seq_len; const int batches_kv = sequences * k_seq_len; const int total_tokens_q = batches_q * embed_dim; const int head_dim = embed_dim / heads; const int output_lin_q_dim = embed_dim; const int output_lin_kv_dim = 2 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim_q = attn_batches * head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim; const int batch_stride_q = head_dim; const int batch_stride_kv = 2 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); torch::Tensor outputs = torch::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); void* k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); void* v_lin_results_ptr = static_cast(static_cast(input_lin_kv_results.data_ptr()) + head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( static_cast(lyr_nrm_results.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), static_cast(inputs_q.data_ptr()), static_cast(batches_q), // n1 static_cast(embed_dim), // n2 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Q Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, // static_cast(inputs_q.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim_kv, batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), // static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), static_cast(inputs_q.data_ptr()), static_cast(outputs.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens_q, (1.0f - dropout_prob)); } else { apex_add_cuda(static_cast(output_lin_results.data_ptr()), static_cast(inputs_q.data_ptr()), static_cast(outputs.data_ptr()), total_tokens_q); } TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_q_results, input_lin_kv_results, softmax_results, dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); const int k_seq_len = inputs_kv.size(0); const int batches_q = sequences * q_seq_len; const int batches_kv = sequences * k_seq_len; const int total_tokens_q = batches_q * embed_dim; const int head_dim = embed_dim / heads; const int output_lin_q_dim = embed_dim; const int output_lin_kv_dim = 2 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim_q = attn_batches * head_dim; const int lead_dim_kv = attn_batches * 2 * head_dim; const int batch_stride_q = head_dim; const int batch_stride_kv = 2 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_q_grads = torch::empty_like(inputs_q); torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations at::Tensor dropout_add_grads = torch::empty_like(output_grads); at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); auto v_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()) + head_dim; auto q_lin_grads_ptr = static_cast(input_lin_q_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()); auto v_lin_grads_ptr = static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Dropout Add Backward apex_masked_scale_cuda( static_cast(output_grads.data_ptr()), static_cast(dropout_add_grads.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens_q, (1.0 / (1.0 - dropout_prob))); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim_kv, batch_stride_kv, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, batch_stride_kv, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q, batch_stride_q, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches); // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), // static_cast(input_q_grads.data_ptr()), static_cast(input_lin_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, static_cast(&alpha), static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast(&beta), static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv, static_cast(&alpha), static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast(&beta), static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( static_cast(input_lin_q_grads.data_ptr()), static_cast(output_grads.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), inputs_q, static_cast(batches_q), // n1 static_cast(embed_dim), // n2 static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast(input_q_grads.data_ptr()), static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr())); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads}; } } // end namespace cublas_gemmex } // end namespace encdec_norm_add } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/layer_norm.cuh ================================================ #pragma once #include #include #include #include namespace { template __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { count = count + U(1); U delta = curr - mu; U lmean = mu + delta / count; mu = lmean; U delta2 = curr - lmean; sigma2 = sigma2 + delta * delta2; } template __device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, U& mu, U& sigma2, U& count) { U delta = muB - mu; U nA = count; U nB = countB; count = count + countB; U nX = count; if (nX > U(0)) { nA = nA / nX; nB = nB / nX; mu = nA * mu + nB * muB; sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; } else { mu = U(0); sigma2 = U(0); } } template __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, const int n2, const int i1, U& mu, U& sigma2, U* buf) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensor is contiguous // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. // // compute variance and mean over n2 U count = U(0); mu = U(0); sigma2 = U(0); if (i1 < n1) { // one warp normalizes one n1 index, // synchronization is implicit // initialize with standard Welford algorithm const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); cuWelfordOnlineSum(curr, mu, sigma2, count); } } for (; l < n2; ++l) { U curr = static_cast(lvals[l]); cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x + (1 << l)) & 31; U muB = WARP_SHFL(mu, srcLaneB); U countB = WARP_SHFL(count, srcLaneB); U sigma2B = WARP_SHFL(sigma2, srcLaneB); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions if (blockDim.y > 1) { U* ubuf = (U*)buf; U* ibuf = (U*)(ubuf + blockDim.y); for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { // upper half of warps write to shared if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_y = threadIdx.y - offset; ubuf[2 * wrt_y] = mu; ubuf[2 * wrt_y + 1] = sigma2; ibuf[wrt_y] = count; } __syncthreads(); // lower half merges if (threadIdx.x == 0 && threadIdx.y < offset) { U muB = ubuf[2 * threadIdx.y]; U sigma2B = ubuf[2 * threadIdx.y + 1]; U countB = ibuf[threadIdx.y]; cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } __syncthreads(); } // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values if (threadIdx.x == 0 && threadIdx.y == 0) { ubuf[0] = mu; ubuf[1] = sigma2; } __syncthreads(); mu = ubuf[0]; sigma2 = ubuf[1] / U(n2); // don't care about final value of count, we know count == n2 } else { mu = WARP_SHFL(mu, 0); sigma2 = WARP_SHFL(sigma2 / U(n2), 0); } } } template <> __device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, const int n1, const int n2, const int i1, float& mu, float& sigma2, float* buf) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensor is contiguous // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. // // compute variance and mean over n2 float count = 0.0f; mu = float(0); sigma2 = float(0); if (i1 < n1) { // one warp normalizes one n1 index, // synchronization is implicit // initialize with standard Welford algorithm const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const at::Half* lvals = vals + i1 * n2; int l = 8 * thrx; if ((((size_t)lvals) & 3) != 0) { // 16 bit alignment // first thread consumes first point if (thrx == 0) { float curr = static_cast(lvals[0]); cuWelfordOnlineSum(curr, mu, sigma2, count); } ++l; } // at this point, lvals[l] are 32 bit aligned for all threads. for (; l + 7 < n2; l += 8 * numx) { for (int k = 0; k < 8; k += 2) { float2 curr = __half22float2(*((__half2*)(lvals + l + k))); cuWelfordOnlineSum(curr.x, mu, sigma2, count); cuWelfordOnlineSum(curr.y, mu, sigma2, count); } } for (; l < n2; ++l) { float curr = static_cast(lvals[l]); cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x + (1 << l)) & 31; float muB = WARP_SHFL(mu, srcLaneB); float countB = WARP_SHFL(count, srcLaneB); float sigma2B = WARP_SHFL(sigma2, srcLaneB); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions if (blockDim.y > 1) { float* ubuf = (float*)buf; float* ibuf = (float*)(ubuf + blockDim.y); for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { // upper half of warps write to shared if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_y = threadIdx.y - offset; ubuf[2 * wrt_y] = mu; ubuf[2 * wrt_y + 1] = sigma2; ibuf[wrt_y] = count; } __syncthreads(); // lower half merges if (threadIdx.x == 0 && threadIdx.y < offset) { float muB = ubuf[2 * threadIdx.y]; float sigma2B = ubuf[2 * threadIdx.y + 1]; float countB = ibuf[threadIdx.y]; cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } __syncthreads(); } // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values if (threadIdx.x == 0 && threadIdx.y == 0) { ubuf[0] = mu; ubuf[1] = sigma2; } __syncthreads(); mu = ubuf[0]; sigma2 = ubuf[1] / float(n2); // don't care about final value of count, we know count == n2 } else { mu = WARP_SHFL(mu, 0); sigma2 = WARP_SHFL(sigma2 / float(n2), 0); } } } template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } template <> __device__ float rsqrt(float v) { return rsqrtf(v); } template <> __device__ double rsqrt(double v) { return rsqrt(v); } // This is the un-specialized struct. Note that we prevent instantiation of // this struct by putting an undefined symbol in the function body so it won't // compile. // template // struct SharedMemory // { // // Ensure that we won't compile any un-specialized types // __device__ T *getPointer() // { // extern __device__ void error(void); // error(); // return NULL; // } // }; // https://github.com/NVIDIA/apex/issues/246 template struct SharedMemory; template <> struct SharedMemory { __device__ float* getPointer() { extern __shared__ float s_float[]; return s_float; } }; template <> struct SharedMemory { __device__ double* getPointer() { extern __shared__ double s_double[]; return s_double; } }; template __global__ void cuApplyLayerNorm(T* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar, const T* __restrict__ vals, const int n1, const int n2, const U epsilon, const T* __restrict__ gamma, const T* __restrict__ beta) { // Assumptions: // 1) blockDim.x == warpSize // 2) Tensors are contiguous // for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { SharedMemory shared; U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); const T* lvals = vals + i1 * n2; T* ovals = output_vals + i1 * n2; U c_invvar = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL && beta != NULL) { for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; } } else { for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); ovals[i] = static_cast(c_invvar * (curr - mu)); } } if (threadIdx.x == 0 && threadIdx.y == 0) { mean[i1] = mu; invvar[i1] = c_invvar; } } } template __device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, const T* input, const T* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean = mean[i1]; U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] = curr_dout; warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; } else { warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); } } } else { for (int k = 0; k < blockDim.y; ++k) { int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; warp_buf1[write_idx] = U(0); warp_buf2[write_idx] = U(0); } } } template __device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, const T* input, const T* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean = mean[i1]; U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1 * n2 + i2; int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] += curr_dout; warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; } } } } template __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout, const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, U* part_grad_gamma, U* part_grad_beta) { const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; const int row_stride = blockDim.x + 1; const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; SharedMemory shared; U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * // blockDim.y + (blockDim.y - // 1)*(blockDim.x/blockDim.y) elements U* warp_buf1 = (U*)buf; U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar); } __syncthreads(); // inter-warp reductions // sum within each warp U acc1 = U(0); U acc2 = U(0); for (int k = 0; k < blockDim.y; ++k) { int row1 = threadIdx.y + k * blockDim.y; int idx1 = row1 * row_stride + threadIdx.x; acc1 += warp_buf1[idx1]; acc2 += warp_buf2[idx1]; } warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; __syncthreads(); // sum all warps for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { if (threadIdx.y < offset) { int row1 = threadIdx.y; int row2 = threadIdx.y + offset; int idx1 = row1 * row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x; warp_buf1[idx1] += warp_buf1[idx2]; warp_buf2[idx1] += warp_buf2[idx2]; } __syncthreads(); } int i2 = blockIdx.x * blockDim.x + threadIdx.x; if (threadIdx.y == 0 && i2 < n2) { int row1 = threadIdx.y; int row2 = threadIdx.y + 1; int idx1 = row1 * row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x; part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; } } template __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta, const int part_size, const int n1, const int n2, T* grad_gamma, T* grad_beta) { // sum partial gradients for gamma and beta SharedMemory shared; U* buf = shared.getPointer(); int i2 = blockIdx.x * blockDim.x + threadIdx.x; if (i2 < n2) { // each warp does sequential reductions until reduced part_size is num_warps int num_warp_reductions = part_size / blockDim.y; U sum_gamma = U(0); U sum_beta = U(0); const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; sum_beta += part_grad_beta_ptr[warp_offset * n2]; } // inter-warp reductions const int nbsize3 = blockDim.x * blockDim.y / 2; for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { // top half write to shared memory if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; buf[write_idx] = sum_gamma; buf[write_idx + nbsize3] = sum_beta; } __syncthreads(); // bottom half sums if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; sum_beta += buf[read_idx + nbsize3]; } __syncthreads(); } // write out fully summed gradients if (threadIdx.y == 0) { grad_gamma[i2] = sum_gamma; grad_beta[i2] = sum_beta; } } } template __global__ void cuComputeGradInput(const T* __restrict__ dout, const T* __restrict__ dout_resid, const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, const T* gamma, T* grad_input) { for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; const U c_invvar = invvar[i1]; const T* k_input = input + i1 * n2; const T* k_dout = dout + i1 * n2; const T* k_dout_resid = dout_resid + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL) { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss * static_cast(gamma[l + k]); sum_loss2 += c_loss * static_cast(gamma[l + k]) * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss * static_cast(gamma[l]); sum_loss2 += c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; } } else { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } } // intra-warp reductions for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions if (blockDim.y > 1) { SharedMemory shared; U* buf = shared.getPointer(); for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; buf[2 * wrt_i] = sum_loss1; buf[2 * wrt_i + 1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; sum_loss1 += buf[2 * read_i]; sum_loss2 += buf[2 * read_i + 1]; } __syncthreads(); } if (threadIdx.y == 0) { buf[2 * threadIdx.x] = sum_loss1; buf[2 * threadIdx.x + 1] = sum_loss2; } __syncthreads(); if (threadIdx.y != 0) { sum_loss1 = buf[2 * threadIdx.x]; sum_loss2 = buf[2 * threadIdx.x + 1]; } } // all threads now have the two sums over l U fH = (U)n2; U term1 = (U(1) / fH) * c_invvar; T* k_grad_input = grad_input + i1 * n2; if (gamma != NULL) { for (int l = thrx; l < n2; l += numx) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); const T c_resid = static_cast(k_dout_resid[l]); U f_grad_input = fH * c_loss * static_cast(gamma[l]); f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input) + c_resid; } } else { for (int l = thrx; l < n2; l += numx) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); const T c_resid = static_cast(k_dout_resid[l]); U f_grad_input = fH * c_loss; f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input) + c_resid; } } } } template void HostApplyLayerNorm(T* output, U* mean, U* invvar, const T* input, int n1, int n2, double epsilon, const T* gamma, const T* beta) { auto stream = at::cuda::getCurrentCUDAStream().stream(); const dim3 threads(32, 4, 1); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; cuApplyLayerNorm<<>>(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); } template void HostLayerNormGradient(const T* dout, const T* dout_resid, const U* mean, const U* invvar, const at::Tensor& input, int n1, int n2, const T* gamma, const T* beta, double epsilon, T* grad_input, T* grad_gamma, T* grad_beta) { auto stream = at::cuda::getCurrentCUDAStream().stream(); if (gamma != NULL && beta != NULL) { // compute grad_gamma(j) and grad_beta(j) const int part_size = 16; const dim3 threads2(32, 4, 1); const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; at::Tensor part_grad_gamma = at::empty( {part_size, n2}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<>>( dout, static_cast(input.data_ptr()), n1, n2, mean, invvar, U(epsilon), static_cast(part_grad_gamma.data_ptr()), static_cast(part_grad_beta.data_ptr())); const dim3 threads3(32, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(U); cuComputeGradGammaBeta<<>>(static_cast(part_grad_gamma.data_ptr()), static_cast(part_grad_beta.data_ptr()), part_size, n1, n2, grad_gamma, grad_beta); } // compute grad_input const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 threads1(32, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; cuComputeGradInput<<>>(dout, dout_resid, static_cast(input.data_ptr()), n1, n2, mean, invvar, U(epsilon), gamma, grad_input); } } // namespace ================================================ FILE: apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" namespace multihead_attn { namespace fused_softmax { namespace mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } // Matmul2 return {dropout_results, dropout_mask, softmax_results}; } torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations // torch::Tensor input_grads = torch::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad if (padding_mask == nullptr) { dispatch_masked_scale_softmax_backward_stream( static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); } else { dispatch_masked_scale_softmax_backward_masked_out_stream( static_cast(output_grads.data_ptr()), static_cast(output_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), static_cast(padding_mask), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream); } // backward pass is completely in-place return output_grads; } } // namespace mask_softmax_dropout } // namespace fused_softmax } // namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp ================================================ #include #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, float dropout_prob); torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, "Only BYTE is supported"); } return fwd_cuda(is_training, heads, input, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); // TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, // "Only BYTE is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, dropout_prob); } } // namespace additive_mask_softmax_dropout namespace mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, float dropout_prob); torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(is_training, heads, input, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, torch::Tensor const& dropout_mask, torch::Tensor const& padding_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); // TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, // "Only BYTE is supported"); return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, use_mask ? static_cast(padding_mask.data_ptr()) : nullptr, dropout_prob); } } // end namespace mask_softmax_dropout } // end namespace fused_softmax namespace encdec { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_q_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_kv_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_prob); } } // end namespace cublas_gemmex } // end namespace encdec namespace encdec_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_q_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_kv_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_mean.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_invvar.dim() == 1, "expected 1D tensor"); TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); TORCH_CHECK(input_weights_kv.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_add_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_add_mask, dropout_prob); } } // end namespace cublas_gemmex } // end namespace encdec_norm_add namespace self { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, input_weights, output_weights, dropout_mask, dropout_prob); } } // end namespace cublas_gemmex } // end namespace self namespace self_bias { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, input_weights, output_weights, dropout_mask, dropout_prob); } } // end namespace cublas_gemmex } // namespace self_bias namespace self_bias_additive_mask { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, const half* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, // torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(use_mask, "no mask is not supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, "Only Half is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results, inputs, input_weights, output_weights, dropout_mask, dropout_prob); } } // end namespace cublas_gemmex } // namespace self_bias_additive_mask namespace self_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob); std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { TORCH_CHECK(pad_mask.dim() == 2, "expected 2D tensor"); TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); } return fwd_cuda(use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_lin_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_mean.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_invvar.dim() == 1, "expected 1D tensor"); TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_add_mask.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, "Only FLOAT is supported"); TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, "Only BYTE is supported"); return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, dropout_mask, dropout_add_mask, dropout_prob); } } // end namespace cublas_gemmex } // end namespace self_norm_add } // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("additive_mask_softmax_dropout_forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); m.def("additive_mask_softmax_dropout_backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.", py::call_guard()); m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward.", py::call_guard()); m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); m.def("encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.", py::call_guard()); m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.", py::call_guard()); m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.", py::call_guard()); m.def("self_attn_bias_forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.", py::call_guard()); m.def("self_attn_bias_backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.", py::call_guard()); m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.", py::call_guard()); m.def("self_attn_bias_additive_mask_backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.", py::call_guard()); m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.", py::call_guard()); } #undef CHECK_CUDA #undef CHECK_CONTIGUOUS #undef CHECK_INPUT ================================================ FILE: apex/contrib/csrc/multihead_attn/philox.cuh ================================================ #pragma once // Philox CUDA. namespace { class Philox { public: __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) : STATE(0) { // key.x = (unsigned int)seed; // key.y = (unsigned int)(seed >> 32); // counter = make_uint4(0, 0, 0, 0); // counter.z = (unsigned int)(subsequence); // counter.w = (unsigned int)(subsequence >> 32); // STATE = 0; // incr_n(offset / 4); key = reinterpret_cast(seed); ull2* tmp = reinterpret_cast(&counter); tmp->x = offset / 4; tmp->y = subsequence; } __device__ inline uint4 operator()() { if (STATE == 0) { uint4 counter_ = counter; uint2 key_ = key; // 7-round philox for (int i = 0; i < 6; i++) { counter_ = single_round(counter_, key_); key_.x += (kPhilox10A); key_.y += (kPhilox10B); } output = single_round(counter_, key_); incr(); } // return a float4 directly // unsigned long ret; // switch(STATE) { // case 0: ret = output.x; break; // case 1: ret = output.y; break; // case 2: ret = output.z; break; // case 3: ret = output.w; break; //} // STATE = (STATE + 1) % 4; return output; } private: struct ull2 { uint64_t x; uint64_t y; }; uint4 counter; uint4 output; uint2 key; unsigned int STATE; __device__ inline void incr_n(unsigned long long n) { unsigned int nlo = (unsigned int)(n); unsigned int nhi = (unsigned int)(n >> 32); counter.x += nlo; if (counter.x < nlo) nhi++; counter.y += nhi; if (nhi <= counter.y) return; if (++counter.z) return; ++counter.w; } __device__ uint4 incr128(uint4 ctr) { uint4 res; asm("add.cc.u32 %0, %4, %8;\n\t" "addc.cc.u32 %1, %5, %9;\n\t" "addc.cc.u32 %2, %6, %10;\n\t" "addc.u32 %3, %7, %11;\n\t" : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), "n"(1), "n"(0), "n"(0), "n"(0)); return res; } __device__ inline void incr() { counter = incr128(counter); } __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, unsigned int* result_high) { *result_high = __umulhi(a, b); return a * b; } __device__ uint2 mulhilo32_v2(unsigned int a, unsigned int b) { uint2* res; unsigned long long tmp; asm("mul.wide.u32 %0, %1, %2;\n\t" : "=l"(tmp) : "r"(a), "r"(b)); res = (uint2*)(&tmp); return *res; } __device__ inline uint4 single_round(uint4 ctr, uint2 key) { // unsigned int hi0; // unsigned int hi1; // unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); // unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); // uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; return ret; } static const unsigned long kPhilox10A = 0x9E3779B9; static const unsigned long kPhilox10B = 0xBB67AE85; static const unsigned long kPhiloxSA = 0xD2511F53; static const unsigned long kPhiloxSB = 0xCD9E8D57; }; // Inverse of 2^32. constexpr float M_RAN_INVM32 = 2.3283064e-10f; __device__ __inline__ float4 uniform4(uint4 x) { return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, x.w * M_RAN_INVM32); } } // namespace ================================================ FILE: apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace self_bias_additive_mask { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, const half* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta_zero = 0.0; const float beta_one = 1.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); void* k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); void* v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* bmm1_results_ptr = static_cast(bmm1_results.data_ptr()); void* dropout_results_ptr = static_cast(dropout_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim, batch_stride, static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta_zero, static_cast(bmm1_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (is_training) { softmax_success = dispatch_additive_masked_softmax_dropout( reinterpret_cast(dropout_results_ptr), (is_training) ? reinterpret_cast(dropout_mask.data_ptr()) : nullptr, reinterpret_cast(bmm1_results_ptr), pad_mask, attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences, 1.0f - dropout_prob, stream); } else { softmax_success = dispatch_additive_masked_softmax( reinterpret_cast(dropout_results_ptr), // this is actually softmax results, but // making it consistent for the next function reinterpret_cast(bmm1_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } // Matmul2 gemm_switch_fp32accum( a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta_zero, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); outputs.copy_(output_biases); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_recompute( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(bmm1_results.data_ptr()), reinterpret_cast(pad_mask.data_ptr()), static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len / sequences, attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(input_lin_output_grads.data_ptr()), // static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } } // end namespace cublas_gemmex } // namespace self_bias_additive_mask } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace self_bias { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, torch::Tensor const& output_biases, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta_zero = 0.0; const float beta_one = 1.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); void* k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); void* v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim, batch_stride, static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta_zero, static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax [[maybe_unused]] bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } if (is_training) { // use at:: function so that C++ version generates the same random mask as // python version auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob); dropout_results = std::get<0>(dropout_tuple); dropout_mask = std::get<1>(dropout_tuple); } // Matmul2 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta_zero, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); outputs.copy_(output_biases); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta_one), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO1_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), static_cast(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(input_lin_output_grads.data_ptr()), // static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } } // end namespace cublas_gemmex } // namespace self_bias } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace self { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); void* k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); void* v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim, batch_stride, static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta, static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads}; } } // end namespace cublas_gemmex } // end namespace self } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include "dropout.cuh" #include "layer_norm.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" namespace multihead_attn { namespace self_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int total_tokens = batches * embed_dim; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); auto lyr_nrm_options = act_options.dtype(torch::kFloat32); auto mask_options = act_options.dtype(torch::kUInt8); torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor output_lin_results = torch::empty_like(inputs, act_options); torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); torch::Tensor outputs = torch::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); void* k_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + head_dim); void* v_lin_results_ptr = static_cast(static_cast(input_lin_results.data_ptr()) + 2 * head_dim); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); char a_layout_t{'t'}; char a_layout_n{'n'}; char b_layout_n{'n'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( static_cast(lyr_nrm_results.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), static_cast(inputs.data_ptr()), static_cast(batches), // n1 static_cast(embed_dim), // n2 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Fwd TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, // static_cast(inputs.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, static_cast(k_lin_results_ptr), lead_dim, batch_stride, static_cast(q_lin_results_ptr), lead_dim, batch_stride, beta, static_cast(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { softmax_success = dispatch_softmax(reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), k_seq_len, k_seq_len, attn_batches * q_seq_len); } else { if (use_time_mask) { softmax_success = dispatch_time_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); } else { softmax_success = dispatch_masked_softmax( reinterpret_cast(softmax_results_ptr), reinterpret_cast(softmax_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences); } } assert(softmax_success); if (is_training) { apex_fused_dropout_cuda( static_cast(softmax_results.data_ptr()), static_cast(dropout_results.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob)); } // Matmul2 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), // static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches); // Output Linear TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), static_cast(inputs.data_ptr()), static_cast(outputs.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens, (1.0f - dropout_prob)); } else { apex_add_cuda(static_cast(output_lin_results.data_ptr()), static_cast(inputs.data_ptr()), static_cast(outputs.data_ptr()), total_tokens); } TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; } std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); const int k_seq_len = q_seq_len; const int batches = sequences * q_seq_len; const int total_tokens = batches * embed_dim; const int head_dim = embed_dim / heads; const int output_lin_dim = 3 * embed_dim; const int attn_batches = heads * sequences; const int lead_dim = attn_batches * 3 * head_dim; const int batch_stride = 3 * head_dim; const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); // TODO: Streams can be used in Backprop but I haven't added more than one // in my first attempt to create the code cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); // Output Tensor Allocations torch::Tensor input_grads = torch::empty_like(inputs); torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights); // Intermediate Tensor Allocations torch::Tensor dropout_add_grads = torch::empty_like(output_grads); torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); torch::Tensor matmul2_grads = torch::empty_like(dropout_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_grads = torch::empty_like(inputs); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; auto v_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + 2 * head_dim; auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); auto k_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + head_dim; auto v_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; char b_layout_t{'t'}; TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Dropout Add Backward apex_masked_scale_cuda( static_cast(output_grads.data_ptr()), static_cast(dropout_add_grads.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens, (1.0 / (1.0 - dropout_prob))); // Output Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Output Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // MatMul2 Dgrad1 gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, static_cast(v_lin_results_ptr), lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches); // Matmul2 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha, static_cast(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, static_cast(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), static_cast(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; softmax_success = dispatch_softmax_backward( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), reinterpret_cast(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len); assert(softmax_success); // Matmul1 Dgrad1 gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Matmul1 Dgrad2 gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim, batch_stride, static_cast(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches); // Input Linear Dgrad TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), // static_cast(input_grads.data_ptr()), static_cast(input_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, // CUBLAS_GEMM_ALGO10_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Input Linear Wgrad TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast(&alpha), // static_cast(inputs.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( static_cast(input_lin_grads.data_ptr()), static_cast(output_grads.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), inputs, static_cast(batches), // n1 static_cast(embed_dim), // n2 static_cast(lyr_nrm_gamma_weights.data_ptr()), static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast(input_grads.data_ptr()), static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr())); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_grads, output_weight_grads}; } } // end namespace cublas_gemmex } // end namespace self_norm_add } // end namespace multihead_attn ================================================ FILE: apex/contrib/csrc/multihead_attn/softmax.cuh ================================================ #pragma once #include #include #include "philox.cuh" #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include #include #include #include #include #include namespace { template __device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src); template __device__ __inline__ void apply_mask(Datatype* dst, Datatype value, const uint8_t* src); template __device__ __inline__ void apply_additive_mask(Datatype* dst, const Datatype* additive_mask); template <> __device__ __inline__ void copy_vector<__half, 1>(__half* dst, const __half* src) { *dst = *src; } template <> __device__ __inline__ void copy_vector(float* dst, const float* src) { *dst = *src; } template <> __device__ __inline__ void copy_vector<__half, 4>(__half* dst, const __half* src) { *((float2*)dst) = *((float2*)src); } template <> __device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) { *dst = *src; } template <> __device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) { *((half2*)dst) = *((half2*)src); } template <> __device__ __inline__ void apply_mask<__half, 1>(__half* dst, __half value, const uint8_t* src) { if (*src == 1) { *dst = value; } } template <> __device__ __inline__ void apply_additive_mask<__half, 1>(__half* dst, const __half* additive_mask) { *dst += *additive_mask; } template <> __device__ __inline__ void apply_additive_mask<__half, 4>(__half* dst, const __half* additive_mask) { *dst += *additive_mask; *(dst + 1) += *(additive_mask + 1); *(dst + 2) += *(additive_mask + 2); *(dst + 3) += *(additive_mask + 3); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Warp Softmax forward //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template __global__ void softmax_warp_forward(input_t* dst, const output_t* src, int batch_size, int stride, int element_count) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; // load data from global memory input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { copy_vector(&elements_input[i][it], src + i * element_count + it * WARP_SIZE); } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { // elements[i][it] = expf(elements[i][it] - max_value[i]); elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using softmax_forward_func = void (*)(input_t* dst, const output_t* src, int batch_size, int stride, int element_count); template bool warp_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp, softmax_forward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &softmax_warp_forward; break; case 1: // 2 kernel = &softmax_warp_forward; break; case 2: // 4 kernel = &softmax_warp_forward; break; case 3: // 8 kernel = &softmax_warp_forward; break; case 4: // 16 kernel = &softmax_warp_forward; break; case 5: // 32 kernel = &softmax_warp_forward; break; case 6: // 64 kernel = &softmax_warp_forward; break; case 7: // 128 kernel = &softmax_warp_forward; break; case 8: // 256 kernel = &softmax_warp_forward; break; case 9: // 512 kernel = &softmax_warp_forward; break; case 10: // 1024 kernel = &softmax_warp_forward; break; default: return false; } return true; } template bool dispatch_softmax(output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_forward_func kernel; int warp_size, batches_per_warp; if (!warp_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); return true; } return false; } template __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t* dst, uint8_t* dropout_mask, const input_t* src, const input_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { assert(ELEMENTS_PER_LDG_STG == 4); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; acc_t pinv = acc_t(1) / p; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; // vectorize if element_count is multiple of 4, else don't vectorize input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset; dst += thread_offset; dropout_mask += thread_offset; // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const half* curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { // masking_value is a large negative value elements_input[i][it + element] = -10000; } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); apply_additive_mask( &elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits::infinity() } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } auto seeds = at::cuda::philox::unpack(philox_args); Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); uint8_t rands[WARP_BATCH][WARP_ITERATIONS]; float4 rand_num; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { rand_num = uniform4(ph()); rands[i][it] = (rand_num.x <= p) > 0.5; rands[i][it + 1] = (rand_num.y <= p) > 0.5; rands[i][it + 2] = (rand_num.z <= p) > 0.5; rands[i][it + 3] = (rand_num.w <= p) > 0.5; copy_vector(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]); } } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { output_t out[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = rands[i][it + element] * (pinv * (elements[i][it + element] / sum[i])); } copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } } } } template __global__ void additive_masked_softmax_dropout_warp_forward(output_t* dst, uint8_t* dropout_mask, const input_t* src, const input_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; acc_t pinv = acc_t(1) / p; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; // vectorize if element_count is multiple of 4, else don't vectorize input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; int thread_offset = first_batch * stride + local_idx; src += thread_offset; dst += thread_offset; dropout_mask += thread_offset; // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + local_idx; const half* curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += 1) { int element_index = local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < 1; ++element) { // masking_value is a large negative value elements_input[i][it + element] = -10000; } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); apply_additive_mask(&elements_input[i][it], curr_mask + itr_jmp); } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } curandStatePhilox4_32_10_t state; auto seeds = at::cuda::philox::unpack(philox_args); curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += 1) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { output_t out[1]; acc_t softmax_out[1]; uint8_t dropout_mask_temp[1]; // generate a vector of random numbers here float rand = curand_uniform(&state); float* rand_ptr = (float*)(&rand); #pragma unroll for (int element = 0; element < 1; ++element) { softmax_out[element] = (elements[i][it + element] / sum[i]); rand_ptr[element] = rand_ptr[element] <= p; out[element] = rand_ptr[element] * pinv * softmax_out[element]; dropout_mask_temp[element] = rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f } copy_vector(dst + i * element_count + it * WARP_SIZE, out); copy_vector(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp); } else { break; } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using additive_masked_softmax_dropout_forward_func = void (*)(output_t* dst, uint8_t* dropout_mask, const input_t* src, const input_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p); template bool warp_additive_masked_softmax_dropout_kernel( int element_count, int log2_elements, int& warp_size, int& batches_per_warp, additive_masked_softmax_dropout_forward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; bool flag_vec4 = (element_count % 4 == 0); switch (log2_elements) { case 0: // 1 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 1: // 2 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 2: // 4 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 3: // 8 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 4: // 16 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 5: // 32 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 6: // 64 kernel = &additive_masked_softmax_dropout_warp_forward; break; case 7: // 128 if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4; else kernel = &additive_masked_softmax_dropout_warp_forward; break; case 8: // 256 if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4; else kernel = &additive_masked_softmax_dropout_warp_forward; break; case 9: // 512 if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4; else kernel = &additive_masked_softmax_dropout_warp_forward; break; case 10: // 1024 if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4; else kernel = &additive_masked_softmax_dropout_warp_forward; break; case 11: // 2048 if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4; else kernel = &additive_masked_softmax_dropout_warp_forward; break; default: return false; } return true; } template bool dispatch_additive_masked_softmax_dropout(output_t* dst, uint8_t* dropout_mask, const input_t* src, const input_t* pad_mask, int totalElements, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, cudaStream_t streamid) // p is the probability to keep, not drop { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 2048) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; additive_masked_softmax_dropout_forward_func kernel; int warp_size, batches_per_warp; if (!warp_additive_masked_softmax_dropout_kernel(softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; c10::optional gen_; auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1); at::PhiloxCudaState rng_engine_inputs; { std::lock_guard lock(gen->mutex_); rng_engine_inputs = gen->philox_cuda_state(counter_offset); } // compute launch size dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride, rng_engine_inputs, p); return true; } return false; } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template __global__ void additive_masked_softmax_warp_forward(input_t* dst, const output_t* src, const input_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset; dst += thread_offset; // load data from global memory input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const half* curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { // masking_value is a large negative value elements_input[i][it + element] = -10000; } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); // apply_mask(&elements_input[i][it], // (__half)-std::numeric_limits::infinity(), // curr_mask + itr_jmp); elements_input[i][it] += *(curr_mask + itr_jmp); } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { // elements[i][it] = expf(elements[i][it] - max_value[i]); elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using additive_masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const half* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride); template bool warp_additive_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp, additive_masked_softmax_forward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &additive_masked_softmax_warp_forward; break; case 1: // 2 kernel = &additive_masked_softmax_warp_forward; break; case 2: // 4 kernel = &additive_masked_softmax_warp_forward; break; case 3: // 8 kernel = &additive_masked_softmax_warp_forward; break; case 4: // 16 kernel = &additive_masked_softmax_warp_forward; break; case 5: // 32 kernel = &additive_masked_softmax_warp_forward; break; case 6: // 64 kernel = &additive_masked_softmax_warp_forward; break; case 7: // 128 kernel = &additive_masked_softmax_warp_forward; break; case 8: // 256 kernel = &additive_masked_softmax_warp_forward; break; case 9: // 512 kernel = &additive_masked_softmax_warp_forward; break; case 10: // 1024 kernel = &additive_masked_softmax_warp_forward; break; default: return false; } return true; } template bool dispatch_additive_masked_softmax(output_t* dst, const input_t* src, const input_t* pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; additive_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; if (!warp_additive_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>( dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } template bool dispatch_additive_masked_softmax_stream(output_t* dst, const input_t* src, const input_t* pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride, cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; additive_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; if (!warp_additive_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template __global__ void masked_softmax_warp_forward(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset; dst += thread_offset; // load data from global memory input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t* curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); apply_mask( &elements_input[i][it], __float2half(-std::numeric_limits::infinity()), curr_mask + itr_jmp); } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { // elements[i][it] = expf(elements[i][it] - max_value[i]); elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride); template bool warp_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp, masked_softmax_forward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &masked_softmax_warp_forward; break; case 1: // 2 kernel = &masked_softmax_warp_forward; break; case 2: // 4 kernel = &masked_softmax_warp_forward; break; case 3: // 8 kernel = &masked_softmax_warp_forward; break; case 4: // 16 kernel = &masked_softmax_warp_forward; break; case 5: // 32 kernel = &masked_softmax_warp_forward; break; case 6: // 64 kernel = &masked_softmax_warp_forward; break; case 7: // 128 kernel = &masked_softmax_warp_forward; break; case 8: // 256 kernel = &masked_softmax_warp_forward; break; case 9: // 512 kernel = &masked_softmax_warp_forward; break; case 10: // 1024 kernel = &masked_softmax_warp_forward; break; default: return false; } return true; } template bool dispatch_masked_softmax(output_t* dst, const input_t* src, const uint8_t* pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; masked_softmax_forward_func kernel; int warp_size, batches_per_warp; if (!warp_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>( dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template __global__ void time_masked_softmax_warp_forward(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int mod_seq_len) { assert(ELEMENTS_PER_LDG_STG == 1); int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset; dst += thread_offset; // load data from global memory input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t* curr_mask = pad_mask + pad_thread_offset; for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements_input[i][it + element] = -std::numeric_limits::infinity(); } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], src + itr_idx); apply_mask( &elements_input[i][it], __float2half(-std::numeric_limits::infinity()), curr_mask + itr_jmp); } } } // convert input_t to acc_t acc_t elements[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { // elements[i][it] = expf(elements[i][it] - max_value[i]); elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using time_masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int mod_seq_len); template bool warp_time_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp, time_masked_softmax_forward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &time_masked_softmax_warp_forward; break; case 1: // 2 kernel = &time_masked_softmax_warp_forward; break; case 2: // 4 kernel = &time_masked_softmax_warp_forward; break; case 3: // 8 kernel = &time_masked_softmax_warp_forward; break; case 4: // 16 kernel = &time_masked_softmax_warp_forward; break; case 5: // 32 kernel = &time_masked_softmax_warp_forward; break; case 6: // 64 kernel = &time_masked_softmax_warp_forward; break; case 7: // 128 kernel = &time_masked_softmax_warp_forward; break; case 8: // 256 kernel = &time_masked_softmax_warp_forward; break; case 9: // 512 kernel = &time_masked_softmax_warp_forward; break; case 10: // 1024 kernel = &time_masked_softmax_warp_forward; break; default: return false; } return true; } template bool dispatch_time_masked_softmax(output_t* dst, const input_t* src, const uint8_t* pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int mod_seq_len) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; time_masked_softmax_forward_func kernel; int warp_size, batches_per_warp; if (!warp_time_masked_softmax_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>( dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len); return true; } return false; } int log2_ceil_native(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); #endif } template __device__ __forceinline__ void warp_reduce_sum(acc_t* sum) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); sum[i] = sum[i] + b; } } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Warp softmax backward functions as fused variants of // at::softmax_backward_data function //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // softmax backward data function is taken from native pytorch, elementwise mul // is fused in the epolog, as well as masking and scaling for fusing dropout template __global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t* gradInput, const input_t* grad, const input_t* output, const uint8_t* mask, const uint8_t* pad_mask, acc_t scale, int batch_size, int stride, int element_count, int heads) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x % WARP_SIZE; // the first element to process by the current thread int thread_offset = first_batch * stride + local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; mask += thread_offset; // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified // to one loop, but I think doing so would obfuscate the logic of the // algorithm, thus I chose to keep the nested loops. This should have no // impact on performance because the loops are unrolled anyway. // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) * output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); output_reg[i][it] = acc_t(0); } } } acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = grad_reg[i][0]; #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += grad_reg[i][it]; } } warp_reduce_sum(sum); // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients int total_ind = thread_offset + i * element_count + it * WARP_SIZE; int pad_mask_ind = element_count * (total_ind / (heads * element_count * element_count)) + total_ind % element_count; uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind]; if (pad_mask_element == 0) gradInput[i * element_count + it * WARP_SIZE] = 0; else { if (is_log_softmax) { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } } template void dispatch_masked_scale_softmax_backward_masked_out(output_t* grad_input, const input_t* grad, const input_t* output, const uint8_t* mask, const uint8_t* pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; } else { int log2_elements = log2_ceil_native(softmax_elements); const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { case 0: // 1 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 1: // 2 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 2: // 4 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 3: // 8 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 4: // 16 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 5: // 32 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 6: // 64 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 7: // 128 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 8: // 256 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 9: // 512 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 10: // 1024 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; default: break; } } } template void dispatch_masked_scale_softmax_backward_masked_out_stream(output_t* grad_input, const input_t* grad, const input_t* output, const uint8_t* mask, const uint8_t* pad_mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, int heads, cudaStream_t streamid) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; } else { int log2_elements = log2_ceil_native(softmax_elements); const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { case 0: // 1 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 1: // 2 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 2: // 4 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 3: // 8 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 4: // 16 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 5: // 32 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 6: // 64 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 7: // 128 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 8: // 256 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 9: // 512 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; case 10: // 1024 masked_scale_softmax_warp_backward_masked_dgrad <<>>(grad_input, grad, output, mask, pad_mask, scale, batch_count, softmax_elements_stride, softmax_elements, heads); break; default: break; } } } template __global__ void masked_scale_softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, const uint8_t* mask, acc_t scale, int batch_size, int stride, int element_count) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x % WARP_SIZE; // the first element to process by the current thread int thread_offset = first_batch * stride + local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; mask += thread_offset; // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified // to one loop, but I think doing so would obfuscate the logic of the // algorithm, thus I chose to keep the nested loops. This should have no // impact on performance because the loops are unrolled anyway. // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) * output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); output_reg[i][it] = acc_t(0); } } } acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = grad_reg[i][0]; #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += grad_reg[i][it]; } } warp_reduce_sum(sum); // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients if (is_log_softmax) { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } template __global__ void masked_scale_softmax_warp_backward_recompute(output_t* gradInput, const input_t* grad, const input_t* softmax_input, const input_t* pad_mask, const uint8_t* mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x % WARP_SIZE; // vectorize if a row length is multiple of 4 int flag_vec4 = element_count & 3 == 0; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; // the first element to process by the current thread int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; softmax_input += thread_offset; gradInput += thread_offset; mask += thread_offset; // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified // to one loop, but I think doing so would obfuscate the logic of the // algorithm, thus I chose to keep the nested loops. This should have no // impact on performance because the loops are unrolled anyway. // load data from global memory for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const input_t* curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { // masking_value is a large negative value elements_input[i][it + element] = -10000; grad_reg[i][it + element] = acc_t(0); } if (element_index < batch_element_count) { int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; copy_vector(&elements_input[i][it], softmax_input + itr_idx); apply_additive_mask( &elements_input[i][it], curr_mask + itr_jmp); //(__half)-std::numeric_limits::infinity() uint8_t mask_temp[ELEMENTS_PER_LDG_STG]; input_t grad_temp[ELEMENTS_PER_LDG_STG]; copy_vector(&mask_temp[0], mask + itr_idx); copy_vector(&grad_temp[0], grad + itr_idx); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { grad_reg[i][it + element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale); } } } } // load data from global memory // convert input_t to acc_t // TODO : remove this, input is already acc_t type in register acc_t elements[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = elements_input[i][it]; } } constexpr uint32_t FULL_MASK = 0xffffffff; // compute local max_value // take the max_value of the first element to avoid one max call acc_t max_value[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; } #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } // reduction max_value #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; } } // compute local sum acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { // elements[i][it] = expf(elements[i][it] - max_value[i]); elements[i][it] = std::exp(elements[i][it] - max_value[i]); sum[i] += elements[i][it]; } } // reduction sum #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it++) { elements[i][it] = elements[i][it] / sum[i]; grad_reg[i][it] = grad_reg[i][it] * elements[i][it]; } } acc_t grad_sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { grad_sum[i] = grad_reg[i][0]; #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { grad_sum[i] += grad_reg[i][it]; } } warp_reduce_sum(grad_sum); // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients output_t grad_input_reg[ELEMENTS_PER_LDG_STG]; #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) { if (is_log_softmax) { grad_input_reg[element] = (grad_reg[i][it + element] - std::exp(elements[i][it + element]) * grad_sum[i]); } else { grad_input_reg[element] = (grad_reg[i][it + element] - elements[i][it + element] * grad_sum[i]); } } copy_vector(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg); } } } } template using masked_scale_softmax_warp_backward_recompute_func = void (*)(output_t* gradInput, const input_t* grad, const input_t* softmax_input, const input_t* pad_mask, const uint8_t* mask, acc_t scale, int batch_size, int stride, int pad_batch_stride, int element_count); template bool masked_scale_softmax_warp_backward_recompute_kernel( int element_count, int log2_elements, int& warp_size, int& batches_per_warp, masked_scale_softmax_warp_backward_recompute_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; bool flag_vec4 = (element_count % 4 == 0); switch (log2_elements) { case 0: // 1 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 1: // 2 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 2: // 4 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 3: // 8 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 4: // 16 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 5: // 32 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 6: // 64 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 7: // 128 kernel = &masked_scale_softmax_warp_backward_recompute; break; case 8: // 256 if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute; else kernel = &masked_scale_softmax_warp_backward_recompute; break; case 9: // 512 if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute; else kernel = &masked_scale_softmax_warp_backward_recompute; break; case 10: // 1024 if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute; else kernel = &masked_scale_softmax_warp_backward_recompute; break; case 11: // 2048 if (flag_vec4) kernel = &masked_scale_softmax_warp_backward_recompute; else kernel = &masked_scale_softmax_warp_backward_recompute; break; default: return false; } return true; } template bool dispatch_masked_scale_softmax_backward_recompute(output_t* grad_input, const input_t* grad, const input_t* softmax_input, const input_t* pad_mask, const uint8_t* mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int pad_batch_stride, int batch_count, cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 2048) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; masked_scale_softmax_warp_backward_recompute_func kernel; int warp_size, batches_per_warp; if (!masked_scale_softmax_warp_backward_recompute_kernel( softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; // compute launch size dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, softmax_elements_stride, pad_batch_stride, softmax_elements); return true; } return false; } template void dispatch_masked_scale_softmax_backward_stream(output_t* grad_input, const input_t* grad, const input_t* output, const uint8_t* mask, acc_t scale, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; } else { int log2_elements = log2_ceil_native(softmax_elements); const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { case 0: // 1 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 1: // 2 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 2: // 4 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 3: // 8 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 4: // 16 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 5: // 32 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 6: // 64 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 7: // 128 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 8: // 256 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 9: // 512 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 10: // 1024 masked_scale_softmax_warp_backward <<>>(grad_input, grad, output, mask, scale, batch_count, softmax_elements_stride, softmax_elements); break; default: break; } } } // elementwise multiplication called in at::softmax_backward_data is fused // inside softmax dgrad kernel as a result of fusion, intermediate // multiplication result is stored in fp32 in registers, instead of fp16 template __global__ void softmax_warp_backward_fused_native(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size, int stride, int element_count) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x % WARP_SIZE; // the first element to process by the current thread int thread_offset = first_batch * stride + local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified // to one loop, but I think doing so would obfuscate the logic of the // algorithm, thus I chose to keep the nested loops. This should have no // impact on performance because the loops are unrolled anyway. // load data from global memory acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] * output[i * element_count + it * WARP_SIZE]; output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; } else { grad_reg[i][it] = acc_t(0); output_reg[i][it] = acc_t(0); } } } acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = grad_reg[i][0]; //* output_reg[i][0]; #pragma unroll for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += grad_reg[i][it]; // * output_reg[i][it]; } } warp_reduce_sum(sum); // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients if (is_log_softmax) { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); } else { gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); } } } } } template void dispatch_softmax_backward_fused_native(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); if (softmax_elements == 0) { return; } else { int log2_elements = log2_ceil_native(softmax_elements); const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { case 0: // 1 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 1: // 2 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 2: // 4 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 3: // 8 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 4: // 16 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 5: // 32 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 6: // 64 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 7: // 128 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 8: // 256 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 9: // 512 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; case 10: // 1024 softmax_warp_backward_fused_native <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); break; default: break; } } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Warp softmax backward //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void softmax_warp_backward(__half* gradInput, const __half* grad, const __half* output, int batch_size, int stride, int element_count) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { copy_vector(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); copy_vector(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE); } } } // convert half to floating point acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { grad_reg[i][it] = grad_reg_input[i][it]; output_reg[i][it] = output_reg_input[i][it]; } } // compute thread local sum acc_t sum[WARP_BATCH] = {0}; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += grad_reg[i][it] * output_reg[i][it]; } } // reduction sum constexpr uint32_t FULL_MASK = 0xffffffff; #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i])); } // store them in global memory copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using softmax_backward_func = void (*)(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size, int stride, int element_count); template bool warp_softmax_backward_kernel(int log2_elements, int& warp_size, int& batches_per_warp, softmax_backward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &softmax_warp_backward; break; case 1: // 2 kernel = &softmax_warp_backward; break; case 2: // 4 kernel = &softmax_warp_backward; break; case 3: // 8 kernel = &softmax_warp_backward; break; case 4: // 16 kernel = &softmax_warp_backward; break; case 5: // 32 kernel = &softmax_warp_backward; break; case 6: // 64 kernel = &softmax_warp_backward; break; case 7: // 128 kernel = &softmax_warp_backward; break; case 8: // 256 kernel = &softmax_warp_backward; break; case 9: // 512 kernel = &softmax_warp_backward; break; case 10: // 1024 kernel = &softmax_warp_backward; break; default: return false; } return true; } template bool dispatch_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_backward_func kernel; int warp_size, batches_per_warp; if (!warp_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); return true; } return false; } template bool dispatch_softmax_backward_stream(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count, cudaStream_t streamid) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; softmax_backward_func kernel; int warp_size, batches_per_warp; if (!warp_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); return true; } return false; } template __global__ void masked_softmax_warp_backward(__half* gradInput, const __half* grad, const __half* output, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride) { int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; // batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = batch_size - first_batch; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the // batch int local_idx = threadIdx.x; // the first element to process by the current thread int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; grad += thread_offset; output += thread_offset; gradInput += thread_offset; // load data from global memory input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { copy_vector(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); copy_vector(&output_reg_input[i][it], output + i * element_count + it * WARP_SIZE); } } } // convert half to floating point acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < WARP_ITERATIONS; ++it) { grad_reg[i][it] = grad_reg_input[i][it]; output_reg[i][it] = output_reg_input[i][it]; } } // compute thread local sum acc_t sum[WARP_BATCH] = {0}; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += grad_reg[i][it] * output_reg[i][it]; } } // reduction sum constexpr uint32_t FULL_MASK = 0xffffffff; #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); } } // store result #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { if (i >= local_batches) break; int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx; const uint8_t* curr_mask = pad_mask + pad_thread_offset; #pragma unroll for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i])); } // store them in global memory int itr_jmp = it * WARP_SIZE; int itr_idx = i * element_count + itr_jmp; // It is kind of unfortunate this has to be here to zero something out // that is close to zero in the first place apply_mask(&out[0], 0.0, curr_mask + itr_jmp); copy_vector(gradInput + itr_idx, out); } } } } // WARP_BATCH number of batches. // WARP_ITERATOINS The number of iterations required for one warp to iterate // over all data. WARP_SIZE number of elements working on a single batch, has to // be a power of two. ELEMENTS_PER_LDG_STG has to be 1. template using masked_softmax_backward_func = void (*)(output_t* gradInput, const input_t* grad, const input_t* output, const uint8_t* pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride); template bool warp_masked_softmax_backward_kernel(int log2_elements, int& warp_size, int& batches_per_warp, masked_softmax_backward_func& kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; switch (log2_elements) { case 0: // 1 kernel = &masked_softmax_warp_backward; break; case 1: // 2 kernel = &masked_softmax_warp_backward; break; case 2: // 4 kernel = &masked_softmax_warp_backward; break; case 3: // 8 kernel = &masked_softmax_warp_backward; break; case 4: // 16 kernel = &masked_softmax_warp_backward; break; case 5: // 32 kernel = &masked_softmax_warp_backward; break; case 6: // 64 kernel = &masked_softmax_warp_backward; break; case 7: // 128 kernel = &masked_softmax_warp_backward; break; case 8: // 256 kernel = &masked_softmax_warp_backward; break; case 9: // 512 kernel = &masked_softmax_warp_backward; break; case 10: // 1024 kernel = &masked_softmax_warp_backward; break; default: return false; } return true; } template bool dispatch_masked_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, const uint8_t* pad_mask, int softmax_elements, int softmax_elements_stride, int batch_count, int pad_batch_stride) { if (softmax_elements == 0) { return true; } else if (softmax_elements <= 1024) { // compute function index. there's a function for each power of two size up // to 1024. int log2_elements = 0; while ((1 << log2_elements) < softmax_elements) ++log2_elements; masked_softmax_backward_func kernel; int warp_size, batches_per_warp; if (!warp_masked_softmax_backward_kernel(log2_elements, warp_size, batches_per_warp, kernel)) { return false; } // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; // compute warps per block. int warps_per_block = (threads_per_block / warp_size); // compute launch size int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // launch kernel<<>>( grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); return true; } return false; } } // namespace ================================================ FILE: apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh ================================================ #pragma once #include #include #include #include #include #include // #include #include #include #include #include #include #include #include #include #include namespace { cublasOperation_t convertTransToCublasOperation(char trans) { if (trans == 't') return CUBLAS_OP_T; else if (trans == 'n') return CUBLAS_OP_N; else if (trans == 'c') return CUBLAS_OP_C; else { TORCH_CHECK(false, "trans must be one of: t, n, c"); return CUBLAS_OP_T; } } void CublasStridedBatchedGemm(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda, long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opb = convertTransToCublasOperation(transb); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); float fAlpha = alpha; float fBeta = beta; TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx( handle, opa, opb, (int)m, (int)n, (int)k, (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo)); } } // namespace // TODO(mkozuki): Make use of the int template parameters or discard them. template void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, float alpha, const half* a, long lda, long long int batch_stride_A, const half* b, long ldb, long long int batch_stride_B, float beta, half* c, long ldc, long long int batch_stride_C, long batch_count) { using Gemm = cutlass::gemm::device::GemmBatched< /* Element type of A matrix */ half, /* Layout of A matrix */ LayoutA, /* Element type of B matrix */ half, /* Layout of B matrix */ LayoutB, /* Element type of C matrix */ half, /* Layout of C matrix */ cutlass::layout::ColumnMajor, /* Element Accumulator*/ float>; Gemm gemm_op; cutlass::Status status = gemm_op({{static_cast(m), static_cast(n), static_cast(k)}, {a, lda}, batch_stride_A, {b, ldb}, batch_stride_B, {c, ldc}, batch_stride_C, {c, ldc}, batch_stride_C, {alpha, beta}, static_cast(batch_count)}, nullptr, stream); C10_CUDA_CHECK(status != cutlass::Status::kSuccess ? cudaErrorUnknown : cudaSuccess); } namespace { void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda, long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc, long strideC, long batchCount) { auto stream = c10::cuda::getCurrentCUDAStream(); // printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta); if ((transa == 't') && (transb == 'n')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if ((transa == 'n') && (transb == 'n')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else if ((transa == 'n') && (transb == 't')) { if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum( stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } else { CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } else { TORCH_CHECK(false, "TransA and TransB are invalid"); } } void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t* lda, int64_t* ldb, int64_t* ldc) { int transa_ = ((transa == 't') || (transa == 'T')); int transb_ = ((transb == 't') || (transb == 'T')); // Note: leading dimensions generally are checked that they are > 0 and at // least as big the result requires (even if the value won't be used). if (n <= 1) *ldc = std::max(m, 1); if (transa_) { if (m <= 1) *lda = std::max(k, 1); } else { if (k <= 1) *lda = std::max(m, 1); } if (transb_) { if (k <= 1) *ldb = std::max(n, 1); } else { if (n <= 1) *ldb = std::max(k, 1); } } void HgemmStridedBatched(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda, long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc, long strideC, long batchCount) { if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX)) { TORCH_CHECK(false, "Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, " "batchCount" "with the bound [val] <= %d", INT_MAX); } adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } } // namespace ================================================ FILE: apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp ================================================ #include #include #include #include #include #define NCCL_CHECK(cmd) \ do { \ ncclResult_t result = cmd; \ if (result != ncclSuccess) { \ std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + std::to_string(__LINE__) + ", " + \ std::string(ncclGetErrorString(result)); \ TORCH_CHECK(false, err); \ } \ } while (0) void* nccl_alloc_plug(size_t size, int device, void* stream) { void* ptr; NCCL_CHECK(ncclMemAlloc(&ptr, size)); return ptr; } void nccl_free_plug(void* ptr, std::size_t size, int device, void* stream) { NCCL_CHECK(ncclMemFree(ptr)); } std::shared_ptr nccl_allocator; void maybe_init() { if (!nccl_allocator) { nccl_allocator = std::make_shared(nccl_alloc_plug, nccl_free_plug); } } std::shared_ptr get_nccl_allocator() { maybe_init(); return nccl_allocator; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_nccl_allocator", []() { return get_nccl_allocator(); }); }; ================================================ FILE: apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp ================================================ /** * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "nccl_p2p_cuda.cuh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", py::call_guard()); m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", py::call_guard()); m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace", py::call_guard()); m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", py::call_guard()); m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu ================================================ #include #include #include #include #include #include #include #include "nccl.h" /* * This file implements a crude but effective mechanism for copying data between tenors owned by different ranks * on the same machine using cudaMemcpyAsync peer-to-peer transfers. */ namespace { __global__ void AddDelay_kernel(const int delay, int* counter) { if (blockIdx.x == 0 && threadIdx.x == 0) { // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. int new_counter = 0; double elapsed = 0; clock_t start = clock(); do { clock_t now = clock(); elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC; ++new_counter; } while (elapsed < (double)delay); *counter = new_counter; } } class NcclCommWrapper { private: ncclComm_t comm; int rank, world_size; ncclDataType_t get_nccl_type(at::Tensor input) { switch (input.scalar_type()) { case at::ScalarType::Half: return ncclFloat16; case at::ScalarType::Float: return ncclFloat32; case at::ScalarType::Double: return ncclFloat64; case at::ScalarType::Byte: return ncclUint8; case at::ScalarType::Char: return ncclInt8; case at::ScalarType::Int: return ncclInt32; case at::ScalarType::Long: return ncclInt64; case at::ScalarType::BFloat16: return ncclBfloat16; default: assert(false); } } public: NcclCommWrapper() { memset(&comm, 0, sizeof(ncclComm_t)); rank = 0; world_size = 0; } NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks) { ncclCommInitRank(&comm, num_ranks, id, my_rank); rank = my_rank; world_size = num_ranks; } ~NcclCommWrapper() { printf("ncclCommDestroy()\n"); ncclCommDestroy(comm); } void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) { auto stream = at::cuda::getCurrentCUDAStream(); ncclGroupStart(); ncclDataType_t ncclType = get_nccl_type(left_output_halo); bool left_zero = (left_rank < 0); bool right_zero = (right_rank < 0); size_t left_n = torch::numel(left_output_halo); size_t right_n = torch::numel(right_output_halo); assert(left_n > 0 && left_n == right_n); if (left_zero) { left_input_halo.zero_(); } else { AT_DISPATCH_ALL_TYPES_AND3( at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() { // send left (to my_rank - 1) ncclSend(left_output_halo.data_ptr(), left_n, ncclType, left_rank, comm, stream); // receive left (from my_rank - 1) ncclRecv(left_input_halo.data_ptr(), right_n, ncclType, left_rank, comm, stream); }); } if (right_zero) { right_input_halo.zero_(); } else { AT_DISPATCH_ALL_TYPES_AND3( at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() { // send right (to my_rank + 1 ) ncclSend(right_output_halo.data_ptr(), right_n, ncclType, right_rank, comm, stream); // receive right (from my_rank + 1) ncclRecv(right_input_halo.data_ptr(), left_n, ncclType, right_rank, comm, stream); }); } ncclGroupEnd(); } std::vector left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) { // after halo exchange: // left_output_halo of rank+1 ends up in right_input_halo of rank // right_output_halo of rank-1 ends up in left_input_halo of rank auto right_input_halo = torch::empty_like(left_output_halo); auto left_input_halo = torch::empty_like(right_output_halo); left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); return {left_input_halo, right_input_halo}; } }; class ManagedObjects { public: ManagedObjects() {} ~ManagedObjects() { for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it) { delete *it; } } int add_comm(NcclCommWrapper* comm) { int handle = _nccl_comms.size(); _nccl_comms.push_back(comm); return handle; } NcclCommWrapper& get_comm(int handle) { assert(handle >= 0 && handle < _nccl_comms.size()); return *_nccl_comms[handle]; } private: std::vector _nccl_comms; }; class ManagedObjects mo; } // end anonymous namespace namespace apex { namespace contrib { namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n) { ncclUniqueId id; ncclGetUniqueId(&id); auto id_tensor = torch::empty({n, (int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); auto id_ptr = id_tensor.data_ptr(); size_t offset = 0; for (int i = 0; i < n; ++i) { ncclUniqueId id; ncclGetUniqueId(&id); memcpy(id_ptr + offset, &id, sizeof(ncclUniqueId)); offset += sizeof(ncclUniqueId); } return id_tensor; } int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) { ncclUniqueId id; auto unique_nccl_id_ptr = unique_nccl_id.data_ptr(); memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId)); NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks); int handle = mo.add_comm(comm); comm = 0L; return handle; } void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) { class NcclCommWrapper& communicator = mo.get_comm(handle); return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); } std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) { class NcclCommWrapper& communicator = mo.get_comm(handle); return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo); } void add_delay(int delay) { auto stream = at::cuda::getCurrentCUDAStream(); auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); AddDelay_kernel<<<1, 1, 0, stream>>>(delay, t.data_ptr()); } } // namespace nccl_p2p } // namespace contrib } // namespace apex ================================================ FILE: apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh ================================================ /** * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #ifndef _nccl_p2p_h_ #define _nccl_p2p_h_ namespace apex { namespace contrib { namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n); int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks); void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo); std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo); void add_delay(int delay); } // namespace nccl_p2p } // namespace contrib } // namespace apex #endif ================================================ FILE: apex/contrib/csrc/nccl_p2p/nccl_version.cpp ================================================ // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. // This file is used to check the version of NCCL detected. #include #include std::tuple get_nccl_version(); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_nccl_version", &get_nccl_version); } ================================================ FILE: apex/contrib/csrc/nccl_p2p/nccl_version_check.cu ================================================ // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. // This file is used to check the version of NCCL detected. #include #include std::tuple get_nccl_version() { return {int(NCCL_MAJOR), int(NCCL_MINOR)}; } ================================================ FILE: apex/contrib/csrc/optimizers/fused_adam_cuda.cpp ================================================ #include // CUDA forward declaration void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first); void fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out); void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) // C++ interface void strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) { CHECK_INPUT(p_copy); fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first); } void adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { CHECK_INPUT(p); if (p_copy.numel() > 0) CHECK_INPUT(p_copy); CHECK_INPUT(m); CHECK_INPUT(v); CHECK_INPUT(g); int64_t num_elem = p.numel(); TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); } void reversible_adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { CHECK_INPUT(p); if (p_copy.numel() > 0) CHECK_INPUT(p_copy); CHECK_INPUT(m); CHECK_INPUT(v); CHECK_INPUT(g); int64_t num_elem = p.numel(); TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); } void maybe_adam_undo(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { CHECK_INPUT(p); CHECK_INPUT(m); CHECK_INPUT(v); CHECK_INPUT(g); int64_t num_elem = p.numel(); TORCH_CHECK(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); TORCH_CHECK(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); TORCH_CHECK(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); } void maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) { CHECK_INPUT(p_in); CHECK_INPUT(p_out); int64_t num_elem = p_in.numel(); TORCH_CHECK(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); maybe_cast_cuda(overflow_flag, p_in, p_out); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", py::call_guard()); m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", py::call_guard()); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", py::call_guard()); m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", py::call_guard()); m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu ================================================ #include #include #include #include #include "ATen/ATen.h" #include "ATen/TensorUtils.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/detail/IndexUtils.cuh" // #include "ATen/Type.h" #include "ATen/AccumulateType.h" #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 #define ILP 4 template __device__ __forceinline__ bool is_aligned(T* p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } template __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } #include "type_shim.h" typedef enum { ADAM_MODE_0 = 0, // eps under square root ADAM_MODE_1 = 1 // eps outside square root } adamMode_t; template __global__ void adam_cuda_kernel(T* __restrict__ p, GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, const float b2, const float eps, const float grad_scale, const float step_size, const size_t tsize, adamMode_t mode, const float decay) { // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock); const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; for (int j = i; j < tsize; j += totThreads) { T scaled_grad = g[j] / grad_scale; m[j] = b1 * m[j] + (1 - b1) * scaled_grad; v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad; float denom; if (mode == ADAM_MODE_0) denom = sqrtf(v[j] + eps); else // Mode 1 denom = sqrtf(v[j]) + eps; float update = (m[j] / denom) + (decay * p[j]); p[j] = p[j] - (step_size * update); if (p_copy != NULL) p_copy[j] = (GRAD_T)p[j]; } } template struct AdamFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata& tl, const float b1, const float b2, const float eps, const float grad_scale, const float step_size, adamMode_t mode, const float decay) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; T* p = (T*)tl.addresses[0][tensor_loc]; p += chunk_idx * chunk_size; T* m = (T*)tl.addresses[1][tensor_loc]; m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[2][tensor_loc]; v += chunk_idx * chunk_size; GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; g += chunk_idx * chunk_size; GRAD_T* p_copy = NULL; if (DEPTH == 5) { p_copy = (GRAD_T*)tl.addresses[4][tensor_loc]; p_copy += chunk_idx * chunk_size; } n -= chunk_idx * chunk_size; T incoming_p[ILP]; T incoming_m[ILP]; T incoming_v[ILP]; T incoming_g[ILP]; // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_copy)) { for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load GRAD_T tmp_g[ILP]; load_store(incoming_p, p, 0, i_start); load_store(incoming_m, m, 0, i_start); load_store(incoming_v, v, 0, i_start); load_store(tmp_g, g, 0, i_start); #pragma unroll for (int ii = 0; ii < ILP; ii++) { incoming_g[ii] = static_cast(tmp_g[ii]); T scaled_grad = incoming_g[ii] / grad_scale; incoming_m[ii] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad; incoming_v[ii] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad; float denom; if (mode == ADAM_MODE_0) denom = sqrtf(incoming_v[ii] + eps); else // Mode 1 denom = sqrtf(incoming_v[ii]) + eps; float update = (incoming_m[ii] / denom) + (decay * incoming_p[ii]); incoming_p[ii] = incoming_p[ii] - (step_size * update); if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); } load_store(p, incoming_p, i_start, 0); load_store(m, incoming_m, i_start, 0); load_store(v, incoming_v, i_start, 0); if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); } } else { for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { incoming_p[ii] = 0; incoming_m[ii] = 0; incoming_v[ii] = 0; incoming_g[ii] = 0; int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { incoming_p[ii] = p[i]; incoming_m[ii] = m[i]; incoming_v[ii] = v[i]; incoming_g[ii] = static_cast(g[i]); } } // note for clarification to future michael: // From a pure memory dependency perspective, there's likely no point unrolling // the write loop, since writes just fire off once their LDGs arrive. // Put another way, the STGs are dependent on the LDGs, but not on each other. // There is still compute ILP benefit from unrolling the loop though. #pragma unroll for (int ii = 0; ii < ILP; ii++) { int j = i_start + threadIdx.x + ii * blockDim.x; if (j < n && j < chunk_size) { T scaled_grad = incoming_g[ii] / grad_scale; m[j] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad; v[j] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad; float denom; if (mode == ADAM_MODE_0) denom = sqrtf(v[j] + eps); else // Mode 1 denom = sqrtf(v[j]) + eps; float update = (m[j] / denom) + (decay * incoming_p[ii]); p[j] = incoming_p[ii] - (step_size * update); if (DEPTH == 5) p_copy[j] = (GRAD_T)p[j]; } } } } } }; void fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { // using namespace at; // Get tensor size int tsize = p.numel(); // Determine #threads and #blocks const int threadsPerBlock = 512; const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); // Constants float step_size = 0; if (bias_correction == 1) { const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction2 = 1 - std::pow(beta2, step); step_size = lr * std::sqrt(bias_correction2) / bias_correction1; } else { step_size = lr; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (g.scalar_type() == at::ScalarType::Half) { // all other values should be fp32 for half gradients TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); // dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; adam_cuda_kernel<<>>( p.data_ptr(), p_copy.numel() ? p_copy.data_ptr() : NULL, m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } else { using namespace at; DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", adam_cuda_kernel<<>>( p.data_ptr(), NULL, // don't output p_copy for fp32, it's wasted write m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } C10_CUDA_CHECK(cudaGetLastError()); } void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, // p, m, v, g, p_copy float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { // Constants float step_size = 0; if (bias_correction == 1) { const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction2 = 1 - std::pow(beta2, step); step_size = lr * std::sqrt(bias_correction2) / bias_correction1; } else { step_size = lr; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); size_t tl_sz = tensor_lists.size(); TORCH_CHECK(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { // alher values should be fp32 for half gradients TORCH_CHECK(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); // dich is done on the gradient type if (tl_sz == 5) { DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", using accscalar_t = at::acc_type; multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor<5, accscalar_t, scalar_t_0>(), beta1, beta2, eps, grad_scale, step_size, (adamMode_t)mode, decay);); } else { DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", using accscalar_t = at::acc_type; multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor<4, accscalar_t, scalar_t_0>(), beta1, beta2, eps, grad_scale, step_size, (adamMode_t)mode, decay);); } } else { if (tl_sz == 5) { DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor<5, scalar_t_0, scalar_t_0>(), beta1, beta2, eps, grad_scale, step_size, (adamMode_t)mode, decay);); } else { DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctor<4, scalar_t_0, scalar_t_0>(), beta1, beta2, eps, grad_scale, step_size, (adamMode_t)mode, decay);); } } C10_CUDA_CHECK(cudaGetLastError()); } template __device__ void convert(const FROM_T vi, TO_T& vo) { vo = static_cast(vi); } template <> __device__ void convert(const float vi, uint8_t& vo) { union S { float as_float; int as_int; }; S s; s.as_float = vi; s.as_int = s.as_int & 0xFF800000; union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_half = static_cast(vi + s.as_float / 8.0f); vo = t.as_byte[1]; } template <> __device__ void convert(const uint8_t vi, float& vo) { union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_byte[0] = 0; t.as_byte[1] = vi; vo = static_cast(t.as_half); } template <> __device__ void convert(const at::Half vi, uint8_t& vo) { union S { float as_float; int as_int; }; S s; s.as_float = static_cast(vi); s.as_int = s.as_int & 0xFF800000; union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_half = static_cast(vi + s.as_float / 8.0f); vo = t.as_byte[1]; } template <> __device__ void convert(const uint8_t vi, at::Half& vo) { union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_byte[0] = 0; t.as_byte[1] = vi; vo = t.as_half; } template __global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, GRAD_T* __restrict__ p_copy, const size_t tsize, int stride, int clear_overflow_first) { // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride; if (clear_overflow_first) { if (i == 0) { *noop_gmem = 0; } __syncthreads(); } for (int j = i; j < tsize; j += totThreads) { GRAD_T pi = p_copy[j]; if (!isfinite(pi)) { *noop_gmem = 1; } } } template <> __global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, uint8_t* __restrict__ p_copy, const size_t tsize, int stride, int clear_overflow_first) { // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride; if (clear_overflow_first) { if (i == 0) { *noop_gmem = 0; } __syncthreads(); } for (int j = i; j < tsize; j += totThreads) { at::Half pi; convert(p_copy[j], pi); if (!isfinite(pi)) { *noop_gmem = 1; } } } template __global__ void maybe_cast_kernel(volatile int* overflow_flag, const FROM_T* p_in, TO_T* p_out, const size_t tsize) { if (overflow_flag && *overflow_flag != 0) return; // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock); const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; FROM_T pi[ILP]; TO_T po[ILP]; for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { pi[ii] = 0; int j = j_start + i + totThreads * ii; if (j < tsize) { pi[ii] = p_in[j]; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { convert(pi[ii], po[ii]); } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int j = j_start + i + totThreads * ii; if (j < tsize) { p_out[j] = po[ii]; } } } } template __global__ void reversible_adam_cuda_kernel( T* __restrict__ p, REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, const float b2, const float eps, const float grad_scale, const float step_size, const size_t tsize, adamMode_t mode, const float decay) { // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock); const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; T mi[ILP]; T vi[ILP]; T pi[ILP]; T gi[ILP]; bool overflow = false; for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { mi[ii] = T(0); vi[ii] = T(0); pi[ii] = T(0); gi[ii] = GRAD_T(0); int j = j_start + i + totThreads * ii; if (j < tsize) { pi[ii] = p[j]; mi[ii] = m[j]; vi[ii] = v[j]; gi[ii] = static_cast(g[j]); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { T scaled_grad = gi[ii] / grad_scale; if (isfinite(scaled_grad)) { mi[ii] = b1 * mi[ii] + (1 - b1) * scaled_grad; vi[ii] = b2 * vi[ii] + (1 - b2) * scaled_grad * scaled_grad; float denom; if (mode == ADAM_MODE_0) denom = sqrtf(vi[ii] + eps); else // Mode 1 denom = sqrtf(vi[ii]) + eps; float update = (mi[ii] / denom) + (decay * pi[ii]); pi[ii] = pi[ii] - (step_size * update); } else { overflow = true; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int j = j_start + i + totThreads * ii; if (j < tsize) { m[j] = mi[ii]; v[j] = vi[ii]; p[j] = pi[ii]; if (p_copy != NULL) { convert(pi[ii], p_copy[j]); } } } } if (p_copy != NULL) { __syncthreads(); if (overflow) { convert(float(INFINITY), p_copy[0]); } } } template __global__ void maybe_adam_undo_cuda_kernel(volatile int* overflow_flag, T* __restrict__ p, T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, const float b2, const float eps, const float grad_scale, const float step_size, const size_t tsize, adamMode_t mode, const float decay) { // NB! Skip undo kernel when overflow flag is NOT set if (overflow_flag && *overflow_flag == 0) return; // Assuming 2D grids and 2D blocks const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int threadsPerBlock = blockDim.x * blockDim.y; const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; const int i = (blockId * threadsPerBlock + threadIdInBlock); const int totThreads = gridDim.x * gridDim.y * threadsPerBlock; T mi[ILP]; T vi[ILP]; T pi[ILP]; T gi[ILP]; for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { mi[ii] = T(0); vi[ii] = T(0); pi[ii] = T(0); gi[ii] = GRAD_T(0); int j = j_start + i * ILP; if (j < tsize) { pi[ii] = p[j]; mi[ii] = m[j]; vi[ii] = v[j]; gi[ii] = static_cast(g[j]); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { T scaled_grad = gi[ii] / grad_scale; if (isfinite(scaled_grad)) { float denom; if (mode == ADAM_MODE_0) denom = sqrtf(vi[ii] + eps); else // Mode 1 denom = sqrtf(vi[ii]) + eps; pi[ii] = (pi[ii] + step_size * (mi[ii] / denom)) / (1.0f - step_size * decay); mi[ii] = (mi[ii] - (1 - b1) * scaled_grad) / b1; vi[ii] = (vi[ii] - (1 - b2) * scaled_grad * scaled_grad) / b2; // Make sure round off errors don't create (small) negative value. // This can happen if we have to revert the very first step. vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int j = j_start + i * ILP; if (j < tsize) { m[j] = mi[ii]; v[j] = vi[ii]; p[j] = pi[ii]; } } } } template struct MaybeCastFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* overflow_flag, TensorListMetadata& tl) { if (overflow_flag && *overflow_flag != 0) return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; FROM_T* p_in = (FROM_T*)tl.addresses[0][tensor_loc]; p_in += chunk_idx * chunk_size; TO_T* p_out = (TO_T*)tl.addresses[1][tensor_loc]; p_out += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; int dim = chunk_size < n ? chunk_size : n; FROM_T pi[ILP]; TO_T po[ILP]; for (int j_start = 0; j_start < dim; j_start += blockDim.x * ILP) { #pragma unroll for (int ii = 0; ii < ILP; ii++) { pi[ii] = FROM_T(0); int j = j_start + threadIdx.x + ii * blockDim.x; if (j < dim) { pi[ii] = p_in[j]; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { convert(pi[ii], po[ii]); } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int j = j_start + threadIdx.x + ii * blockDim.x; if (j < dim) { p_out[j] = po[ii]; } } } } }; void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) { // Get tensor size int tsize = p_copy.numel(); int niter = (tsize + stride - 1) / stride; // Determine #threads and #blocks const int threadsPerBlock = 512; // In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set. const dim3 blocks(clear_overflow_first ? 1 : (niter + threadsPerBlock - 1) / threadsPerBlock); TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32"); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); using namespace at; // prevents "toString is undefined" errors DISPATCH_FLOAT_HALF_AND_BYTE( p_copy.scalar_type(), 0, "check_finite_cuda_kernel", strided_check_finite_cuda_kernel<<>>( overflow_flag.data_ptr(), p_copy.data_ptr(), tsize, stride, clear_overflow_first);); C10_CUDA_CHECK(cudaGetLastError()); } void fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { // using namespace at; // Get tensor size int tsize = p.numel(); // Determine #threads and #blocks const int threadsPerBlock = 512; const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); // Constants float step_size = 0; if (bias_correction == 1) { const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction2 = 1 - std::pow(beta2, step); step_size = lr * std::sqrt(bias_correction2) / bias_correction1; } else { step_size = lr; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (g.scalar_type() == at::ScalarType::Half) { // all other values should be fp32 for half gradients TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); // dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { DISPATCH_FLOAT_AND_HALF( g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; reversible_adam_cuda_kernel<<>>( p.data_ptr(), p_copy.numel() ? p_copy.data_ptr() : NULL, m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } else { TORCH_CHECK(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); DISPATCH_FLOAT_AND_HALF( g.scalar_type(), 0, "adam_cuda_e5m2_kernel", using accscalar_t = at::acc_type; reversible_adam_cuda_kernel<<>>( p.data_ptr(), p_copy.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } } else { using namespace at; DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", reversible_adam_cuda_kernel <<>>( p.data_ptr(), NULL, // don't output p_copy for fp32, it's wasted write m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } C10_CUDA_CHECK(cudaGetLastError()); } void maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) { // Get tensor size int tsize = p_in.numel(); TORCH_CHECK(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()"); // Determine #threads and #blocks const int threadsPerBlock = 512; const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32"); // Constants cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda" DISPATCH_FLOAT_HALF_AND_BYTE( p_out.scalar_type(), 1, "maybe_cast_cuda", maybe_cast_kernel<<>>( overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, p_in.data_ptr(), p_out.data_ptr(), tsize);)) C10_CUDA_CHECK(cudaGetLastError()); } void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists) // p_in, p_out { // Constants cudaStream_t stream = at::cuda::getCurrentCUDAStream(); size_t tl_sz = tensor_lists.size(); TORCH_CHECK(tl_sz == 2, "expected tensor lists of size 2"); DISPATCH_FLOAT_HALF_AND_BYTE( tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel", DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel", multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, overflow_flag, tensor_lists, MaybeCastFunctor<2, scalar_t_0, scalar_t_1>());)) C10_CUDA_CHECK(cudaGetLastError()); } void fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { // Get tensor size int tsize = p.numel(); // Determine #threads and #blocks const int threadsPerBlock = 512; const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock); TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); // Constants float step_size = 0; if (bias_correction == 1) { const float bias_correction1 = 1 - std::pow(beta1, step); const float bias_correction2 = 1 - std::pow(beta2, step); step_size = lr * std::sqrt(bias_correction2) / bias_correction1; } else { step_size = lr; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (g.scalar_type() == at::ScalarType::Half) { // all other values should be fp32 for half gradients TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); // dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; maybe_adam_undo_cuda_kernel <<>>( overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } else { using namespace at; DISPATCH_DOUBLE_AND_FLOAT( g.scalar_type(), 0, "adam_cuda_kernel", maybe_adam_undo_cuda_kernel<<>>( overflow_flag.numel() ? overflow_flag.data_ptr() : NULL, p.data_ptr(), m.data_ptr(), v.data_ptr(), g.data_ptr(), beta1, beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay);); } C10_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp ================================================ #include void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int bias_correction, const float weight_decay, const int grad_averaging, const int mode, const float global_grad_norm, const float max_grad_norm); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu ================================================ #include #include #include #include // Another possibility: // #include #include #include "multi_tensor_apply.cuh" #include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 typedef enum { MOMENT_MODE_0 = 0, // L2 regularization mode MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; std::tuple multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::optional per_tensor_python); using MATH_T = float; template struct LAMBStage1Functor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl, const float beta1, const float beta2, const float beta3, const float beta1_correction, const float beta2_correction, const float epsilon, adamMode_t mode, const float decay, const float global_grad_norm, const float max_global_grad_norm) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; p += chunk_idx * chunk_size; T* m = (T*)tl.addresses[2][tensor_loc]; m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[3][tensor_loc]; v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { r_g[ii] = g[i]; // special ?optimization? for lamb stage 1 if (decay == 0) { r_p[ii] = MATH_T(0); } else { r_p[ii] = p[i]; } r_m[ii] = m[i]; r_v[ii] = v[i]; } else { r_g[ii] = MATH_T(0); r_p[ii] = MATH_T(0); r_m[ii] = MATH_T(0); r_v[ii] = MATH_T(0); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; // L2 on scaled grad scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; } else { MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { g[i] = r_p[ii]; m[i] = r_m[ii]; v[i] = r_v[ii]; } } } } }; // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. template struct LAMBStage2Functor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl, const float* per_tensor_param_norm, const float* per_tensor_update_norm, const float learning_rate, const float decay) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_num = tl.start_tensor_this_launch + tensor_loc; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; MATH_T ratio = learning_rate; // apply adaptive learning rate to parameters with non-zero weight decay if (decay != 0.0) { float param_norm = per_tensor_param_norm[tensor_num]; float update_norm = per_tensor_update_norm[tensor_num]; ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; } T* update = (T*)tl.addresses[0][tensor_loc]; update += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; p += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_p[ILP]; MATH_T r_update[ILP]; #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { r_p[ii] = p[i]; r_update[ii] = update[i]; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_p[ii] = r_p[ii] - (ratio * r_update[ii]); } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = r_p[ii]; } } } } }; void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int bias_correction, const float weight_decay, const int grad_averaging, const int mode, const float global_grad_norm, const float max_grad_norm) { using namespace at; // Master weight and 32bit momentum(potentially changing) is not handled by this // So we assume every tensor are all in the same type // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { bias_correction1 = 1 - std::pow(beta1, step); bias_correction2 = 1 - std::pow(beta2, step); } // Handle grad averaging mode float beta3 = 1.0f; if (grad_averaging == 1) beta3 = 1 - beta1; std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1); std::vector> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2); // Compute per tensor param norm auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); // We now in-place modify grad to store update before compute its norm // Generally this is not a issue since people modify grad in step() method all the time // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, LAMBStage1Functor(), beta1, beta2, beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, global_grad_norm, max_grad_norm);) // Compute update norms auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2); DISPATCH_FLOAT_AND_HALF( tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor(), std::get<1>(param_norm_tuple).data_ptr(), std::get<1>(update_norm_tuple).data_ptr(), lr, weight_decay);) AT_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp ================================================ #include void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction, float weight_decay); void multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, int mode, int bias_correction, float weight_decay); void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction, float weight_decay); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, "CUDA kernels for multi-tensor Adam, " "with param copy", py::call_guard()); m.def("multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_cuda, "CUDA kernels for multi-tensor Adam, " "with param copy, capturable for CUDA graph", py::call_guard()); m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, "CUDA kernel for multi-tensor Adam, " "with stored param remainders and param copy", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu ================================================ #include #include #include #include // Another possibility: // #include #include #include #include "multi_tensor_apply.cuh" #include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 template __device__ __forceinline__ bool is_aligned(const T* p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } template __device__ __forceinline__ void load_store(T* dst, const T* src, int dst_offset = 0, int src_offset = 0) { typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((const LT*)src)[src_offset]; } // (1-t)*x + t*y // Note: Named _lerp to avoid ambiguity with std::lerp under C++20. __device__ __forceinline__ float _lerp(float t, float x, float y) { // See https://developer.nvidia.com/blog/lerp-faster-cuda/ return fma(t, y, fma(-t, x, x)); } typedef enum { ADAM_MODE_0 = 0, // L2 regularization mode ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) } adamMode_t; /* Multi-tensor Adam * * Updates params in-place and outputs a copy with a desired datatype. */ template struct DistAdamFunctor { // Vectorized local compute __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float eps, const float lr, adamMode_t mode, const float weight_decay) { if (mode == ADAM_MODE_0) { // L2 #pragma unroll for (int ii = 0; ii < ILP; ii++) { float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); float next_m = _lerp(beta1, scaled_grad, m[ii]); float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]); float next_m_unbiased = next_m / beta1_correction; float next_v_unbiased = next_v / beta2_correction; float denom = sqrtf(next_v_unbiased) + eps; float update = next_m_unbiased / denom; m[ii] = next_m; v[ii] = next_v; p[ii] -= lr * update; } } else { // weight decay #pragma unroll for (int ii = 0; ii < ILP; ii++) { float scaled_grad = g[ii] * grad_scale; float next_m = _lerp(beta1, scaled_grad, m[ii]); float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]); float next_m_unbiased = next_m / beta1_correction; float next_v_unbiased = next_v / beta2_correction; float denom = sqrtf(next_v_unbiased) + eps; float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); m[ii] = next_m; v[ii] = next_v; p[ii] -= lr * update; } } } __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, const float* grad_scale_ptr, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float eps, const float lr, adamMode_t mode, const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; const float grad_scale = *grad_scale_ptr; T* p_in = (T*)tl.addresses[0][tensor_loc]; p_in += chunk_idx * chunk_size; T* m = (T*)tl.addresses[1][tensor_loc]; m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[2][tensor_loc]; v += chunk_idx * chunk_size; const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; g += chunk_idx * chunk_size; PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc]; p_out += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { T local_p[ILP]; T local_m[ILP]; T local_v[ILP]; GRAD_T local_g[ILP]; PARAM_OUT_T local_p_out[ILP]; // Load if (aligned) { load_store(local_p, p_in + i_start); load_store(local_m, m + i_start); load_store(local_v, v + i_start); load_store(local_g, g + i_start); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { local_p[ii] = p_in[i]; local_m[ii] = m[i]; local_v[ii] = v[i]; local_g[ii] = g[i]; } else { local_p[ii] = 0; local_m[ii] = 0; local_v[ii] = 0; local_g[ii] = 0; } } } // Local compute local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps, lr, mode, weight_decay); #pragma unroll for (int ii = 0; ii < ILP; ii++) { local_p_out[ii] = static_cast(local_p[ii]); } // Store if (aligned) { load_store(p_in + i_start, local_p); load_store(m + i_start, local_m); load_store(v + i_start, local_v); load_store(p_out + i_start, local_p_out); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { p_in[i] = local_p[ii]; m[i] = local_m[ii]; v[i] = local_v[ii]; p_out[i] = local_p_out[ii]; } } } } } }; /* Multi-tensor Adam with CUDA Graph Support * * Updates params in-place and outputs a copy with a desired datatype. */ template struct DistAdamCapturableFunctor { // Vectorized local compute __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float eps, const float lr, adamMode_t mode, const float weight_decay) { if (mode == ADAM_MODE_0) { // L2 #pragma unroll for (int ii = 0; ii < ILP; ii++) { float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); float next_m = _lerp(beta1, scaled_grad, m[ii]); float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]); float next_m_unbiased = next_m / beta1_correction; float next_v_unbiased = next_v / beta2_correction; float denom = sqrtf(next_v_unbiased) + eps; float update = next_m_unbiased / denom; m[ii] = next_m; v[ii] = next_v; p[ii] -= lr * update; } } else { // weight decay #pragma unroll for (int ii = 0; ii < ILP; ii++) { float scaled_grad = g[ii] * grad_scale; float next_m = _lerp(beta1, scaled_grad, m[ii]); float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]); float next_m_unbiased = next_m / beta1_correction; float next_v_unbiased = next_v / beta2_correction; float denom = sqrtf(next_v_unbiased) + eps; float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); m[ii] = next_m; v[ii] = next_v; p[ii] -= lr * update; } } } __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, const float* grad_scale_ptr, const float beta1, const float beta2, const int* step, const int bias_correction, const float eps, const float* lr, adamMode_t mode, const float weight_decay) const { assert(noop_gmem); assert(grad_scale_ptr); assert(step); assert(lr); if (*noop_gmem == 1) return; float beta1_correction = 1.0f, beta2_correction = 1.0f; if (bias_correction == 1) { beta1_correction = 1 - pow(beta1, *step); beta2_correction = 1 - pow(beta2, *step); } int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; const float grad_scale = *grad_scale_ptr; T* p_in = (T*)tl.addresses[0][tensor_loc]; p_in += chunk_idx * chunk_size; T* m = (T*)tl.addresses[1][tensor_loc]; m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[2][tensor_loc]; v += chunk_idx * chunk_size; const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; g += chunk_idx * chunk_size; PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc]; p_out += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { T local_p[ILP]; T local_m[ILP]; T local_v[ILP]; GRAD_T local_g[ILP]; PARAM_OUT_T local_p_out[ILP]; // Load if (aligned) { load_store(local_p, p_in + i_start); load_store(local_m, m + i_start); load_store(local_v, v + i_start); load_store(local_g, g + i_start); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { local_p[ii] = p_in[i]; local_m[ii] = m[i]; local_v[ii] = v[i]; local_g[ii] = g[i]; } else { local_p[ii] = 0; local_m[ii] = 0; local_v[ii] = 0; local_g[ii] = 0; } } } // Local compute local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps, *lr, mode, weight_decay); #pragma unroll for (int ii = 0; ii < ILP; ii++) { local_p_out[ii] = static_cast(local_p[ii]); } // Store if (aligned) { load_store(p_in + i_start, local_p); load_store(m + i_start, local_m); load_store(v + i_start, local_v); load_store(p_out + i_start, local_p_out); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { p_in[i] = local_p[ii]; m[i] = local_m[ii]; v[i] = local_v[ii]; p_out[i] = local_p_out[ii]; } } } } } }; /* Functor for multi-tensor Adam with implicit main params * * If params are BF16 and optimizer state is FP32, it is not necessary * to store FP32 main params. Instead, store 16-bit param remainder * and combine with BF16 param to reconstruct the FP32 main param. */ template struct DistAdamWithParamRemaindersFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<6>& tl, const float* grad_scale_ptr, const float beta1, const float beta2, const float beta1_correction, const float beta2_correction, const float eps, const float lr, adamMode_t mode, const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; const float grad_scale = *grad_scale_ptr; int16_t* p_in = (int16_t*)tl.addresses[0][tensor_loc]; p_in += chunk_idx * chunk_size; int16_t* p_rem = (int16_t*)tl.addresses[1][tensor_loc]; p_rem += chunk_idx * chunk_size; float* m = (float*)tl.addresses[2][tensor_loc]; m += chunk_idx * chunk_size; float* v = (float*)tl.addresses[3][tensor_loc]; v += chunk_idx * chunk_size; const GRAD_T* g = (GRAD_T*)tl.addresses[4][tensor_loc]; g += chunk_idx * chunk_size; int16_t* p_out = (int16_t*)tl.addresses[5][tensor_loc]; p_out += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; n = chunk_size < n ? chunk_size : n; const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(p_rem) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) { union fp32_or_int162 { float fp32; int16_t int16[2]; }; fp32_or_int162 local_p[ILP]; int16_t local_p_bf16[ILP]; int16_t local_p_rem[ILP]; float local_m[ILP]; float local_v[ILP]; GRAD_T local_g[ILP]; // Load if (aligned) { load_store(local_p_bf16, p_in + i_start); load_store(local_p_rem, p_rem + i_start); load_store(local_m, m + i_start); load_store(local_v, v + i_start); load_store(local_g, g + i_start); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { local_p_bf16[ii] = p_in[i]; local_p_rem[ii] = p_rem[i]; local_m[ii] = m[i]; local_v[ii] = v[i]; local_g[ii] = g[i]; } else { local_p_bf16[ii] = 0; local_p_rem[ii] = 0; local_m[ii] = 0; local_v[ii] = 0; local_g[ii] = 0; } } } // Reconstruct FP32 params #pragma unroll for (int ii = 0; ii < ILP; ii++) { if (local_p_rem[ii] < 0) local_p_bf16[ii]--; // Undo rounding local_p[ii].int16[1] = local_p_bf16[ii]; local_p[ii].int16[0] = local_p_rem[ii]; } // Local compute using LocalFunctor = DistAdamFunctor; LocalFunctor::local_step(reinterpret_cast(local_p), local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps, lr, mode, weight_decay); // Split into BF16 params (rounded-to-nearest) and remainders #pragma unroll for (int ii = 0; ii < ILP; ii++) { local_p_bf16[ii] = local_p[ii].int16[1]; local_p_rem[ii] = local_p[ii].int16[0]; if (local_p_rem[ii] < 0) local_p_bf16[ii]++; // Round up } // Store if (aligned) { load_store(p_rem + i_start, local_p_rem); load_store(m + i_start, local_m); load_store(v + i_start, local_v); load_store(p_out + i_start, local_p_bf16); } else { #pragma unroll for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { if (i < n) { p_rem[i] = local_p_rem[ii]; m[i] = local_m[ii]; v[i] = local_v[ii]; p_out[i] = local_p_bf16[ii]; } } } } } }; void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, // p_in, m, v, g, p_out at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction, float weight_decay) { using namespace at; // Expect p_in, m, v, g, p_out size_t tl_sz = tensor_lists.size(); TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5"); const auto p_in_type = tensor_lists[0][0].scalar_type(); const auto g_type = tensor_lists[3][0].scalar_type(); const auto p_out_type = tensor_lists[4][0].scalar_type(); float beta1_correction = 1.0f, beta2_correction = 1.0f; if (bias_correction == 1) { beta1_correction = 1 - std::pow(beta1, step); beta2_correction = 1 - std::pow(beta2, step); } DISPATCH_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "dist_adam_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( g_type, 1, "dist_adam_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( p_out_type, 2, "dist_adam_cuda_kernel", multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DistAdamFunctor(), grad_scale.data_ptr(), beta1, beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode, weight_decay);))); C10_CUDA_CHECK(cudaGetLastError()); } void multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, // p_in, m, v, g, p_out at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, int mode, int bias_correction, float weight_decay) { using namespace at; // Expect p_in, m, v, g, p_out size_t tl_sz = tensor_lists.size(); TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5"); const auto p_in_type = tensor_lists[0][0].scalar_type(); const auto g_type = tensor_lists[3][0].scalar_type(); const auto p_out_type = tensor_lists[4][0].scalar_type(); DISPATCH_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "dist_adam_capturable_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( g_type, 1, "dist_adam_capturable_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT( p_out_type, 2, "dist_adam_capturable_cuda_kernel", multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DistAdamCapturableFunctor(), grad_scale.data_ptr(), beta1, beta2, step.data_ptr(), bias_correction, eps, lr.data_ptr(), (adamMode_t)mode, weight_decay);))); C10_CUDA_CHECK(cudaGetLastError()); } void multi_tensor_fused_adam_with_param_remainders_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, // p_in, p_rem, m, v, g, p_out at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction, float weight_decay) { using namespace at; // Expect p_in, p_rem, m, v, g, p_out size_t tl_sz = tensor_lists.size(); TORCH_CHECK(tl_sz == 6, "expected tensor lists of size 6"); const auto g_type = tensor_lists[4][0].scalar_type(); float beta1_correction = 1.0f, beta2_correction = 1.0f; if (bias_correction == 1) { beta1_correction = 1 - std::pow(beta1, step); beta2_correction = 1 - std::pow(beta2, step); } DISPATCH_FLOAT_HALF_AND_BFLOAT( g_type, 0, "dist_adam_with_param_remainders_cuda_kernel", multi_tensor_apply<6>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DistAdamWithParamRemaindersFunctor(), grad_scale.data_ptr(), beta1, beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode, weight_decay);); C10_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp ================================================ #include void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, at::Tensor step, at::Tensor per_tensor_epsilon, const int mode, at::Tensor per_tensor_decay, at::Tensor global_scale, at::Tensor global_grad_norm, const float max_grad_norm); void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, at::Tensor update_norm_offset, at::Tensor learning_rate, at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, "Computes update term for LAMB optimizer", py::call_guard()); m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, "Applies update term for LAMB optimizer", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu ================================================ #include #include #include #include // Another possibility: // #include #include #include "multi_tensor_apply.cuh" #include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 template __device__ __forceinline__ bool is_aligned(T* p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } template __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) { typedef typename std::aligned_storage::type LT; ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } template __device__ void convert(const FROM_T vi, TO_T& vo) { vo = static_cast(vi); } template <> __device__ void convert(const float vi, uint8_t& vo) { union S { float as_float; int as_int; }; S s; s.as_float = vi; s.as_int = s.as_int & 0xFF800000; union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_half = static_cast(vi + s.as_float / 8.0f); vo = t.as_byte[1]; } template <> __device__ void convert(const uint8_t vi, float& vo) { union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_byte[0] = 0; t.as_byte[1] = vi; vo = static_cast(t.as_half); } template <> __device__ void convert(const at::Half vi, uint8_t& vo) { union S { float as_float; int as_int; }; S s; s.as_float = static_cast(vi); s.as_int = s.as_int & 0xFF800000; union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_half = static_cast(vi + s.as_float / 8.0f); vo = t.as_byte[1]; } template <> __device__ void convert(const uint8_t vi, at::Half& vo) { union T { at::Half as_half; uint8_t as_byte[2]; }; T t; t.as_byte[0] = 0; t.as_byte[1] = vi; vo = t.as_half; } typedef enum { MOMENT_MODE_0 = 0, // L2 regularization mode MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; template struct DistOptLAMBStage1Functor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, const MATH_T* per_tensor_beta1, const MATH_T* per_tensor_beta2, const MATH_T* per_tensor_beta3, const int* per_tensor_bias_correction, const int* step, const MATH_T* per_tensor_epsilon, adamMode_t mode, const MATH_T* per_tensor_decay, const MATH_T* global_scale, const MATH_T* global_grad_norm, const float max_grad_norm) { // I'd like this kernel to propagate infs/nans. if (*noop_gmem == 1) return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_num = tl.start_tensor_this_launch + tensor_loc; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; float combined_scale = *global_scale; if (max_grad_norm > 0) { combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6); combined_scale = *global_scale / std::min((float)1.0, combined_scale); } MATH_T beta1 = per_tensor_beta1[tensor_num]; MATH_T beta2 = per_tensor_beta2[tensor_num]; MATH_T beta3 = 1 - beta1; MATH_T beta1_correction, beta2_correction; if (per_tensor_bias_correction[tensor_num] == 1) { beta1_correction = 1 - pow(beta1, *step); beta2_correction = 1 - pow(beta2, *step); } else { beta1_correction = (MATH_T)1.0; beta2_correction = (MATH_T)1.0; } MATH_T epsilon = per_tensor_epsilon[tensor_num]; MATH_T decay = per_tensor_decay[tensor_num]; GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; p += chunk_idx * chunk_size; T* m = (T*)tl.addresses[2][tensor_loc]; m += chunk_idx * chunk_size; T* v = (T*)tl.addresses[3][tensor_loc]; v += chunk_idx * chunk_size; MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc]; u += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && is_aligned(p) && is_aligned(m) && is_aligned(v)) { GRAD_T l_g[ILP]; T l_p[ILP]; T l_m[ILP]; T l_v[ILP]; for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load load_store(l_g, g, 0, i_start); if (decay != 0) load_store(l_p, p, 0, i_start); load_store(l_m, m, 0, i_start); load_store(l_v, v, 0, i_start); // unpack #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_g[ii] = l_g[ii]; if (decay == 0) { r_p[ii] = MATH_T(0); } else { r_p[ii] = l_p[ii]; } r_m[ii] = l_m[ii]; r_v[ii] = l_v[ii]; } #pragma unroll for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { MATH_T scaled_grad = r_g[ii] / combined_scale; // L2 on scaled grad scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; } else { MATH_T scaled_grad = r_g[ii] / combined_scale; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { l_m[ii] = r_m[ii]; l_v[ii] = r_v[ii]; } // store load_store(u, r_p, i_start, 0); load_store(m, l_m, i_start, 0); load_store(v, l_v, i_start, 0); } } else { // see note in multi_tensor_scale_kernel.cu for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; MATH_T r_v[ILP]; #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { r_g[ii] = g[i]; // special ?optimization? for lamb stage 1 if (decay == 0) { r_p[ii] = MATH_T(0); } else { r_p[ii] = p[i]; } r_m[ii] = m[i]; r_v[ii] = v[i]; } else { r_g[ii] = MATH_T(0); r_p[ii] = MATH_T(0); r_m[ii] = MATH_T(0); r_v[ii] = MATH_T(0); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { if (mode == MOMENT_MODE_0) { MATH_T scaled_grad = r_g[ii] / combined_scale; // L2 on scaled grad scaled_grad = scaled_grad + decay * r_p[ii]; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = next_m_unbiased / denom; } else { MATH_T scaled_grad = r_g[ii] / combined_scale; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T denom = sqrtf(next_v_unbiased) + epsilon; r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { u[i] = r_p[ii]; m[i] = r_m[ii]; v[i] = r_v[ii]; } } } } } }; // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. template struct DistOptLAMBStage2Functor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl, const MATH_T* per_tensor_param_norm, const MATH_T* per_tensor_update_norm, const long* update_norm_offset, const MATH_T* learning_rate, const MATH_T* per_tensor_decay, const MATH_T* global_grad_norm, bool use_nvlamb) { // I'd like this kernel to propagate infs/nans. if (*noop_gmem == 1) return; int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_num = tl.start_tensor_this_launch + tensor_loc; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; MATH_T decay = per_tensor_decay[tensor_num]; MATH_T ratio = *learning_rate; // nvlamb: apply adaptive learning rate to all parameters // otherwise, only apply to those with non-zero weight decay if (use_nvlamb || (decay != (MATH_T)0.0)) { MATH_T param_norm = per_tensor_param_norm[tensor_num]; MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]]; ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate); } MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; update += chunk_idx * chunk_size; T* p = (T*)tl.addresses[1][tensor_loc]; p += chunk_idx * chunk_size; GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc]; p_copy += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; // to make things simple, we put aligned case in a different code path if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update)) { T r_p[ILP]; MATH_T r_update[ILP]; GRAD_T r_p_copy[ILP]; for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) { // load load_store(r_p, p, 0, i_start); load_store(r_update, update, 0, i_start); #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_p[ii] = static_cast(r_p[ii]) - (ratio * r_update[ii]); convert(r_p[ii], r_p_copy[ii]); } load_store(p, r_p, i_start, 0); load_store(p_copy, r_p_copy, i_start, 0); } } else { for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { MATH_T r_p[ILP]; MATH_T r_update[ILP]; #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { r_p[ii] = p[i]; r_update[ii] = update[i]; } } #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_p[ii] = r_p[ii] - (ratio * r_update[ii]); } #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = r_p[ii]; convert(r_p[ii], p_copy[i]); } } } } } }; void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, at::Tensor step, at::Tensor per_tensor_epsilon, const int mode, at::Tensor per_tensor_decay, at::Tensor global_scale, at::Tensor global_grad_norm, const float max_grad_norm) { using namespace at; DISPATCH_FLOAT_AND_HALF( tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF( tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF( tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DistOptLAMBStage1Functor(), per_tensor_beta1.data_ptr(), per_tensor_beta2.data_ptr(), per_tensor_beta3.data_ptr(), per_tensor_bias_correction.data_ptr(), step.data_ptr(), per_tensor_epsilon.data_ptr(), (adamMode_t)mode, per_tensor_decay.data_ptr(), global_scale.data_ptr(), global_grad_norm.data_ptr(), max_grad_norm);))) AT_CUDA_CHECK(cudaGetLastError()); } void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, at::Tensor update_norm_offset, at::Tensor learning_rate, at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb) { using namespace at; DISPATCH_FLOAT_AND_HALF( tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_HALF_AND_BYTE( tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF( tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2", multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DistOptLAMBStage2Functor(), per_tensor_param_norm.data_ptr(), per_tensor_update_norm.data_ptr(), update_norm_offset.data_ptr(), learning_rate.data_ptr(), per_tensor_decay.data_ptr(), global_grad_norm.data_ptr(), use_nvlamb);))) AT_CUDA_CHECK(cudaGetLastError()); } ================================================ FILE: apex/contrib/csrc/peer_memory/peer_memory.cpp ================================================ /** * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "peer_memory_cuda.cuh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", py::call_guard()); m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", py::call_guard()); m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", py::call_guard()); m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", py::call_guard()); m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", py::call_guard()); m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", py::call_guard()); m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/peer_memory/peer_memory_cuda.cu ================================================ #include #include #include #include #include #include #include #include "nccl.h" #define CUDACHECK(cmd) \ do { \ cudaError_t err = cmd; \ if (err != cudaSuccess) { \ char hostname[1024]; \ gethostname(hostname, 1024); \ printf("%s: CUDA failure %s:%d '%s'\n", hostname, __FILE__, __LINE__, cudaGetErrorString(err)); \ } \ } while (0) namespace { constexpr int THREADS_PER_CTA = 128; /* Basic deleter function for from_blob function. void deleter(void* ptr) { printf("deleter(ptr=%p)\n",ptr); cudaFree(ptr); } */ template at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOptions& options, bool channels_last) { size_t size = 1; std::vector strides(shape.size()); if (channels_last) { assert(shape.size() == 4); strides[0] = shape[1] * shape[2] * shape[3]; strides[1] = 1; strides[2] = shape[1] * shape[3]; strides[3] = shape[1]; } else { int idx = strides.size(); for (auto it = shape.rbegin(); it != shape.rend(); ++it) { strides[--idx] = size; size *= *it; } } size *= sizeof(T); // TODO: Implement dynamic reuse of pooled peer memory. // We provide no deleter function because all peer memory allocations are static in this implementation. return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); } void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) { if (t.dim() == 3) { N = 1; if (explicit_nhwc) { C = t.size(2); H = t.size(0); W = t.size(1); } else { C = t.size(0); H = t.size(1); W = t.size(2); } } else if (t.dim() == 4) { if (explicit_nhwc) { N = t.size(0); C = t.size(3); H = t.size(1); W = t.size(2); } else { N = t.size(0); C = t.size(1); H = t.size(2); W = t.size(3); } } else { printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n", __FILE__, __LINE__, int(t.dim())); assert(t.dim() == 3 || t.dim() == 4); } } void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W) { if (t.dim() == 3) { if (explicit_nhwc) { stride_C = t.stride(2); stride_H = t.stride(0); stride_W = t.stride(1); } else { stride_C = t.stride(0); stride_H = t.stride(1); stride_W = t.stride(2); } stride_N = t.size(0) * t.size(1) * t.size(2); } else if (t.dim() == 4) { if (explicit_nhwc) { stride_N = t.stride(0); stride_C = t.stride(3); stride_H = t.stride(1); stride_W = t.stride(2); } else { stride_N = t.stride(0); stride_C = t.stride(1); stride_H = t.stride(2); stride_W = t.stride(3); } } else { printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n", __FILE__, __LINE__, t.dim()); assert(t.dim() == 3 || t.dim() == 4); } } template inline __device__ void __zero(T* dst) { *dst = T(0); } inline __device__ void __zero(int2* dst) { *dst = {0, 0}; } template inline __device__ void zero_tensor(const int dim0, const int dim1, const int dim2, T* __restrict__ data, const int data_stride0, const int data_stride1, const int data_stride2, const int thread_id, const int block_id, const int num_blocks) { const int global_id = thread_id + block_id * THREADS_PER_CTA; const int num_threads = num_blocks * THREADS_PER_CTA; const int count = dim0 * dim1 * dim2; for (int i = global_id; i < count; i += num_threads) { int offset; if (contiguous) { offset = i; } else { const int j2 = i % dim2; const int k = i / dim2; const int j1 = k % dim1; const int j0 = k / dim1; offset = j0 * data_stride0 + j1 * data_stride1 + j2 * data_stride2; } __zero(data + offset); } } template inline __device__ void push_pull_tensor(const int dim0, const int dim1, const int dim2, const T* __restrict__ data_in, const int data_in_stride0, const int data_in_stride1, const int data_in_stride2, T* __restrict__ data_out, const int data_out_stride0, const int data_out_stride1, const int data_out_stride2, int4* local_peer, int4* remote_peer, const int thread_id, const int block_id, const int num_blocks) { // 128b=16B NVLink flit // Note: Use last 4B as a semaphore static_assert(sizeof(T) <= 12); union Flit { T payload; uint uints[4]; }; // Communication bit indicates whether flit has been received from // a remote GPU constexpr uint communication_mask = 1 << 0; // Status bit is used to choose the active peer buffer in an // alternating double buffer scheme. We use buffer 1 if the bits // match, use buffer 2 if the bits differ, and invert the bit // after finishing with a buffer. constexpr uint status_mask = 1 << 1; // Split peer memory into two sets of buffers // Note: Each block owns a THREADS_PER_CTA*2*16B chunk of peer // memory const int peer_offset1 = block_id * THREADS_PER_CTA * 2 + thread_id; const int peer_offset2 = peer_offset1 + THREADS_PER_CTA; volatile int* local_peer1 = reinterpret_cast(local_peer + peer_offset1); volatile int* local_peer2 = reinterpret_cast(local_peer + peer_offset2); volatile int* remote_peer1 = reinterpret_cast(remote_peer + peer_offset1); volatile int* remote_peer2 = reinterpret_cast(remote_peer + peer_offset2); // Iterate through tensor entries const int num_threads = num_blocks * THREADS_PER_CTA; const int count = dim0 * dim1 * dim2; for (int i0 = block_id * THREADS_PER_CTA; i0 < count; i0 += num_threads) { const int i = i0 + thread_id; const bool has_data = i < count; // Calculate buffer positions int data_in_offset, data_out_offset; if (contiguous) { data_in_offset = i; data_out_offset = i; } else { const int j2 = i % dim2; const int k = i / dim2; const int j1 = k % dim1; const int j0 = k / dim1; data_in_offset = j0 * data_in_stride0 + j1 * data_in_stride1 + j2 * data_in_stride2; data_out_offset = j0 * data_out_stride0 + j1 * data_out_stride1 + j2 * data_out_stride2; } // Determine which peer memory buffer to use // Note: The status bit is not affected by asynchronous // communication from the remote GPU. Flit local_message1, local_message2; asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(local_message1.uints[0]), "=r"(local_message1.uints[1]), "=r"(local_message1.uints[2]), "=r"(local_message1.uints[3]) : "l"(local_peer1) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(local_message2.uints[0]), "=r"(local_message2.uints[1]), "=r"(local_message2.uints[2]), "=r"(local_message2.uints[3]) : "l"(local_peer2) : "memory"); const uint status1 = local_message1.uints[3] & status_mask; const uint status2 = local_message2.uints[3] & status_mask; const bool peer1_is_active = (status1 ^ status2) == 0; volatile int* ox = peer1_is_active ? remote_peer1 : remote_peer2; volatile int* ix = peer1_is_active ? local_peer1 : local_peer2; const uint status = peer1_is_active ? status1 : status2; Flit recv_message = peer1_is_active ? local_message1 : local_message2; // Send flit to remote GPU // Note: Set communication bit and keep status bit Flit send_message; if (has_data) { send_message.payload = data_in[data_in_offset]; } send_message.uints[3] = communication_mask | status; asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(ox), "r"(send_message.uints[0]), "r"(send_message.uints[1]), "r"(send_message.uints[2]), "r"(send_message.uints[3]) : "memory"); // Recieve flit from peer while ((recv_message.uints[3] & communication_mask) == 0) { asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(recv_message.uints[0]), "=r"(recv_message.uints[1]), "=r"(recv_message.uints[2]), "=r"(recv_message.uints[3]) : "l"(ix) : "memory"); } if (has_data) { data_out[data_out_offset] = recv_message.payload; } // Reset semaphore // Note: Clear communication bit and invert status bit uint flag = ~status & status_mask; asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(ix), "n"(0), "n"(0), "n"(0), "r"(flag) : "memory"); if (i0 + num_threads < count) { __threadfence_system(); } } } template #if __CUDA_ARCH__ >= 700 __launch_bounds__(THREADS_PER_CTA) #endif __global__ void push_pull_halos_1d_kernel( // top halo, T* toh, int toh_stride0, int toh_stride1, int toh_stride2, // top output halo (local) const T* tih, int tih_stride0, int tih_stride1, int tih_stride2, // top input halo (local) int4* tox, // top output transfer buffer (remote peer) int4* tix, // top input transfer buffer (local peer) // btm halo T* boh, int boh_stride0, int boh_stride1, int boh_stride2, // btm output halo (local) const T* bih, int bih_stride0, int bih_stride1, int bih_stride2, // btm input halo (local) int4* box, // btm output transfer buffer (remote peer) int4* bix, // btm input transfer buffer (local peer) // dimensions int dim0, int dim1, int dim2, bool top_first // whether to launch communicate top halo first ) { const int num_blocks_side = gridDim.x / 2; const int block_id_side = (blockIdx.x < num_blocks_side ? blockIdx.x : blockIdx.x - num_blocks_side); const bool in_top_block = top_first == (blockIdx.x < num_blocks_side); if (in_top_block) { if (top_zero) { zero_tensor(dim0, dim1, dim2, toh, toh_stride0, toh_stride1, toh_stride2, threadIdx.x, block_id_side, num_blocks_side); } else { push_pull_tensor(dim0, dim1, dim2, tih, tih_stride0, tih_stride1, tih_stride2, toh, toh_stride0, toh_stride1, toh_stride2, tix, tox, threadIdx.x, block_id_side, num_blocks_side); } } else { if (btm_zero) { zero_tensor(dim0, dim1, dim2, boh, boh_stride0, boh_stride1, boh_stride2, threadIdx.x, block_id_side, num_blocks_side); } else { push_pull_tensor(dim0, dim1, dim2, bih, bih_stride0, bih_stride1, bih_stride2, boh, boh_stride0, boh_stride1, boh_stride2, bix, box, threadIdx.x, block_id_side, num_blocks_side); } } } __global__ void delay_kernel(int delay_nanoseconds, int* counter) { if (blockIdx.x == 0 && threadIdx.x == 0) { // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. int new_counter = 0; double elapsed = 0; clock_t start = clock(); do { clock_t now = clock(); elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC; ++new_counter; } while (elapsed < (double)delay_nanoseconds); *counter = new_counter; } } } // namespace namespace apex { namespace contrib { namespace peer_memory { int64_t allocate_raw(int64_t size) { float* ptr = 0L; cudaMalloc(&ptr, size); cudaMemset(ptr, 0, size); return (int64_t)ptr; } void free_raw(int64_t raw) { cudaFree((void*)raw); } void zero(int64_t raw, int64_t size) { cudaMemset((void*)raw, 0, size); } at::Tensor get_raw_ipc_address(int64_t raw) { cudaIpcMemHandle_t mem_handle; CUDACHECK(cudaIpcGetMemHandle(&mem_handle, (void*)raw)); const int n = sizeof(cudaIpcMemHandle_t); auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); auto address_tensor_p = address_tensor.data_ptr(); memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); return address_tensor; } std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw) { int peer_group_size = ipc_addresses.size(0); std::vector results(peer_group_size); for (int i = 0; i < peer_group_size; ++i) { if (i != peer_rank) { cudaIpcMemHandle_t mem_handle; memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr(), sizeof(cudaIpcMemHandle_t)); void* p = 0L; CUDACHECK(cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess)); results[i] = (int64_t)p; } else { results[i] = (int64_t)raw; } } return results; } at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) { return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); } at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) { return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); } at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) { return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); } void push_pull_halos_1d( bool diagnostics, bool explicit_nhwc, int numSM, // number of SMs to use (zero corresponds to all SMs) int rank, // rank in spatial parallel group bool top_zero, // if top halo should be zeroed at::Tensor top_in_halo, // top input halo buffer (in local device memory, sent to top neighbor) at::Tensor top_in_transfer, // top input transfer buffer (in local peer memory) at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) bool btm_zero, // if btm halo should be zeroed at::Tensor btm_in_halo, // btm input halo buffer (in local device memory, sent to btm neighbor) at::Tensor btm_in_transfer, // btm input transfer buffer (in local peer memory) at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) at::Tensor btm_out_halo // btm output halo buffer (in local device memory, received from btm neighbor) ) { // basic checks of inputs TORCH_CHECK(!(top_zero && btm_zero)); TORCH_CHECK(top_in_halo.is_cuda()); TORCH_CHECK(top_out_transfer.is_cuda()); TORCH_CHECK(top_in_transfer.is_cuda()); TORCH_CHECK(top_out_halo.is_cuda()); TORCH_CHECK(btm_in_halo.is_cuda()); TORCH_CHECK(btm_out_transfer.is_cuda()); TORCH_CHECK(btm_in_transfer.is_cuda()); TORCH_CHECK(btm_out_halo.is_cuda()); // tensor shapes int tih_N, tih_C, tih_H, tih_W; tensor_shape(top_in_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W); int toh_N, toh_C, toh_H, toh_W; tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W); int bih_N, bih_C, bih_H, bih_W; tensor_shape(btm_in_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W); int boh_N, boh_C, boh_H, boh_W; tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W); TORCH_CHECK(toh_N == tih_N && tih_N == boh_N && boh_N == bih_N && toh_C == tih_C && tih_C == boh_C && boh_C == bih_C && toh_H == tih_H && tih_H == boh_H && boh_H == bih_H && toh_W == tih_W && tih_W == boh_W && boh_W == bih_W); int NN = toh_N, NC = toh_C, NH = toh_H, NW = toh_W; if (diagnostics) { printf("rank %d: NN=%d, NC=%d, NH=%d, NW=%d\n", rank, NN, NC, NH, NW); } TORCH_CHECK(NN == 1); // tensor strides int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W; tensor_strides(top_in_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W; tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W; tensor_strides(btm_in_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W; tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); if (diagnostics) { printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); } // determine if nhwc bool is_nhwc = (toh_stride_C == 1); if (diagnostics) { printf("rank %d: is_nhwc = %s\n", rank, is_nhwc ? "true" : "false"); } // determine if contiguous bool contiguous = true; if ((NN - 1) * toh_stride_N + (NC - 1) * toh_stride_C + (NH - 1) * toh_stride_H + (NW - 1) * toh_stride_W != NN * NC * NH * NW - 1) { contiguous = false; } if ((NN - 1) * boh_stride_N + (NC - 1) * boh_stride_C + (NH - 1) * boh_stride_H + (NW - 1) * boh_stride_W != NN * NC * NH * NW - 1) { contiguous = false; } if (!top_zero) { if (toh_stride_N != tih_stride_N || toh_stride_C != tih_stride_C || toh_stride_H != tih_stride_H || toh_stride_W != tih_stride_W) { contiguous = false; } } if (!btm_zero) { if (boh_stride_N != bih_stride_N || boh_stride_C != bih_stride_C || boh_stride_H != bih_stride_H || boh_stride_W != bih_stride_W) { contiguous = false; } } if (diagnostics) { printf("rank %d: contiguous = %s\n", rank, contiguous ? "true" : "false"); } // determine whether to communicate top halo first bool top_first = rank % 2 != 0; if (diagnostics) { printf("rank %d: top_first = %s\n", rank, top_first ? "true" : "false"); } // peer memory buffers int tox_size = top_out_transfer.numel() * top_out_transfer.element_size(); int tix_size = top_in_transfer.numel() * top_in_transfer.element_size(); int box_size = btm_out_transfer.numel() * btm_out_transfer.element_size(); int bix_size = btm_in_transfer.numel() * btm_in_transfer.element_size(); if (!top_zero) { TORCH_CHECK(top_out_transfer.is_contiguous()); TORCH_CHECK(top_in_transfer.is_contiguous()); TORCH_CHECK(tox_size == tix_size); } if (!btm_zero) { TORCH_CHECK(btm_out_transfer.is_contiguous()); TORCH_CHECK(btm_in_transfer.is_contiguous()); TORCH_CHECK(box_size == bix_size); } // figure out launch parameters int device; cudaGetDevice(&device); cudaDeviceProp prop; cudaGetDeviceProperties(&prop, device); if (numSM <= 0 || numSM > prop.multiProcessorCount) { numSM = prop.multiProcessorCount; } auto current_stream = at::cuda::getCurrentCUDAStream(); dim3 block(THREADS_PER_CTA, 1, 1); // helper macros to launch templated kernel #define LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, TOP_ZERO, BTM_ZERO, KERNEL_ARGS, NUM_ELEMENTS) \ do { \ /* kernel configuration */ \ int numBlocksPerSm; \ cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ &numBlocksPerSm, push_pull_halos_1d_kernel, THREADS_PER_CTA, 0); \ dim3 grid(numSM * numBlocksPerSm, 1, 1); \ if (grid.x % 2 != 0) { \ /* require even number of blocks (half for top, half for bottom) */ \ grid.x -= 1; \ } \ if ((grid.x / 2) * THREADS_PER_CTA > NUM_ELEMENTS) { \ /* only need enough blocks to cover top and bottom halo elements */ \ grid.x = 2 * ((NUM_ELEMENTS + THREADS_PER_CTA - 1) / THREADS_PER_CTA); \ } \ if (!TOP_ZERO) { \ /* require 2*128b=32B peer memory per thread */ \ if ((grid.x / 2) * THREADS_PER_CTA * 32 > tox_size) { \ grid.x = 2 * (tox_size / (THREADS_PER_CTA * 32)); \ } \ } \ if (!BTM_ZERO) { \ /* require 2*128b=32B peer memory per thread */ \ if ((grid.x / 2) * THREADS_PER_CTA * 32 > box_size) { \ grid.x = 2 * (box_size / (THREADS_PER_CTA * 32)); \ } \ } \ TORCH_CHECK(grid.x >= 2); \ \ /* launch kernel */ \ cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, \ KERNEL_ARGS, 0, current_stream); \ } while (false) #define LAUNCH_PUSH_PULL_HALO_KERNEL(T, CONTIGUOUS, KERNEL_ARGS, NUM_ELEMENTS) \ do { \ if (top_zero) { \ LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, true, false, KERNEL_ARGS, NUM_ELEMENTS); \ } else if (btm_zero) { \ LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, true, KERNEL_ARGS, NUM_ELEMENTS); \ } else { \ LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, false, KERNEL_ARGS, NUM_ELEMENTS); \ } \ } while (false) AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&] { if (diagnostics) { printf("rank %d: size(scalar_t) = %ld\n", rank, sizeof(scalar_t)); } scalar_t* toh_p = top_out_halo.data_ptr(); scalar_t* tih_p = top_in_halo.data_ptr(); int4* tox_p = reinterpret_cast(top_out_transfer.data_ptr()); int4* tix_p = reinterpret_cast(top_in_transfer.data_ptr()); scalar_t* boh_p = btm_out_halo.data_ptr(); scalar_t* bih_p = btm_in_halo.data_ptr(); int4* box_p = reinterpret_cast(btm_out_transfer.data_ptr()); int4* bix_p = reinterpret_cast(btm_in_transfer.data_ptr()); if (diagnostics) printf("rank %d: choosing halo exchange kernel\n", rank); // do int2 vector loads if channel count permits if (contiguous && (NN * NH * NW * NC * sizeof(scalar_t)) % sizeof(int2) == 0) { // can do contiguous int2 transfers if (diagnostics) { } toh_stride_N = toh_stride_H = toh_stride_W = toh_stride_C = 1; tih_stride_N = tih_stride_H = tih_stride_W = tih_stride_C = 1; boh_stride_N = boh_stride_H = boh_stride_W = boh_stride_C = 1; bih_stride_N = bih_stride_H = bih_stride_W = bih_stride_C = 1; NC = (NN * NH * NW * NC * sizeof(scalar_t)) / sizeof(int2); NN = NH = NW = 1; if (diagnostics) { printf("rank %d: launching contiguous int2 halo exchange kernel\n", rank); printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); } void* kernel_args[] = {(int2**)&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, (int2**)&tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, &tox_p, &tix_p, (int2**)&boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, (int2**)&bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, &box_p, &bix_p, &NH, &NW, &NC, &top_first}; int num_elem = NN * NH * NW * NC; LAUNCH_PUSH_PULL_HALO_KERNEL(int2, true, kernel_args, num_elem); } else if (is_nhwc && (NC * sizeof(scalar_t)) % sizeof(int2) == 0) { // can do strided int2 transfers int divisor = sizeof(int2) / sizeof(scalar_t); if (diagnostics) { printf("rank %d: launching strided int2 halo exchange kernel\n", rank); } toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor; tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor; boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor; bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor; NC /= divisor; if (diagnostics) { printf("rank %d: divisor=%d\n", rank, divisor); printf("rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); printf("rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); printf("rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); printf("rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\n", rank, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); printf("rank %d: NC=%d, NH=%d, NW=%d\n", rank, NC, NH, NW); } void* kernel_args[] = {(int2**)&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, (int2**)&tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, &tox_p, &tix_p, (int2**)&boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, (int2**)&bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, &box_p, &bix_p, &NH, &NW, &NC, &top_first}; int num_elem = NH * NW * NC; LAUNCH_PUSH_PULL_HALO_KERNEL(int2, false, kernel_args, num_elem); } else { // cannot do int2 transfers if (diagnostics) { printf("rank %d: launching non-int2 halo exchange kernel\n", rank); } int num_elem = NC * NH * NW; if (is_nhwc) { void* kernel_args[] = {&toh_p, &toh_stride_H, &toh_stride_W, &toh_stride_C, &tih_p, &tih_stride_H, &tih_stride_W, &tih_stride_C, &tox_p, &tix_p, &boh_p, &boh_stride_H, &boh_stride_W, &boh_stride_C, &bih_p, &bih_stride_H, &bih_stride_W, &bih_stride_C, &box_p, &bix_p, &NH, &NW, &NC, &top_first}; LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); } else { void* kernel_args[] = {&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, &tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, &tox_p, &tix_p, &boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, &box_p, &bix_p, &NC, &NH, &NW, &top_first}; LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem); } } }); #undef LAUNCH_PUSH_PULL_HALO_KERNEL_BASE #undef LAUNCH_PUSH_PULL_HALO_KERNEL } } // namespace peer_memory } // namespace contrib } // namespace apex ================================================ FILE: apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh ================================================ /** * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #ifndef _peer_memory_h_ #define _peer_memory_h_ namespace apex { namespace contrib { namespace peer_memory { int64_t allocate_raw(int64_t size); void free_raw(int64_t raw); void zero(int64_t raw, int64_t size); at::Tensor get_raw_ipc_address(int64_t raw); std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); void push_pull_halos_1d( bool diagnostics, bool explicit_nhwc, int numSM, // number of SMs to use int peer_rank, // rank in spatial parallel group bool top_zero, // if top halo should be zeroed at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) at::Tensor top_inp_transfer, // top input transfer buffer (in local peer memory) at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) at::Tensor top_inp_halo, // top input halo buffer (in local device memory, sent to top neighbor) bool btm_zero, // if btm halo should be zeroed at::Tensor btm_out_halo, // btm output halo buffer (in local device memory, received from btm neighbor) at::Tensor btm_inp_transfer, // btm input transfer buffer (in local peer memory) at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) at::Tensor btm_inp_halo // btm input halo buffer (in local device memory, sent to btm neighbor) ); } // namespace peer_memory } // namespace contrib } // namespace apex #endif ================================================ FILE: apex/contrib/csrc/transducer/transducer_joint.cpp ================================================ #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize); std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, bool packOutput, float scale); std::vector transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) { CHECK_INPUT(f); CHECK_INPUT(g); CHECK_INPUT(fLen); CHECK_INPUT(gLen); if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, dropoutProb, tileSize); } std::vector transducer_joint_backward(std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, bool packOutput, float scale) { for (auto t : in) { CHECK_INPUT(t); } CHECK_INPUT(fLen); CHECK_INPUT(gLen); if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", py::call_guard()); m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/transducer/transducer_joint_kernel.cu ================================================ #include #include #include #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include #include #include #include "philox.cuh" // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // width should be a power of 2 and should be less than warpSize. template __device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width = C10_WARP_SIZE) { for (unsigned offset = width / 2; offset > 0; offset /= 2) { x += __shfl_down_sync(0xffffffff, x, offset, width); } return x; } inline int largestPowerOfTwo(int x) { int y = 1; while (y <= x) y <<= 1; return y >> 1; } /* Figure out vectorization type for masks. Similar to how PyTorch figures out acc_t here: aten/src/ATen/AccumulateType.h */ template struct MaskVecType {}; template <> struct MaskVecType<1> { using type = uint8_t; }; template <> struct MaskVecType<2> { using type = uint16_t; }; template <> struct MaskVecType<4> { using type = uint32_t; }; template using mvec_type = typename MaskVecType::type; // Helper class to calculate pointer offset that can be shared by different flavors of kernels. // For fwd, batch offset and stride are different for packing and non-packing mode. struct OffsetCalFwd { __device__ __forceinline__ OffsetCalFwd(int64_t batch, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t gLen, int64_t hiddenSize, bool packOutput) : batch(batch), batchOffset(batchOffset), maxFLen(maxFLen), maxGLen(maxGLen), gLen(gLen), hiddenSize(hiddenSize), packOutput(packOutput) {} int64_t batch; const int64_t* batchOffset; int64_t maxFLen; int64_t maxGLen; int64_t gLen; int64_t hiddenSize; bool packOutput; __device__ __forceinline__ int64_t getBatchOffset() { return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize : batch * maxFLen * maxGLen * hiddenSize; } __device__ __forceinline__ int64_t getStrideF() { return packOutput ? gLen * hiddenSize : maxGLen * hiddenSize; } }; // Helper class to calculate pointer offset that can be shared by different flavors of kernels // For bwd, batch offset and stride are different for packing and non-packing mode. // The reducion is done for two input tensors. Therefore, generating two sets of offsets // according to bwdFasterDim can lead to a unified implementation in the actual kernel. struct OffsetCalBwd { __device__ __forceinline__ OffsetCalBwd(int64_t batch, const int64_t* batchOffset, const int* fLen, const int* gLen, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, bool bwdFasterDim) : batch(batch), batchOffset(batchOffset), maxFLen(maxFLen), maxGLen(maxGLen), fLen(fLen), gLen(gLen), hiddenSize(hiddenSize), packOutput(packOutput), bwdFasterDim(bwdFasterDim) {} int64_t batch; const int64_t* batchOffset; const int* fLen; const int* gLen; int64_t maxFLen; int64_t maxGLen; int64_t hiddenSize; bool packOutput; bool bwdFasterDim; // whether doing bwd on the faster moving dimension __device__ __forceinline__ int64_t getBatchOffset() { return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize : batch * maxFLen * maxGLen * hiddenSize; } __device__ __forceinline__ int64_t getMaxXLen() { return bwdFasterDim ? maxGLen : maxFLen; } __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]) { return bwdFasterDim ? gLen[batch] : fLen[batch]; } __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]) { return bwdFasterDim ? fLen[batch] : gLen[batch]; } __device__ __forceinline__ int64_t getStrideX() { return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize); } __device__ __forceinline__ int64_t getStrideY() { return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize; } }; // Vanila transducer joint forward kernel // Detail of this joint function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // f is a tensor of shape [batch, T, H] // g is a tensor of shape [batch, U, H] // the transducer joint does // sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) // The resultant tensor is of shape [batch, T, U, H] // Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :] // This joint function can optionally pack the output where the output tensor with a shape of // [B, T, U, H] is packed into [B_packed, H]. // Don't-care region (t > fLen) or (u > gLen) is removed. // To enable packing, the starting offset for each batch need to be specified with batchOffset. template __global__ void transducer_joint_forward(const scalar_t* f, const scalar_t* g, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, scalar_t* sum) { const int batch = blockIdx.z; const int t = blockIdx.y; const int u = blockIdx.x; const auto myFLen = fLen[batch]; const auto myGLen = gLen[batch]; OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideF = offsetCal.getStrideF(); scalar_t const* myF = f + batch * maxFLen * hiddenSize + t * hiddenSize; scalar_t const* myG = g + batch * maxGLen * hiddenSize + u * hiddenSize; scalar_t* mySum = sum + myBatchOffset + t * strideF + u * hiddenSize; if (t < myFLen and u < myGLen) { #pragma unroll for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) { if (h < hiddenSize) { mySum[h] = myF[h] + myG[h]; } } } else if (packOutput == false and t < maxFLen and u < maxGLen) { // Need to write finite data to don't-care region because we instantiate the result tensor // with torch::empty for performance reasons. Even though it is don't-care region, the // contents need to be finite, otherwise could lead to NaN in WGRAD. // In packing mode, this write is no longer necessary as we remove the don't-care region // from the output. // Picking -1 (over 0) here for ease of testing. #pragma unroll for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) { if (h < hiddenSize) { mySum[h] = -1; } } } } /* Tiled version of the joint forward kernel Detail of this joint function can be found in: [1] Sequence Transduction with Recurrent Neural Networks. f is a tensor of shape [batch, T, H] g is a tensor of shape [batch, U, H] the transducer joint does sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) The resultant tensor is of shape [batch, T, U, H] Each thread is working on a tile of the shape of tileF x tileG in the result tensor. The input for the tile is first loaded in the register and is reused tileG and tileF times. This joint function can optionally pack the output where the output tensor with a shape of [B, T, U, H] is packed into [B_packed, H]. Don't-care region (t > fLen) or (u > gLen) is removed. To enable packing, the starting offset for each batch need to be specified with batchOffset. Optionally this joint function performs ReLU and/or dropout on the joint output, which is controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint function is a masked operation, which is controlled by the template argument masked. In this case, masks are saved to backward. */ template __global__ void transducer_joint_tiled_forward(const scalar_t* f, const scalar_t* g, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, int64_t hiddenPerBlock, bool packOutput, bool relu, bool dropout, float p, at::PhiloxCudaState philoxArgs, scalar_t* sum, uint8_t* mask) { static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4"); const int batch = blockIdx.z; const int t = blockIdx.y * tileF; const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; const int u = blockIdx.x / hiddenBlock * tileG; const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock; const int h = threadIdx.x; const auto myFLen = fLen[batch]; const auto myGLen = gLen[batch]; OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideF = offsetCal.getStrideF(); scalar_t const* myF = f + batch * maxFLen * hiddenSize + t * hiddenSize + hOffset; scalar_t const* myG = g + batch * maxGLen * hiddenSize + u * hiddenSize + hOffset; scalar_t* mySum = sum + myBatchOffset + t * strideF + u * hiddenSize + hOffset; uint8_t* myMask = mask + myBatchOffset + t * strideF + u * hiddenSize + hOffset; // The following code is only needed for dropout. We try to bypass them as much as possible. auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) : std::make_tuple(static_cast(0), static_cast(0)); uint64_t tid = masked ? (static_cast(blockIdx.z) * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x : 0; Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; bool dropoutMask[U]; if (t < myFLen and u < myGLen and hOffset + h < hiddenSize) { // register buffers for tiled input reuse scalar_t fBuffer[tileF], gBuffer[tileG]; for (int i = 0; i < tileF; ++i) { if (t + i < myFLen) fBuffer[i] = myF[i * hiddenSize + h]; } for (int j = 0; j < tileG; ++j) { if (u + j < myGLen) gBuffer[j] = myG[j * hiddenSize + h]; } #pragma unroll for (int i = 0; i < tileF; ++i) { if (t + i < myFLen) { #pragma unroll for (int j = 0; j < tileG; ++j) { int idx = i * tileG + j; if (masked and dropout and idx % U == 0) { // For performance, generate 4 random numbers in one shot // auto rand4 = curand_uniform4(&state); auto rand4 = uniform4(ph()); dropoutMask[0] = rand4.x < p; dropoutMask[1] = rand4.y < p; dropoutMask[2] = rand4.z < p; dropoutMask[3] = rand4.w < p; } if (u + j < myGLen) { scalar_t out = fBuffer[i] + gBuffer[j]; if (masked) { // Apply ReLU here when relu is True bool localMask = relu ? (out > 0) : 1; localMask = dropout ? localMask & dropoutMask[idx % U] : localMask; out = dropout ? out * localMask * scale : out * localMask; myMask[i * strideF + j * hiddenSize + h] = static_cast(localMask); } mySum[i * strideF + j * hiddenSize + h] = out; } else if (packOutput == false and u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1; } } else if (packOutput == false and t + i < maxFLen) { // Again need to write finite data to don't-care region #pragma unroll for (int j = 0; j < tileG; ++j) { if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1; } } } } else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset + h < hiddenSize) { // Only need to ensure the finity in normal mode #pragma unroll for (int i = 0; i < tileF; ++i) { if (t + i < maxFLen) { #pragma unroll for (int j = 0; j < tileG; ++j) { if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1; } } } } } /* Bwd operation (reduction) on one input tensor. Since the operation performed for the two input tensors are exactly the same, only one kernel is needed, and the different indexing offsets and strides are handled by OffsetCalBwd. When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template __device__ void transducer_joint_single_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, bool bwdFasterDim, // whether bwd on the faster moving dimension (u) float scale, scalar_t* inGrad, int yBlockOffset = 0) { const int batch = blockIdx.z; // For the second input tensor, this offset need to be subtracted because the first yBlockOffset // sets of thread blocks are for the first input tensor. const int x = blockIdx.y - yBlockOffset; const int hOffset = blockIdx.x * C10_WARP_SIZE; const int wid = threadIdx.y; const int lid = threadIdx.x; const int numWarp = blockDim.y; extern __shared__ char smem8[]; auto smem = reinterpret_cast(smem8); OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); const auto maxXLen = offsetCal.getMaxXLen(); const auto myXLen = offsetCal.getMyXLen(); const auto myYLen = offsetCal.getMyYLen(); scalar_t* myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset; if (x < myXLen) { const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideX = offsetCal.getStrideX(); const auto strideY = offsetCal.getStrideY(); const scalar_t* myGrad = grad + myBatchOffset + x * strideX + hOffset; const uint8_t* myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr; // Each warp reduces numYPerWarp "y" first acc_t warpSum = 0; auto numYPerWarp = (myYLen + numWarp - 1) / numWarp; #pragma unroll for (int warpY = 0; warpY < numYPerWarp; ++warpY) { auto y = wid * numYPerWarp + warpY; if (y < myYLen and (hOffset + lid) < hiddenSize) if (masked) warpSum += static_cast(myGrad[y * strideY + lid]) * myMask[y * strideY + lid] * scale; else warpSum += myGrad[y * strideY + lid]; } // transpose partial sum in SMEM and reduce further using warpReduce smem[lid * numWarp + wid] = warpSum; __syncthreads(); auto sum = smem[wid * C10_WARP_SIZE + lid]; sum = warpReduce(sum, numWarp); // a a b b c c d d // a a b b c c d d // a a b b c c d d // a a b b c c d d // example of 4 warps (a, b, c, d) with 8 threads per warp // Each warp need 8 / 4 = 2 threads to write the results. if (hOffset + wid * C10_WARP_SIZE / numWarp + lid / numWarp < hiddenSize) { if (lid % numWarp == 0) { myInGrad[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = sum; } } } else if (wid == 0 and hOffset + lid < hiddenSize) { // Need to ensure the grad is zero for don't care region myInGrad[lid] = 0; } } /* Actual bwd (reduction) kernel get launched. Call transducer_joint_single_backward twice on two input tensors. The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template __global__ void transducer_joint_combined_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, float scale, scalar_t* fGrad, scalar_t* gGrad) { if (blockIdx.y < maxFLen) { transducer_joint_single_backward( grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad); } else { transducer_joint_single_backward( grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen); } } /* Vectorized version of transducer_joint_single_backward Doing exact same operation as transducer_joint_single_backward except the load and store are vectorized. When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template __device__ void transducer_joint_single_vec_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, bool bwdFasterDim, float scale, scalar_t* inGrad, int yBlockOffset = 0) { const int batch = blockIdx.z; const int x = blockIdx.y - yBlockOffset; const int hOffset = blockIdx.x * C10_WARP_SIZE * V; const int wid = threadIdx.y; const int lid = threadIdx.x; const int numWarp = blockDim.y; // Figure out the vectorization type for mask using mvec_t = mvec_type; OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); const auto maxXLen = offsetCal.getMaxXLen(); const auto myXLen = offsetCal.getMyXLen(); const auto myYLen = offsetCal.getMyYLen(); scalar_t* myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset; extern __shared__ char smem8[]; auto smem = reinterpret_cast(smem8); acc_t warpSum[V]; scalar_t inBuffer[V]; uint8_t maskBuffer[V]; scalar_t outBuffer[V]; auto myInGradVec = reinterpret_cast(myInGrad); auto outBufferVec = reinterpret_cast(outBuffer); if (x < myXLen) { const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideX = offsetCal.getStrideX(); const auto strideY = offsetCal.getStrideY(); const scalar_t* myGrad = grad + myBatchOffset + x * strideX + hOffset; const uint8_t* myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr; for (int i = 0; i < V; ++i) warpSum[i] = 0; // Each warp reduces numYPerWarp "y" first auto numYPerWarp = (myYLen + numWarp - 1) / numWarp; for (int warpY = 0; warpY < numYPerWarp; ++warpY) { auto y = wid * numYPerWarp + warpY; auto myGradVec = reinterpret_cast(myGrad + y * strideY); auto myMaskVec = masked ? reinterpret_cast(myMask + y * strideY) : nullptr; auto inBufferVec = reinterpret_cast(inBuffer); auto maskBufferVec = reinterpret_cast(maskBuffer); if (hOffset + lid * V < hiddenSize and y < myYLen) { *inBufferVec = myGradVec[lid]; // vectorized load if (masked) { *maskBufferVec = myMaskVec[lid]; #pragma unroll for (int i = 0; i < V; ++i) warpSum[i] += static_cast(inBuffer[i]) * maskBuffer[i] * scale; } else { #pragma unroll for (int i = 0; i < V; ++i) warpSum[i] += inBuffer[i]; } } } // transpose partial sum in SMEM and reduce further using warpReduce for (int i = 0; i < V; ++i) { smem[lid * numWarp + wid] = warpSum[i]; __syncthreads(); auto sum = smem[wid * C10_WARP_SIZE + lid]; if (hOffset + (wid * C10_WARP_SIZE / numWarp) * V < hiddenSize) { sum = warpReduce(sum, numWarp); if (lid % numWarp == 0) { outBuffer[i] = sum; } } __syncthreads(); } // a a b b c c d d // a a b b c c d d // a a b b c c d d // a a b b c c d d // example of 4 warps (a, b, c, d) with 8 threads per warp // Each warp need 8 / 4 = 2 threads to write the results. if (lid % numWarp == 0 and hOffset + (wid * C10_WARP_SIZE / numWarp + lid / numWarp) * V < hiddenSize) myInGradVec[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = *outBufferVec; } else if (wid == 0 and hOffset + lid * V < hiddenSize) { // Need to ensure the grad is zero for don't care region myInGradVec[lid] = 0; } } /* Vecotrized version of transducer_joint_combined_backward Call transducer_joint_single_vec_backward twice on two input tensors. The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. */ template __global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen, const int* gLen, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput, float scale, scalar_t* fGrad, scalar_t* gGrad) { if (blockIdx.y < maxFLen) { transducer_joint_single_vec_backward( grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad); } else { transducer_joint_single_vec_backward( grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen); } } std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) { auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); const auto maxFLen = f.size(1); const auto maxGLen = g.size(1); const auto hiddenSize = f.size(2); bool masked = dropout or relu; int64_t* batchOffsetPtr = nullptr; torch::Tensor sum, mask; auto maskOpt = tensorOpt.dtype(torch::kUInt8); if (!packOutput) { sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); batchOffsetPtr = nullptr; if (masked) mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); } else { sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); batchOffsetPtr = batchOffset.data_ptr(); if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt); } uint8_t* maskPtr = masked ? mask.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); // Simple heuristics const int numThread = std::min(128, (static_cast(hiddenSize) + C10_WARP_SIZE - 1) / C10_WARP_SIZE * C10_WARP_SIZE); if (opt == 0) { // vanilla kernel const int threads = numThread; const dim3 blocks(maxGLen, maxFLen, batchSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF( dtype, "transducer_joint_forward", ([&] { transducer_joint_forward<<>>( f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, sum.data_ptr()); })); } if (opt == 1) { // tiled version. For simplicity, assume tileF == tileG, even though the kernel can // support more general cases. const int threads = numThread; const int hiddenPerBlock = numThread; const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; const dim3 blocks((maxGLen + tileSize - 1) / tileSize * hiddenBlock, (maxFLen + tileSize - 1) / tileSize, batchSize); TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, "Expected tileSize to be in [1, 2, 4], but got ", tileSize); at::PhiloxCudaState rng_engine_inputs; if (masked) { // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler // for non-masked calls. // Therefore no need to initialize. c10::optional gen_; auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, // each thread processes tileF * tileG output elements. int64_t counterOffset = tileSize * tileSize; { std::lock_guard lock(gen->mutex_); rng_engine_inputs = gen->philox_cuda_state(counterOffset); } } AT_DISPATCH_FLOATING_TYPES_AND_HALF( dtype, "transducer_joint_forward", ([&] { void (*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, at::PhiloxCudaState, scalar_t*, uint8_t*); if (masked) { switch (tileSize) { case 2: kernel = &transducer_joint_tiled_forward; break; case 4: kernel = &transducer_joint_tiled_forward; break; } } else { switch (tileSize) { case 1: kernel = &transducer_joint_tiled_forward; break; case 2: kernel = &transducer_joint_tiled_forward; break; case 4: kernel = &transducer_joint_tiled_forward; break; } } kernel<<>>(f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, hiddenPerBlock, packOutput, relu, dropout, 1.0f - dropoutProb, rng_engine_inputs, sum.data_ptr(), maskPtr); })); } C10_CUDA_CHECK(cudaGetLastError()); if (masked) return {sum, mask}; else return {sum}; } std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, int maxGLen, bool packOutput, float scale) { auto grad = in[0]; bool masked = (in.size() == 2); uint8_t* maskPtr = masked ? in[1].data_ptr() : nullptr; auto tensorOpt = grad.options(); auto dtype = grad.scalar_type(); const int batchSize = fLen.size(0); const int hiddenSize = grad.size(-1); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); int64_t* batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); // The number "y" I would like each thread to work on const int workPerThread = 32; // Since the bwd for f and g have the same thread block size, we need to use the max of the two. int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread - 1) / workPerThread); // Would like to have at least 2 warps numWarp = std::max(2, numWarp); // cap on the maximum number of warps allowed numWarp = std::min(maxNumWarp, numWarp); // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape // numWarp x warpSize const int smemSize = numWarp * C10_WARP_SIZE; const dim3 threads(C10_WARP_SIZE, numWarp, 1); AT_DISPATCH_FLOATING_TYPES_AND_HALF( dtype, "transducer_joint_cuda_backward_kernel", ([&] { auto gradPtr = grad.data_ptr(); auto fLenPtr = fLen.data_ptr(); auto gLenPtr = gLen.data_ptr(); auto fGradPtr = fGrad.data_ptr(); auto gGradPtr = gGrad.data_ptr(); // resolve the acc_t type using acc_t = at::acc_type; using vec_t = uint64_t; constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); constexpr int vecAlignment = std::alignment_of::value; // if all input and output tensors meet the alignment requirement bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) and (reinterpret_cast(fGradPtr) % vecAlignment == 0) and (reinterpret_cast(gGradPtr) % vecAlignment == 0); if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. const dim3 blocks((hiddenSize + C10_WARP_SIZE * vectFactor - 1) / (C10_WARP_SIZE * vectFactor), maxFLen + maxGLen, batchSize); if (masked) { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, scale, fGradPtr, gGradPtr); } else { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, scale, fGradPtr, gGradPtr); } } else { const dim3 blocks((hiddenSize + C10_WARP_SIZE - 1) / C10_WARP_SIZE, maxFLen + maxGLen, batchSize); if (masked) { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, scale, fGradPtr, gGradPtr); } else { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, scale, fGradPtr, gGradPtr); } } })); return {fGrad, gGrad}; } ================================================ FILE: apex/contrib/csrc/transducer/transducer_loss.cpp ================================================ #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput); torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput); std::vector transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen, torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(fLen); CHECK_INPUT(yLen); if (packedInput) CHECK_INPUT(batchOffset); return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput); } torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(lossGrad); CHECK_INPUT(alpha); CHECK_INPUT(beta); CHECK_INPUT(fLen); CHECK_INPUT(yLen); if (packedInput) CHECK_INPUT(batchOffset); return transducer_loss_cuda_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, maxFLen, blankIdx, opt, fuseSoftmaxBackward, packedInput); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", py::call_guard()); m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", py::call_guard()); } ================================================ FILE: apex/contrib/csrc/transducer/transducer_loss_kernel.cu ================================================ #include #include #include #include #include #include #include template __device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) { // standard log-sum-exp trick is used here to provide better numerical stability return (a >= b) ? a + std::log1p(exp(b - a)) : b + std::log1p(exp(a - b)); } // Vanilla transducer loss function (i.e. forward-backward algorithm) // Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer // Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. // This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template __global__ void transducer_loss_forward(const scalar_t* x, const int* label, const int* audLen, const int* txtLen, const int64_t* batchOffset, int64_t dictSize, // 64-bit indexing for data tensor int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, acc_t* alpha, acc_t* beta, scalar_t* loss) { const int batch = blockIdx.y; const int tid = threadIdx.x; const auto myFLen = audLen[batch]; // Note that start of the sentence is added as 1 here const auto myGLen = txtLen[batch] + 1; const auto myLabel = label + batch * (maxGLen - 1); const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; const scalar_t* myX = x + myBatchOffset * dictSize; int u = tid; if (blockIdx.x == 0) { // alpha path acc_t* myAlpha = alpha + batch * maxFLen * maxGLen; if (u == 0) myAlpha[0] = 0; __syncthreads(); for (int64_t step = 1; step < myFLen + myGLen - 1; ++step) { // Move along the diagonal wavefront to leverage available parallelism for (u = tid; u < myGLen; u += blockDim.x) { int64_t t = step - u; if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { // Eq(16) in [1] if (u == 0) { // alpha(t, u) = alpha(t-1, u) * null(t-1, u) myAlpha[t * maxGLen + u] = myAlpha[(t - 1) * maxGLen] + myX[((t - 1) * myStrideT) * dictSize + blankIdx]; } else if (t == 0) { // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1) myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]]; } else { // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1) acc_t current = myAlpha[(t - 1) * maxGLen + u] + myX[((t - 1) * myStrideT + u) * dictSize + blankIdx]; acc_t next = myAlpha[t * maxGLen + u - 1] + myX[(t * myStrideT + u - 1) * dictSize + myLabel[u - 1]]; myAlpha[t * maxGLen + u] = logSumExp(next, current); } } } __syncthreads(); } } else if (blockIdx.x == 1) { // beta path acc_t* myBeta = beta + batch * maxFLen * maxGLen; if (u == 0) { myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; } __syncthreads(); for (int64_t step = myFLen + myGLen - 3; step >= 0; --step) { for (u = tid; u < myGLen; u += blockDim.x) { int64_t t = step - u; if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { // Eq(18) in [1] if (u == myGLen - 1) { // beta(t, u) = beta(t+1, u) * null(t, u) myBeta[t * maxGLen + u] = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx]; } else if (t == myFLen - 1) { // beta(t, u) = beta(t, u+1) * y(t, u) myBeta[t * maxGLen + u] = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]]; } else { // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u) acc_t current = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx]; acc_t next = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]]; myBeta[t * maxGLen + u] = logSumExp(next, current); } } } __syncthreads(); } if (tid == 0) loss[batch] = -myBeta[0]; } } // transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization. // Compared to the vanilla version, there are two optimizations: // 1. load x in batch through loop unrolling to reduce the latency. // 2. Use registers and shared memory to hold alpha and beta values passed from one step the next. // For simplicity, this kernel currently only supports U <= maxThread, which should be the common // case. For cases where U > maxThread, the vanilla kernel is used as a fallback option. // Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer // Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. // This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template __global__ void transducer_loss_batch_load_forward(const scalar_t* x, const int* label, const int* audLen, const int* txtLen, const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, acc_t* alpha, acc_t* beta, scalar_t* loss) { const int batch = blockIdx.y; int u = threadIdx.x; const auto myFLen = audLen[batch]; const auto myGLen = txtLen[batch] + 1; const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; const scalar_t* myX = x + myBatchOffset * dictSize; scalar_t next[batchLdSize], current[batchLdSize]; extern __shared__ char smem8[]; auto smem = reinterpret_cast(smem8); if (blockIdx.x == 0) { // alpha path acc_t* myAlpha = alpha + batch * maxFLen * maxGLen; // two SMEM regions for double buffering read and write data to avoid data race acc_t* const sharedAlpha[2] = {smem, smem + maxGLen}; sharedAlpha[0][u] = 0; __syncthreads(); if (u == 0) myAlpha[0] = 0; auto myAlphaLabel = (u == 0) ? 0 : label[batch * (maxGLen - 1) + u - 1]; // register used to pass value to the next step for the same thread acc_t prvStepAlpha = 0; for (int64_t step = 1; step < myFLen + myGLen - 1 + batchLdSize; step += batchLdSize) { // Move along the diagonal wavefront to leverage available parallelism // Batch loading X through loop unrolling #pragma unroll for (int i = 0; i < batchLdSize; ++i) { if (step + i < myFLen + myGLen - 1) { // index computing int64_t t = step + i - u; int64_t currentId = ((t - 1) * myStrideT + u) * dictSize + blankIdx; int64_t nextId = (t * myStrideT + u - 1) * dictSize + myAlphaLabel; // main loading loop if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { if (u == 0) { current[i] = myX[currentId]; } else if (t == 0) { next[i] = myX[nextId]; } else { current[i] = myX[currentId]; next[i] = myX[nextId]; } } } } // main computing loop for (int i = 0; i < batchLdSize; ++i) { // swap the pointer for double buffering auto sharedAlphaRd = sharedAlpha[(step + i - 1) % 2]; auto sharedAlphaWr = sharedAlpha[(step + i) % 2]; if (step + i < myFLen + myGLen - 1) { int64_t t = step + i - u; if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { // Eq(16) in [1] if (u == 0) prvStepAlpha = prvStepAlpha + current[i]; else if (t == 0) prvStepAlpha = sharedAlphaRd[u - 1] + next[i]; else prvStepAlpha = logSumExp(prvStepAlpha + current[i], sharedAlphaRd[u - 1] + next[i]); sharedAlphaWr[u] = prvStepAlpha; myAlpha[t * maxGLen + u] = prvStepAlpha; } } __syncthreads(); } } } else if (blockIdx.x == 1) { // beta path acc_t* myBeta = beta + batch * maxFLen * maxGLen; // two SMEM regions for double buffering read and write data to avoid data race acc_t* const sharedBeta[2] = {smem, smem + maxGLen}; sharedBeta[0][u] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; __syncthreads(); auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch * (maxGLen - 1) + u]; // register used to pass value to the next step for the same thread acc_t prvStepBeta = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx]; if (u == 0) myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = prvStepBeta; for (int64_t step = 1; step < myFLen + myGLen - 1; step += batchLdSize) { // Move along the diagonal wavefront to leverage available parallelism // Batch loading X #pragma unroll for (int i = 0; i < batchLdSize; ++i) { if (step + i < myFLen + myGLen - 1) { // index computing int64_t t = myFLen + myGLen - (step + i) - 2 - u; int64_t currentId = (t * myStrideT + u) * dictSize + blankIdx; int64_t nextId = (t * myStrideT + u) * dictSize + myBetaLabel; // main loading loop if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { if (u == myGLen - 1) { current[i] = myX[currentId]; } else if (t == myFLen - 1) { next[i] = myX[nextId]; } else { current[i] = myX[currentId]; next[i] = myX[nextId]; } } } } // main computing loop for (int i = 0; i < batchLdSize; ++i) { // swap the pointer for double buffering auto sharedBetaRd = sharedBeta[(step + i - 1) % 2]; auto sharedBetaWr = sharedBeta[(step + i) % 2]; if (step + i < myFLen + myGLen - 1) { int64_t t = myFLen + myGLen - (step + i) - 2 - u; if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) { // Eq(18) in [1] if (u == myGLen - 1) prvStepBeta = prvStepBeta + current[i]; else if (t == myFLen - 1) prvStepBeta = sharedBetaRd[u + 1] + next[i]; else prvStepBeta = logSumExp(prvStepBeta + current[i], sharedBetaRd[u + 1] + next[i]); sharedBetaWr[u] = prvStepBeta; myBeta[t * maxGLen + u] = prvStepBeta; } } __syncthreads(); } } if (u == 0) loss[batch] = -prvStepBeta; } } // Vanilla transudcer loss backward operation. // Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, // hence only Eq(20) in [1] is implemented in this kernel. // Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time // Since only gradients for the correct token and null token need to be updated, gradients at other // locations are initialized to 0. // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template __global__ void transducer_loss_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta, const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) { const int tid = threadIdx.x; const int t = blockIdx.x; const int batch = blockIdx.y; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; auto myX = x + (myBatchOffset + t * myStrideT) * dictSize; auto myAlpha = alpha + batch * maxFLen * maxGLen; auto myBeta = beta + batch * maxFLen * maxGLen; auto myXGrad = xGrad + (myBatchOffset + t * myStrideT) * dictSize; auto myLabel = label + batch * (maxGLen - 1); int64_t u = tid; while (t < myFLen and u < myGLen) { // Do the update // loss = -ln(Pr(y*|x)) acc_t grad = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; if (u != myGLen - 1) myXGrad[u * dictSize + myLabel[u]] = -std::exp(grad + myBeta[t * maxGLen + u + 1] + myX[u * dictSize + myLabel[u]]); if (t == myFLen - 1 and u == myGLen - 1) myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myX[u * dictSize + blankIdx]); else if (t != myFLen - 1) myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myBeta[(t + 1) * maxGLen + u] + myX[u * dictSize + blankIdx]); u += blockDim.x; } } // Fused transudcer loss backward operation. // Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template __global__ void transducer_loss_fused_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta, const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) { const int tid = threadIdx.x; const int u = blockIdx.x; const int t = blockIdx.y; const int batch = blockIdx.z; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize; if (t < myFLen and u < myGLen) { auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize; auto myAlpha = alpha + batch * maxFLen * maxGLen; auto myBeta = beta + batch * maxFLen * maxGLen; auto myLabel = label + batch * (maxGLen - 1); // load and store shared variables in SMEM if (tid == 0) { commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; myBetaTU = myBeta[t * maxGLen + u]; myBetaTUp1 = myBeta[t * maxGLen + u + 1]; myBetaTp1U = myBeta[(t + 1) * maxGLen + u]; myLabelShared = myLabel[u]; } __syncthreads(); for (int64_t h = tid; h < dictSize; h += blockDim.x) { // Do the update acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) acc_t myGrad = std::exp(grad + myBetaTU); if (u != myGLen - 1 and h == myLabelShared) { myGrad -= std::exp(grad + myBetaTUp1); } else if (h == blankIdx) { if (t == myFLen - 1 and u == myGLen - 1) myGrad -= std::exp(grad); else if (t != myFLen - 1) myGrad -= std::exp(grad + myBetaTp1U); } myXGrad[h] = myGrad; } } else if (!packedInput) { // In non-pack mode, need to make sure the gradients for don't-care regions are zero. for (int64_t h = tid; h < dictSize; h += blockDim.x) { myXGrad[h] = 0; } } } // Vectorized version of fused transudcer loss backward operation. // Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with // batchOffset. template __global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen, const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta, const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) { const int tid = threadIdx.x; const int u = blockIdx.x; const int t = blockIdx.y; const int batch = blockIdx.z; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize; auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize; auto myAlpha = alpha + batch * maxFLen * maxGLen; auto myBeta = beta + batch * maxFLen * maxGLen; auto myLabel = label + batch * (maxGLen - 1); // Variabels for vectorization scalar_t myXBuffer[V], myXGradBuffer[V]; auto myXVec = reinterpret_cast(myX); auto myXGradVec = reinterpret_cast(myXGrad); auto myXBufferVec = reinterpret_cast(myXBuffer); auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); if (t < myFLen and u < myGLen) { // load and store shared variables in SMEM if (tid == 0) { commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0]; myBetaTU = myBeta[t * maxGLen + u]; if (t != myFLen - 1) myBetaTp1U = myBeta[(t + 1) * maxGLen + u]; if (u != myGLen - 1) { myBetaTUp1 = myBeta[t * maxGLen + u + 1]; myLabelShared = myLabel[u]; } } __syncthreads(); #pragma unroll for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) { // Load myX in a vector form *myXBufferVec = myXVec[h0 / V]; // Do the update for a vector of input #pragma unroll for (int i = 0; i < V; ++i) { auto h = h0 + i; acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x)) acc_t myGrad = std::exp(grad + myBetaTU); if (u != myGLen - 1 and h == myLabelShared) { myGrad -= std::exp(grad + myBetaTUp1); } else if (h == blankIdx) { if (t == myFLen - 1 and u == myGLen - 1) myGrad -= std::exp(grad); else if (t != myFLen - 1) myGrad -= std::exp(grad + myBetaTp1U); } myXGradBuffer[i] = myGrad; } // Store myXGrad in a vector form myXGradVec[h0 / V] = *myXGradBufferVec; } } else if (!packedInput) { // In non-pack mode, need to make sure the gradients for don't-care regions are zero. for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) { myXGradVec[h0 / V] = 0; } } } std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput) { auto scalarType = x.scalar_type(); auto tensorOpt = x.options(); const int batchSize = label.size(0); const int maxGLen = label.size(1) + 1; const int dictSize = x.size(-1); TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, "Expected blank index to be in the range of 0 to ", dictSize - 1, ", but got ", blankIdx); TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, "Got an invalid optimization level ", opt); // The data type of alpha and beta will be resolved at dispatch time, // hence defined here and assigned later torch::Tensor alpha; torch::Tensor beta; torch::Tensor loss = torch::empty({batchSize}, tensorOpt); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( scalarType, "transducer_loss_cuda_forward", ([&] { // resolve accumulation type using acc_t = at::acc_type; auto accType = c10::CppTypeToScalarType::value; auto accTensorOpt = tensorOpt.dtype(accType); alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); // decide what kernel to launch based on the problem size // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla // kernel. const auto smemSize = 2 * maxGLen * sizeof(acc_t); const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 : (opt == -1) ? 1 : opt; const int threads = std::min(maxThreadPerBlock, maxGLen); const dim3 blocks(2, batchSize, 1); if (optFallBack == 0) transducer_loss_forward<<>>( x.data_ptr(), label.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr(), beta.data_ptr(), loss.data_ptr()); else if (optFallBack == 1) transducer_loss_batch_load_forward<<>>( x.data_ptr(), label.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr(), beta.data_ptr(), loss.data_ptr()); })); C10_CUDA_CHECK(cudaGetLastError()); return {alpha, beta, loss}; } torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput) { auto dtype = x.scalar_type(); torch::Tensor xGrad; const int batchSize = label.size(0); const int maxGLen = label.size(1) + 1; const int dictSize = x.size(-1); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; const int warpSize = deviceProperties->warpSize; const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (fuseSoftmaxBackward) { // alloc empty tensors for performance, hence need to ensure zeros are writtern to // don't-care region in the kernel. xGrad = torch::empty_like(x); // Would like each thread to work on 4 hidden units const int workPerThread = 4; // Don't want to have more than 128 threads per thread block const int maxThreadPerElmt = std::min(128, maxThreadPerBlock); const int threads = std::min(maxThreadPerElmt, std::max(warpSize, (dictSize + workPerThread - 1) / workPerThread)); const dim3 blocks(maxGLen, maxFLen, batchSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF( dtype, "transducer_loss_cuda_backward", ([&] { using vec_t = uint64_t; using acc_t = at::acc_type; constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); constexpr int vecAlignment = std::alignment_of::value; // if all input and output tensors meet the alignment requirement bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 and reinterpret_cast(xGrad.data_ptr()) % vecAlignment == 0; if (vectFactor > 1 and dictSize % vectFactor == 0 and memAlign) { transducer_loss_fused_vec_backward<<>>( x.data_ptr(), lossGrad.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), label.data_ptr(), alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); } else { transducer_loss_fused_backward<<>>( x.data_ptr(), lossGrad.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), label.data_ptr(), alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); } })); } else { // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize // the tensor with all zeros. xGrad = torch::zeros_like(x); // don't launch more threads than needed. const int threads = std::min(maxThreadPerBlock, maxGLen); const dim3 blocks(maxFLen, batchSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { using acc_t = at::acc_type; transducer_loss_backward<<>>( x.data_ptr(), lossGrad.data_ptr(), audLen.data_ptr(), txtLen.data_ptr(), label.data_ptr(), alpha.data_ptr(), beta.data_ptr(), batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr()); })); } C10_CUDA_CHECK(cudaGetLastError()); return xGrad; } ================================================ FILE: apex/contrib/csrc/xentropy/interface.cpp ================================================ #include #include // CUDA forward declarations std::vector softmax_xentropy_cuda(const at::Tensor& input, const at::Tensor& labels, const float smoothing, const bool half_to_float); at::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at::Tensor& logits, const at::Tensor& max_log_sum_exp, const at::Tensor& labels, const float smoothing); // C++ interface #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector softmax_xentropy_forward(const at::Tensor& input, const at::Tensor& labels, const float smoothing, const bool half_to_float) { CHECK_CUDA(input); CHECK_INPUT(labels); return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); } at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits, const at::Tensor& max_log_sum_exp, const at::Tensor& labels, const float smoothing) { CHECK_CUDA(grad_loss); CHECK_CUDA(logits); CHECK_INPUT(max_log_sum_exp); CHECK_INPUT(labels); return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::call_guard()); m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::call_guard()); // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables py::object version = py::cast( #ifdef XENTROPY_VER XENTROPY_VER #else std::string{} #endif ); m.attr("__version__") = version; } ================================================ FILE: apex/contrib/csrc/xentropy/xentropy_kernel.cu ================================================ /** * From PyTorch: * * Copyright (c) 2016- Facebook, Inc (Adam Paszke) * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) * Copyright (c) 2011-2013 NYU (Clement Farabet) * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) * * From Caffe2: * * Copyright (c) 2016-present, Facebook Inc. All rights reserved. * * All contributions by Facebook: * Copyright (c) 2016 Facebook Inc. * * All contributions by Google: * Copyright (c) 2015 Google Inc. * All rights reserved. * * All contributions by Yangqing Jia: * Copyright (c) 2015 Yangqing Jia * All rights reserved. * * All contributions from Caffe: * Copyright(c) 2013, 2014, 2015, the respective contributors * All rights reserved. * * All other contributions: * Copyright(c) 2015, 2016 the respective contributors * All rights reserved. * * Caffe2 uses a copyright model similar to Caffe: each contributor holds * copyright over their contributions to Caffe2. The project versioning records * all such contribution and copyright details. If a contributor wants to further * mark their specific copyright on a particular contribution, they should * indicate their copyright solely in the commit message of the change when it is * committed. * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America * and IDIAP Research Institute nor the names of its contributors may be * used to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include "type_shim.h" #define ALIGN_BYTES 16 using Tensor = at::Tensor; using TensorList = at::TensorList; using ScalarType = at::ScalarType; using at::acc_type; template struct LogSoftMaxForwardEpilogue { __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) : logsum(max_input + std::log(sum)) {} __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) : logsum(max_log_sum_exp) {} __device__ __forceinline__ OutT operator()(T input) const { return static_cast(input - logsum); } const AccumT logsum; }; template struct LogSoftMaxBackwardEpilogue { __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) : sum(sum) {} __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { return static_cast(gradOutput - std::exp(static_cast(output)) * sum); } const AccumT sum; }; const int max_threads = 1024; inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); while (block_size < (max_block_size / 2)) block_size *= 2; // Launch at least a single warp - the kernel assumes that. block_size = std::max(block_size, static_cast(32)); return dim3(block_size); } template struct Add { __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; //////////////////////////////////////////////////////////////////////////////// // Regular kernel (fast when dim_size is large; requires inner_size == 1) //////////////////////////////////////////////////////////////////////////////// template struct MaxFloat { __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { return ::max(max, (AccumT)v); } }; template struct AddFloat { __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + v; } }; template struct SumExpFloat { __device__ __forceinline__ SumExpFloat(AccumT v) : max_k(v) {} __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + std::exp(v - max_k); } const AccumT max_k; }; template