Repository: karpathy/llm.c Branch: master Commit: f1e2ace65149 Files: 102 Total size: 1.2 MB Directory structure: gitextract_pp0afoji/ ├── .github/ │ └── workflows/ │ ├── ci.yml │ ├── ci_gpu.yml │ └── ci_tests.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── dev/ │ ├── cpu/ │ │ └── matmul_forward.c │ ├── cuda/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── adamw.cu │ │ ├── attention_backward.cu │ │ ├── attention_forward.cu │ │ ├── benchmark_on_modal.py │ │ ├── classifier_fused.cu │ │ ├── common.h │ │ ├── crossentropy_forward.cu │ │ ├── crossentropy_softmax_backward.cu │ │ ├── encoder_backward.cu │ │ ├── encoder_forward.cu │ │ ├── fused_residual_forward.cu │ │ ├── gelu_backward.cu │ │ ├── gelu_forward.cu │ │ ├── global_norm.cu │ │ ├── layernorm_backward.cu │ │ ├── layernorm_forward.cu │ │ ├── matmul_backward.cu │ │ ├── matmul_backward_bias.cu │ │ ├── matmul_forward.cu │ │ ├── nccl_all_reduce.cu │ │ ├── permute.cu │ │ ├── residual_forward.cu │ │ ├── softmax_forward.cu │ │ └── trimat_forward.cu │ ├── data/ │ │ ├── README.md │ │ ├── data_common.py │ │ ├── edu_fineweb.sh │ │ ├── fineweb.py │ │ ├── fineweb.sh │ │ ├── hellaswag.py │ │ ├── mmlu.py │ │ ├── tinyshakespeare.py │ │ └── tinystories.py │ ├── download_starter_pack.sh │ ├── eval/ │ │ ├── README.md │ │ ├── export_hf.py │ │ ├── run_eval.sh │ │ └── summarize_eval.py │ ├── loss_checker_ci.py │ ├── test/ │ │ ├── Makefile │ │ ├── device_file_io.cu │ │ ├── test_dataloader.c │ │ └── test_outlier_detector.c │ ├── unistd.h │ └── vislog.ipynb ├── doc/ │ └── layernorm/ │ ├── layernorm.c │ ├── layernorm.md │ └── layernorm.py ├── llmc/ │ ├── adamw.cuh │ ├── attention.cuh │ ├── cublas_common.h │ ├── cuda_common.h │ ├── cuda_utils.cuh │ ├── cudnn_att.cpp │ ├── cudnn_att.h │ ├── dataloader.h │ ├── encoder.cuh │ ├── fused_classifier.cuh │ ├── gelu.cuh │ ├── global_norm.cuh │ ├── layernorm.cuh │ ├── logger.h │ ├── matmul.cuh │ ├── mfu.h │ ├── outlier_detector.h │ ├── rand.h │ ├── sampler.h │ ├── schedulers.h │ ├── tokenizer.h │ ├── utils.h │ └── zero.cuh ├── profile_gpt2.cu ├── profile_gpt2cu.py ├── requirements.txt ├── scripts/ │ ├── README.md │ ├── multi_node/ │ │ ├── run_gpt2_124M_fs.sbatch │ │ ├── run_gpt2_124M_mpi.sh │ │ └── run_gpt2_124M_tcp.sbatch │ ├── pyrun_gpt2_124M.sh │ ├── run_gpt2_124M.sh │ ├── run_gpt2_1558M.sh │ ├── run_gpt2_350M.sh │ ├── run_gpt2_774M.sh │ └── run_gpt3_125M.sh ├── test_gpt2.c ├── test_gpt2.cu ├── test_gpt2_fp32.cu ├── train_gpt2.c ├── train_gpt2.cu ├── train_gpt2.py ├── train_gpt2_fp32.cu └── train_llama3.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/ci.yml ================================================ name: Build and test on: create: workflow_dispatch: push: branches: - master pull_request: branches: - master jobs: build-and-test-cpu: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v4 - name: Install OpenMP if: matrix.os != 'windows-latest' run: | if [ "${{ runner.os }}" == "Linux" ]; then sudo apt-get update && sudo apt-get install -y libomp-dev elif [ "${{ runner.os }}" == "macOS" ]; then brew install libomp fi - name: Install dependencies run: pip install -r requirements.txt - name: Run preprocessing run: python dev/data/tinyshakespeare.py - name: Train model run: python train_gpt2.py --device=cpu - name: Download Win32 Make.exe if: matrix.os == 'windows-latest' run: | $wc = New-Object System.Net.WebClient $url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip' $output = './make-bin-win64.zip' $wc.DownloadFile($url, $output) - name: Unzip Win32 Makefile if: matrix.os == 'windows-latest' run: | unzip make-bin-win64.zip - name: Compile training and testing program if: matrix.os != 'windows-latest' run: make test_gpt2 train_gpt2 - name: Compile training and testing program for Windows if: matrix.os == 'windows-latest' shell: cmd run: | call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat" make-4.4.1\dist\make WIN_CI_BUILD=1 test_gpt2 train_gpt2 - name: Execute testing program (With OpenMP) if: matrix.os != 'windows-latest' run: OMP_NUM_THREADS=8 ./test_gpt2 - name: Execute Windows testing program (With OpenMP) if: matrix.os == 'windows-latest' shell: cmd run: | copy test_gpt2 test_gpt2.exe test_gpt2.exe - name: Compile training and testing program without OpenMP if: matrix.os != 'windows-latest' run: NO_OMP=1 make test_gpt2 train_gpt2 - name: Execute testing program (No OpenMP) if: matrix.os != 'windows-latest' run: ./test_gpt2 build-cuda-windows: runs-on: windows-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Download Win32 Make.exe run: | $wc = New-Object System.Net.WebClient $url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip' $output = './make-bin-win64.zip' $wc.DownloadFile($url, $output) - name: Unzip Win32 Makefile run: | unzip make-bin-win64.zip - name: Install Cuda Toolkit 12.4 on Windows run: | mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" choco install unzip -y curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y # Default installation path for CUDA Toolkit is C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4 - name: Add Path run: | echo "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - name: Build Cuda targets shell: cmd working-directory: ${{ github.workspace }} run: | call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat" make-4.4.1\dist\make -j WIN_CI_BUILD=1 train_gpt2fp32cu test_gpt2fp32cu test_gpt2cu train_gpt2cu profile_gpt2cu build-ubuntu20-04: runs-on: ubuntu-20.04 container: image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: System Info run: | nvcc --version g++ --version - name: Install cudnn frontend run: | apt-get update && apt-get install -y git git clone https://github.com/NVIDIA/cudnn-frontend.git - name: Build FP32 checkpoint run: make train_gpt2fp32cu test_gpt2fp32cu - name: Build FP32 precision run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu - name: Build with CUDNN run: PRECISION=BF16 USE_CUDNN=1 make train_gpt2cu test_gpt2cu profile_gpt2cu build-cuda-fp32: runs-on: ubuntu-latest container: image: nvidia/cuda:12.4.1-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: Build FP32 checkpoint run: make train_gpt2fp32cu test_gpt2fp32cu - name: Build FP32 precision run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu build-cuda-bf16: runs-on: ubuntu-latest container: image: nvidia/cuda:12.4.1-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: Build project run: PRECISION=BF16 make test_gpt2cu train_gpt2cu profile_gpt2cu build-cuda-fp16: runs-on: ubuntu-latest container: image: nvidia/cuda:12.4.1-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: Build project run: PRECISION=FP16 make test_gpt2cu train_gpt2cu profile_gpt2cu build-cuda-kernels: runs-on: ubuntu-latest container: image: nvidia/cuda:12.4.1-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: Install OpenMP and OpenMPI run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev - name: Build project run: make -j4 -C dev/cuda ================================================ FILE: .github/workflows/ci_gpu.yml ================================================ name: GPU Builds and Tests on: create: workflow_dispatch: push: branches: - master pull_request: branches: - master jobs: build-and-test-gpu: runs-on: ubicloud-gpu-standard-1-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Install OpenMP run: sudo apt-get update && sudo apt-get install -y libomp-dev - name: Install dependencies run: pip install -r requirements.txt - name: Run preprocessing run: python dev/data/tinyshakespeare.py - name: Train model run: python train_gpt2.py - name: Compile training and testing program run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu - name: Train model (With OpenMP) run: OMP_NUM_THREADS=8 ./train_gpt2cu - name: Train model (FP32) with gpt2_124M.bin run: | PRECISION=FP32 make train_gpt2cu ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e "gpt2_124M.bin" - name: Test for percent loss differential for FP32 run: | PRECISION=FP32 make train_gpt2cu ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e "gpt2_124M.bin" > train_gpt2cu_fp32_precision.txt python dev/loss_checker_ci.py -f train_gpt2cu_fp32_precision.txt -s 20 -e 28 -a 5.0 - name: Build FP32 precision run: PRECISION=FP32 make test_gpt2cu profile_gpt2cu - name: Run default run: ./test_gpt2cu - name: Run no recompute GeLU run: ./test_gpt2cu -r 0 - name: Run recompute LN run: ./test_gpt2cu -r 2 - name: Build BF16 precision run: PRECISION=BF16 make train_gpt2cu test_gpt2cu profile_gpt2cu - name: Run default run: ./test_gpt2cu - name: Run no recompute GeLU run: ./test_gpt2cu -r 0 - name: Run no master weights run: ./test_gpt2cu -w 0 - name: Run recompute LN run: ./test_gpt2cu -r 2 - name: Train model fp32 (With OpenMP) run: OMP_NUM_THREADS=8 ./train_gpt2fp32cu - name: Execute testing program (With OpenMP) run: OMP_NUM_THREADS=8 ./test_gpt2cu - name: Execute testing program fp32 (With OpenMP) run: OMP_NUM_THREADS=8 ./test_gpt2fp32cu - name: Compile training and testing program without OpenMP run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu - name: Train model (No OpenMP) run: NO_OMP=1 ./train_gpt2cu - name: Train model fp32 (No OpenMP) run: NO_OMP=1 ./train_gpt2fp32cu - name: Execute testing program (No OpenMP) run: ./test_gpt2cu -b 32 - name: Execute testing program fp32 (No OpenMP) run: ./test_gpt2fp32cu - name: Install cuDNN-frontend run: git clone https://github.com/NVIDIA/cudnn-frontend.git - name: Build with cuDNN run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu - name: Train model with cuDNN run: ./train_gpt2cu - name: Train model fp32 with cuDNN run: ./train_gpt2fp32cu - name: Execute testing program with cuDNN run: ./test_gpt2cu - name: Execute testing program fp32 with cuDNN run: ./test_gpt2fp32cu unit-tests-gpu: runs-on: ubicloud-gpu-standard-1-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Test Device<->File IO run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io ================================================ FILE: .github/workflows/ci_tests.yml ================================================ name: Unit, Static and other Tests on: create: workflow_dispatch: push: branches: - master pull_request: branches: - master jobs: dataloader_test: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: test the dataloader without / with sanitize address run: | cd dev/test make PRECISION=BF16 test_dataloader ./test_dataloader make clean make PRECISION=BF16 TEST_CFLAGS="-fsanitize=address -fno-omit-frame-pointer" test_dataloader ./test_dataloader ptx_and_sass_files: runs-on: ubuntu-latest container: image: nvidia/cuda:12.4.1-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - name: Install OpenMP and OpenMPI run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev - name: Generate ptx/sass files and upload them to persistent storage run: | mkdir -p dev/cuda/ptx_sass_logs make train_gpt2cu cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass cd dev/cuda make -j all_ptx make -j all_sass cp *.ptx ptx_sass_logs/ cp *.sass ptx_sass_logs/ ls ptx_sass_logs/ - name: Generate ptx/sass files for A100 and upload them to persistent storage run: | mkdir -p dev/cuda/ptx_sass_logs_A100 make train_gpt2cu GPU_COMPUTE_CAPABILITY=80 cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass cd dev/cuda make -j GPU_COMPUTE_CAPABILITY=80 all_ptx make -j GPU_COMPUTE_CAPABILITY=80 all_sass cp *.ptx ptx_sass_logs_A100/ cp *.sass ptx_sass_logs_A100/ ls ptx_sass_logs_A100/ - name: Generate ptx/sass files for H100 and upload them to persistent storage run: | mkdir -p dev/cuda/ptx_sass_logs_H100 make train_gpt2cu GPU_COMPUTE_CAPABILITY=90 cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass cd dev/cuda make -j GPU_COMPUTE_CAPABILITY=90 all_ptx make -j GPU_COMPUTE_CAPABILITY=90 all_sass cp *.ptx ptx_sass_logs_H100/ cp *.sass ptx_sass_logs_H100/ ls ptx_sass_logs_H100/ - name: Upload ptx/sass files uses: actions/upload-artifact@v4 with: name: ptx_sass_files path: dev/cuda/ptx_sass_logs/ retention-days: 30 # days to retain - name: Upload ptx/sass files for A100 uses: actions/upload-artifact@v4 with: name: ptx_sass_files_A100 path: dev/cuda/ptx_sass_logs_A100/ retention-days: 30 # days to retain - name: Upload ptx/sass files for H100 uses: actions/upload-artifact@v4 with: name: ptx_sass_files_H100 path: dev/cuda/ptx_sass_logs_H100/ retention-days: 30 # days to retain ================================================ FILE: .gitignore ================================================ # dot files and such .vscode .venv # .bin files generated by Python *.bin # data directories dev/data/__pycache__/ dev/data/fineweb10B/ dev/data/hellaswag/ dev/data/mmlu/ dev/data/tinyshakespeare/ dev/data/tinystories/ # binaries test_gpt2 test_gpt2cu test_gpt2fp32cu train_gpt2 train_gpt2cu train_gpt2fp32cu profile_gpt2cu dev/cuda/*_forward dev/cuda/*_backward dev/cuda/classifier_fused dev/cuda/adamw dev/cuda/matmul_backward_bias dev/cuda/nccl_all_reduce dev/cuda/global_norm *.obj *.exe *.o # log files *.log ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Andrej Karpathy Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ CC ?= clang CFLAGS = -Ofast -Wno-unused-result -Wno-ignored-pragmas -Wno-unknown-attributes LDFLAGS = LDLIBS = -lm INCLUDES = CFLAGS_COND = -march=native # Find nvcc SHELL_UNAME = $(shell uname) REMOVE_FILES = rm -f OUTPUT_FILE = -o $@ CUDA_OUTPUT_FILE = -o $@ # Default O3 CPU optimization level for NVCC (0 for fastest compile time) FORCE_NVCC_O ?= 3 # NVCC flags # -t=0 is short for --threads, 0 = number of CPUs on the machine NVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O) NVCC_LDFLAGS = -lcublas -lcublasLt NVCC_INCLUDES = NVCC_LDLIBS = NCLL_INCUDES = NVCC_CUDNN = # By default we don't build with cudnn because it blows up compile time from a few seconds to ~minute USE_CUDNN ?= 0 # We will place .o files in the `build` directory (create it if it doesn't exist) BUILD_DIR = build ifeq ($(OS), Windows_NT) $(shell if not exist $(BUILD_DIR) mkdir $(BUILD_DIR)) REMOVE_BUILD_OBJECT_FILES := del $(BUILD_DIR)\*.obj else $(shell mkdir -p $(BUILD_DIR)) REMOVE_BUILD_OBJECT_FILES := rm -f $(BUILD_DIR)/*.o endif # Function to check if a file exists in the PATH ifneq ($(OS), Windows_NT) define file_exists_in_path $(which $(1) 2>/dev/null) endef else define file_exists_in_path $(shell where $(1) 2>nul) endef endif ifneq ($(CI),true) # if not in CI, then use the GPU query ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY= ifneq ($(call file_exists_in_path, nvidia-smi),) # Get the compute capabilities of all GPUs # Remove decimal points, sort numerically in ascending order, and select the first (lowest) value GPU_COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\.//g' | sort -n | head -n 1) GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY)) endif endif endif # set to defaults if - make GPU_COMPUTE_CAPABILITY= otherwise use the compute capability detected above ifneq ($(GPU_COMPUTE_CAPABILITY),) NVCC_FLAGS += --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] endif # autodect a lot of various supports on current platform $(info ---------------------------------------------) ifneq ($(OS), Windows_NT) NVCC := $(shell which nvcc 2>/dev/null) NVCC_LDFLAGS += -lnvidia-ml # Function to test if the compiler accepts a given flag. define check_and_add_flag $(eval FLAG_SUPPORTED := $(shell printf "int main() { return 0; }\n" | $(CC) $(1) -x c - -o /dev/null 2>/dev/null && echo 'yes')) ifeq ($(FLAG_SUPPORTED),yes) CFLAGS += $(1) endif endef # Check each flag and add it if supported $(foreach flag,$(CFLAGS_COND),$(eval $(call check_and_add_flag,$(flag)))) else CFLAGS := REMOVE_FILES = del *.exe,*.obj,*.lib,*.exp,*.pdb && del SHELL_UNAME := Windows ifneq ($(shell where nvcc 2> nul),"") NVCC := nvcc else NVCC := endif CC := cl CFLAGS = /Idev /Zi /nologo /W4 /WX- /diagnostics:column /sdl /O2 /Oi /Ot /GL /D _DEBUG /D _CONSOLE /D _UNICODE /D UNICODE /Gm- /EHsc /MD /GS /Gy /fp:fast /Zc:wchar_t /Zc:forScope /Zc:inline /permissive- \ /external:W3 /Gd /TP /wd4996 /Fd$@.pdb /FC /openmp:llvm LDFLAGS := LDLIBS := INCLUDES := NVCC_FLAGS += -I"dev" ifeq ($(WIN_CI_BUILD),1) $(info Windows CI build) OUTPUT_FILE = /link /OUT:$@ CUDA_OUTPUT_FILE = -o $@ else $(info Windows local build) OUTPUT_FILE = /link /OUT:$@ && copy /Y $@ $@.exe CUDA_OUTPUT_FILE = -o $@ && copy /Y $@.exe $@ endif endif # Check and include cudnn if available # You can override the path to cudnn frontend by setting CUDNN_FRONTEND_PATH on the make command line # By default, we look for it in HOME/cudnn-frontend/include and ./cudnn-frontend/include # Refer to the README for cuDNN install instructions ifeq ($(USE_CUDNN), 1) ifeq ($(SHELL_UNAME), Linux) ifeq ($(shell [ -d $(HOME)/cudnn-frontend/include ] && echo "exists"), exists) $(info ✓ cuDNN found, will run with flash-attention) CUDNN_FRONTEND_PATH ?= $(HOME)/cudnn-frontend/include else ifeq ($(shell [ -d cudnn-frontend/include ] && echo "exists"), exists) $(info ✓ cuDNN found, will run with flash-attention) CUDNN_FRONTEND_PATH ?= cudnn-frontend/include else $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths) endif NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH) NVCC_LDFLAGS += -lcudnn NVCC_FLAGS += -DENABLE_CUDNN NVCC_CUDNN = $(BUILD_DIR)/cudnn_att.o else ifneq ($(OS), Windows_NT) $(info → cuDNN is not supported on MAC OS right now) else $(info ✓ Windows cuDNN found, will run with flash-attention) ifeq ($(shell if exist "$(HOMEDRIVE)$(HOMEPATH)\cudnn-frontend\include" (echo exists)),exists) CUDNN_FRONTEND_PATH ?= $(HOMEDRIVE)$(HOMEPATH)\cudnn-frontend\include #override on command line if different location else ifeq ($(shell if exist "cudnn-frontend\include" (echo exists)),exists) CUDNN_FRONTEND_PATH ?= cudnn-frontend\include #override on command line if different location else $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths) endif CUDNN_INCLUDE_PATH ?= -I"C:\Program Files\NVIDIA\CUDNN\v9.1\include\12.4" CUDNN_FRONTEND_PATH += $(CUDNN_INCLUDE_PATH) NVCC_FLAGS += --std c++20 -Xcompiler "/std:c++20" -Xcompiler "/EHsc /W0 /nologo /Ox /FS" -maxrregcount=0 --machine 64 NVCC_CUDNN = $(BUILD_DIR)\cudnn_att.obj NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH) NVCC_LDFLAGS += -L"C:\Program Files\NVIDIA\CUDNN\v9.1\lib\12.4\x64" -lcudnn NVCC_FLAGS += -DENABLE_CUDNN endif endif else $(info → cuDNN is manually disabled by default, run make with `USE_CUDNN=1` to try to enable) endif # Check if OpenMP is available # This is done by attempting to compile an empty file with OpenMP flags # OpenMP makes the code a lot faster so I advise installing it # e.g. on MacOS: brew install libomp # e.g. on Ubuntu: sudo apt-get install libomp-dev # later, run the program by prepending the number of threads, e.g.: OMP_NUM_THREADS=8 ./gpt2 # First, check if NO_OMP is set to 1, if not, proceed with the OpenMP checks ifeq ($(NO_OMP), 1) $(info OpenMP is manually disabled) else ifneq ($(OS), Windows_NT) # Detect if running on macOS or Linux ifeq ($(SHELL_UNAME), Darwin) # Check for Homebrew's libomp installation in different common directories ifeq ($(shell [ -d /opt/homebrew/opt/libomp/lib ] && echo "exists"), exists) # macOS with Homebrew on ARM (Apple Silicon) CFLAGS += -Xclang -fopenmp -DOMP LDFLAGS += -L/opt/homebrew/opt/libomp/lib LDLIBS += -lomp INCLUDES += -I/opt/homebrew/opt/libomp/include $(info ✓ OpenMP found) else ifeq ($(shell [ -d /usr/local/opt/libomp/lib ] && echo "exists"), exists) # macOS with Homebrew on Intel CFLAGS += -Xclang -fopenmp -DOMP LDFLAGS += -L/usr/local/opt/libomp/lib LDLIBS += -lomp INCLUDES += -I/usr/local/opt/libomp/include $(info ✓ OpenMP found) else $(info ✗ OpenMP not found) endif else # Check for OpenMP support in GCC or Clang on Linux ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0) CFLAGS += -fopenmp -DOMP LDLIBS += -lgomp $(info ✓ OpenMP found) else $(info ✗ OpenMP not found) endif endif endif endif # Check if NCCL is available, include if so, for multi-GPU training ifeq ($(NO_MULTI_GPU), 1) $(info → Multi-GPU (NCCL) is manually disabled) else ifneq ($(OS), Windows_NT) # Detect if running on macOS or Linux ifeq ($(SHELL_UNAME), Darwin) $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping NCCL support) else ifeq ($(shell dpkg -l | grep -q nccl && echo "exists"), exists) $(info ✓ NCCL found, OK to train with multiple GPUs) NVCC_FLAGS += -DMULTI_GPU NVCC_LDLIBS += -lnccl else $(info ✗ NCCL is not found, disabling multi-GPU support) $(info ---> On Linux you can try install NCCL with `sudo apt install libnccl2 libnccl-dev`) endif endif endif # Attempt to find and include OpenMPI on the system OPENMPI_DIR ?= /usr/lib/x86_64-linux-gnu/openmpi OPENMPI_LIB_PATH = $(OPENMPI_DIR)/lib/ OPENMPI_INCLUDE_PATH = $(OPENMPI_DIR)/include/ ifeq ($(NO_USE_MPI), 1) $(info → MPI is manually disabled) else ifeq ($(shell [ -d $(OPENMPI_LIB_PATH) ] && [ -d $(OPENMPI_INCLUDE_PATH) ] && echo "exists"), exists) $(info ✓ MPI enabled) NVCC_INCLUDES += -I$(OPENMPI_INCLUDE_PATH) NVCC_LDFLAGS += -L$(OPENMPI_LIB_PATH) NVCC_LDLIBS += -lmpi NVCC_FLAGS += -DUSE_MPI else $(info ✗ MPI not found) endif # Precision settings, default to bf16 but ability to override PRECISION ?= BF16 VALID_PRECISIONS := FP32 FP16 BF16 ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),) $(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS)) endif ifeq ($(PRECISION), FP32) PFLAGS = -DENABLE_FP32 else ifeq ($(PRECISION), FP16) PFLAGS = -DENABLE_FP16 else PFLAGS = -DENABLE_BF16 endif # PHONY means these targets will always be executed .PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu # Add targets TARGETS = train_gpt2 test_gpt2 # Conditional inclusion of CUDA targets ifeq ($(NVCC),) $(info ✗ nvcc not found, skipping GPU/CUDA builds) else $(info ✓ nvcc found, including GPU/CUDA support) TARGETS += train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu $(NVCC_CUDNN) endif $(info ---------------------------------------------) all: $(TARGETS) train_gpt2: train_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $^ $(LDLIBS) $(OUTPUT_FILE) test_gpt2: test_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $^ $(LDLIBS) $(OUTPUT_FILE) $(NVCC_CUDNN): llmc/cudnn_att.cpp $(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_INCLUDES) -o $@ train_gpt2cu: train_gpt2.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) test_gpt2cu: test_gpt2.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) test_gpt2fp32cu: test_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) clean: $(REMOVE_FILES) $(TARGETS) $(REMOVE_BUILD_OBJECT_FILES) ================================================ FILE: README.md ================================================ # llm.c LLMs in simple, pure C/CUDA with no need for 245MB of PyTorch or 107MB of cPython. Current focus is on pretraining, in particular reproducing the [GPT-2](https://github.com/openai/gpt-2) and [GPT-3](https://arxiv.org/abs/2005.14165) miniseries, along with a parallel PyTorch reference implementation in [train_gpt2.py](train_gpt2.py). You'll recognize this file as a slightly tweaked [nanoGPT](https://github.com/karpathy/nanoGPT), an earlier project of mine. Currently, llm.c is a bit faster than PyTorch Nightly (by about 7%). In addition to the bleeding edge mainline code in [train_gpt2.cu](train_gpt2.cu), we have a simple reference CPU fp32 implementation in ~1,000 lines of clean code in one file [train_gpt2.c](train_gpt2.c). I'd like this repo to only maintain C and CUDA code. Ports to other languages or repos are very welcome, but should be done in separate repos, and I am happy to link to them below in the "notable forks" section. Developer coordination happens in the [Discussions](https://github.com/karpathy/llm.c/discussions) and on Discord, either the `#llmc` channel on the [Zero to Hero](https://discord.gg/3zy8kqD9Cp) channel, or on `#llmdotc` on [GPU MODE](https://discord.gg/gpumode) Discord. ## quick start The best introduction to the llm.c repo today is reproducing the GPT-2 (124M) model. [Discussion #481](https://github.com/karpathy/llm.c/discussions/481) steps through this in detail. We can reproduce other models from the GPT-2 and GPT-3 series in both llm.c and in the parallel implementation of PyTorch. Have a look at the [scripts README](scripts/README.md). debugging tip: when you run the `make` command to build the binary, modify it by replacing `-O3` with `-g` so you can step through the code in your favorite IDE (e.g. vscode). ## quick start (1 GPU, fp32 only) If you won't be training on multiple nodes, aren't interested in mixed precision, and are interested in learning CUDA, the fp32 (legacy) files might be of interest to you. These are files that were "checkpointed" early in the history of llm.c and frozen in time. They are simpler, more portable, and possibly easier to understand. Run the 1 GPU, fp32 code like this: ```bash chmod u+x ./dev/download_starter_pack.sh ./dev/download_starter_pack.sh make train_gpt2fp32cu ./train_gpt2fp32cu ``` The download_starter_pack.sh script is a quick & easy way to get started and it downloads a bunch of .bin files that help get you off the ground. These contain: 1) the GPT-2 124M model saved in fp32, in bfloat16, 2) a "debug state" used in unit testing (a small batch of data, and target activations and gradients), 3) the GPT-2 tokenizer, and 3) the tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. Alternatively, instead of running the .sh script, you can re-create these artifacts manually as follows: ```bash pip install -r requirements.txt python dev/data/tinyshakespeare.py python train_gpt2.py ``` ## quick start (CPU) The "I am so GPU poor that I don't even have one GPU" section. You can still enjoy seeing llm.c train! But you won't go too far. Just like the fp32 version above, the CPU version is an even earlier checkpoint in the history of llm.c, back when it was just a simple reference implementation in C. For example, instead of training from scratch, you can finetune a GPT-2 small (124M) to output Shakespeare-like text, as an example: ```bash chmod u+x ./dev/download_starter_pack.sh ./dev/download_starter_pack.sh make train_gpt2 OMP_NUM_THREADS=8 ./train_gpt2 ``` If you'd prefer to avoid running the starter pack script, then as mentioned in the previous section you can reproduce the exact same .bin files and artifacts by running `python dev/data/tinyshakespeare.py` and then `python train_gpt2.py`. The above lines (1) download an already tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset and download the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference. The output looks like this on my MacBook Pro (Apple Silicon M3 Max): ``` [GPT-2] max_seq_len: 1024 vocab_size: 50257 num_layers: 12 num_heads: 12 channels: 768 num_parameters: 124439808 train dataset num_batches: 1192 val dataset num_batches: 128 num_activations: 73323776 val loss 5.252026 step 0: train loss 5.356189 (took 1452.121000 ms) step 1: train loss 4.301069 (took 1288.673000 ms) step 2: train loss 4.623322 (took 1369.394000 ms) step 3: train loss 4.600470 (took 1290.761000 ms) ... (trunctated) ... step 39: train loss 3.970751 (took 1323.779000 ms) val loss 4.107781 generating: --- Come Running Away, Greater conquer With the Imperial blood the heaviest host of the gods into this wondrous world beyond. I will not back thee, for how sweet after birth Netflix against repounder, will not flourish against the earlocks of Allay --- ``` ## datasets The data files inside `/dev/data/(dataset).py` are responsible for downloading, tokenizing and saving the tokens to .bin files, readable easily from C. So for example when you run: ```bash python dev/data/tinyshakespeare.py ``` We download and tokenize the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. The output of this looks like this: ``` writing 32,768 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_val.bin writing 305,260 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_train.bin ``` The .bin files contain a short header (1024 bytes) and then a stream of tokens in uint16, indicating the token ids with the GPT-2 tokenizer. More datasets are available in `/dev/data`. ## test I am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. On the CPU as an example, compile and run with: ```bash make test_gpt2 ./test_gpt2 ``` This now loads the `gpt2_124M_debug_state.bin` file that gets written by train_gpt2.py, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch. To test the GPU version we run: ```bash # fp32 test (cudnn not supported) make test_gpt2cu PRECISION=FP32 && ./test_gpt2cu # mixed precision cudnn test make test_gpt2cu USE_CUDNN=1 && ./test_gpt2cu ``` This tests both the fp32 path and the mixed precision path. The test should pass and print `overall okay: 1`. ## tutorial I attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layernorm/layernorm.md). It's a simple, step-by-step guide to implementing a single layer of the GPT-2 model, the layernorm layer. This is a good starting point to understand how the layers are implemented in C. **flash attention**. As of May 1, 2024 we use the Flash Attention from cuDNN. Because cuDNN bloats the compile time from a few seconds to ~minute and this code path is right now very new, this is disabled by default. You can enable it by compiling like this: ```bash make train_gpt2cu USE_CUDNN=1 ``` This will try to compile with cudnn and run it. You have to have cuDNN installed on your system. The [cuDNN installation instructions](https://developer.nvidia.com/cudnn) with apt-get will grab the default set of cuDNN packages. For a minimal setup, the cuDNN dev package is sufficient, e.g. on Ubuntu 22.04 for CUDA 12.x: ```bash wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get -y install libcudnn9-dev-cuda-12 ``` On top of this you need the [cuDNN frontend](https://github.com/NVIDIA/cudnn-frontend/tree/main), but this is just header files. Simply clone the repo to your disk. The Makefile currently looks for it in either your home directory or the current directory. If you have put it elsewhere, add `CUDNN_FRONTEND_PATH=/path/to/your/cudnn-frontend/include` to the `make` command-line. ## multi-GPU training Make sure you install MPI and NCCL, e.g. on Linux: ```bash sudo apt install openmpi-bin openmpi-doc libopenmpi-dev ``` For NCCL follow the instructions from the [official website](https://developer.nvidia.com/nccl/nccl-download) (e.g. network installer) and then: ```bash make train_gpt2cu mpirun -np ./train_gpt2cu ``` or simply run one of our scripts under `./scripts/`. ## multi-node training Make sure you've installed `NCCL` following instructions from [multi-GPU](#multi-gpu-training) section. There are 3 ways we currently support that allow you to run multi-node training: 1) Use OpenMPI to exchange nccl id and initialize NCCL. See e.g. `./scripts/multi_node/run_gpt2_124M_mpi.sh` script for details. 2) Use shared file system to init NCCL. See `./scripts/multi_node/run_gpt2_124M_fs.sbatch` script for details. 3) Use TCP sockets to init NCCL. See `./scripts/multi_node/run_gpt2_124M_tcp.sbatch` script for details. Note: * If you're running in a slurm environment and your slurm doesn't support PMIx (which we assume will be a common situation given that `slurm-wlm` dropped PMIx support) you will have to use FS (2) or TCP (3) approach. To test whether your slurm supports PMIx run: `srun --mpi=list` and see whether you get `pmix` in the output. * If you don't have slurm set up, you can kick off a multi-node run using `mpirun` - MPI (1). None of these 3 methods is superior, we just offer you options so that you can run in your specific environment. ## experiments / sweeps Just as an example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`): ```bash #!/bin/bash learning_rates=(3e-5 1e-4 3e-4 1e-3) for i in {0..3}; do export CUDA_VISIBLE_DEVICES=$i screen -dmS "tr$i" bash -c "./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -l ${learning_rates[$i]} -o stories$i.log" done # you can bring these down with # screen -ls | grep -E "tr[0-3]" | cut -d. -f1 | xargs -I {} screen -X -S {} quit ``` This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. A quick example of how to parse and plot these logfiles is in [dev/vislog.ipynb](dev/vislog.ipynb). ## repo A few more words on what I want this repo to be: First, I want `llm.c` to be a place for education. E.g. our `dev/cuda` folder is a place for a library of kernels for all the layers that are manually hand-written and very well documented, starting from very simple kernels all the way to more complex / faster kernels. If you have a new kernel with various different tradeoffs, please feel free to contribute it here. That said, I also want `llm.c` to be very fast too, even practically useful to train networks. E.g. to start, we should be able to reproduce the big GPT-2 (1.6B) training run. This requires that we incorporate whatever fastest kernels there are, including the use of libraries such as cuBLAS, cuBLASLt, CUTLASS, cuDNN, etc. I also think doing so serves an educational purpose to establish an expert upper bound, and a unit of measurement, e.g. you could say that your manually written kernels are 80% of cuBLAS speed, etc. Then you can choose to do a super fast run, or you can choose to "drag and drop" whatever manual kernels you wish to use, and run with those. However, as a constraint, I want to keep the mainline `llm.c` in the root folder simple and readable. If there is a PR that e.g. improves performance by 2% but it "costs" 500 lines of complex C code, and maybe an exotic 3rd party dependency, I may reject the PR because the complexity is not worth it. As a concrete example - making cuBLAS for matmuls the default in the root training loop is a no-brainer: it makes the mainline code much faster, it is a single line of interpretable code, and it is a very common dependency. On the side of this, we can have manual implementations that can compete with cuBLAS in `dev/cuda`. Lastly, I will be a lot more sensitive to complexity in the root folder of the project, which contains the main / default files of the project. In comparison, the `dev/` folder is a bit more of a scratch space for us to develop a library of kernels or classes and share useful or related or educational code, and some of this code could be ok to be (locally) complex. ## notable forks - AMD support - [llm.c](https://github.com/anthonix/llm.c) by @[anthonix](https://github.com/anthonix): support for AMD devices, such as the 7900 XTX - C# - [llm.cs](https://github.com/azret/llm.cs) by @[azret](https://github.com/azret): a C# port of this project - [Llm.cs](https://github.com/nietras/Llm.cs) by @[nietras](https://github.com/nietras): a C# port of this project with focus on easy to get started on any platform. Clone and run ✅ - CUDA C++ - [llm.cpp](https://github.com/gevtushenko/llm.c) by @[gevtushenko](https://github.com/gevtushenko): a port of this project using the [CUDA C++ Core Libraries](https://github.com/NVIDIA/cccl) - A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [GPU MODE Discord Server](https://discord.gg/cudamode) - C++/CUDA - [llm.cpp](https://github.com/zhangpiu/llm.cpp/tree/master/llmcpp) by @[zhangpiu](https://github.com/zhangpiu): a port of this project using the [Eigen](https://gitlab.com/libeigen/eigen), supporting CPU/CUDA. - WebGPU C++ - [gpu.cpp](https://github.com/AnswerDotAI/gpu.cpp) by @[austinvhuang](https://github.com/austinvhuang): a library for portable GPU compute in C++ using native WebGPU. Aims to be a general-purpose library, but also porting llm.c kernels to WGSL. - C++ - [llm.cpp](https://github.com/GaoYusong/llm.cpp) by @[GaoYusong](https://github.com/GaoYusong): a port of this project featuring a C++ single-header [tinytorch.hpp](https://github.com/GaoYusong/llm.cpp/blob/main/tinytorch.hpp) library - Go - [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project - Java - [llm.java](https://github.com/harryjackson/llm.java) by @[harryjackson](https://github.com/harryjackson): a Java port of this project - Metal - [llm.metal](https://github.com/regrettable-username/llm.metal) by @[regrettable-username](https://github.com/regrettable-username): LLM training in simple, raw C/Metal Shading Language - Mojo - [llm.🔥](https://github.com/dorjeduck/llm.mojo) by @[dorjeduck](https://github.com/dorjeduck): a Mojo port of this project - OpenCL - [llm.c](https://github.com/krrishnarraj/llm.c) by @[krrishnarraj](https://github.com/krrishnarraj): an OpenCL port of this project - Rust - [llm.rs](https://github.com/yijunyu/llm.rs) by @[Yijun Yu](https://github.com/yijunyu): a Rust rewrite with the aim to have same performance - [llm.rs](https://github.com/ToJen/llm.rs) by @[ToJen](https://github.com/ToJen): a Rust port of this project - Swift - [llm.swift](https://github.com/otabuzzman/llm.swift) by @[otabuzzman](https://github.com/otabuzzman): a Swift port of this project - Zig - [llm.zig](https://github.com/Saimirbaci/llm.zig) by @[saimirbaci](https://github.com/Saimirbaci): a Zig port of this project - Habana Gaudi2 - [llm.tpc](https://github.com/abhilash1910/llm.tpc) by @[abhilash1910](https://github.com/abhilash1910): a Habana Gaudi2 port of this project - Nim - [llm.nim](https://github.com/Vindaar/llm.nim) by @[Vindaar](https://github.com/Vindaar): a Nim port of this project ## discussions Ways of organizing development: - Experiencing a concrete issue with the repo? Use [Issues](https://github.com/karpathy/llm.c/issues). - Have some code to contribute? Open a [PR](https://github.com/karpathy/llm.c/pulls) - Chat about the repo, ask questions, etc.? Look at [Discussions](https://github.com/karpathy/llm.c/discussions). - Something faster? I created a new `#llmc` channel on my [Zero to Hero Discord channel](https://discord.gg/3zy8kqD9Cp). ## license MIT ================================================ FILE: dev/cpu/matmul_forward.c ================================================ /* CPU Kernels for matmul forward pass. */ // Compile Examples: // // MSVC: cl.exe /O2 /fp:fast /Qvec-report:2 /I. /I ..\..\dev matmul_forward.c // cl.exe /O2 /fp:fast /Qvec-report:2 /arch:AVX /I. /I ..\..\dev matmul_forward.c // cl.exe /O2 /fp:fast /Qvec-report:2 /arch:AVX2 /I. /I ..\..\dev matmul_forward.c // #include #include #include #include #include // ---------------------------------------------------------------------------- // CPU code reference void matmul_forward_cpu(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { // OC is short for "output channels" // inp is (B,T,C), weight is (OC, C), bias is (OC) // out will be (B,T,OC) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* out_bt = out + b * T * OC + t * OC; const float* inp_bt = inp + b * T * C + t * C; for (int o = 0; o < OC; o++) { float val = (bias != NULL) ? bias[o] : 0.0f; const float* wrow = weight + o*C; for (int i = 0; i < C; i++) { val += inp_bt[i] * wrow[i]; } out_bt[o] = val; } } } } void matmul_forward_ngc92(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_backward // OC is short for "output channels" // inp is (B,T,C), weight is (OC, C), bias is (OC) // out will be (B,T,OC) // make sure the tiled loop will be correct, otherwise, fallback to slow version #define LOOP_UNROLL 8 if (B * T % LOOP_UNROLL != 0) { printf("MUST BE A MULTIPLE OF 8"); // FIXME return; } // collapse the B and T loops into one and turn it into a strided loop. // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times // for significant speed-ups. for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) { for (int o = 0; o < OC; o++) { // keep LOOP_UNROLL many results in register, initialized by the bias term. float result[LOOP_UNROLL]; for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { result[ibt] = (bias != NULL) ? bias[o] : 0.0f; } // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache // the value of weight[i + o * C] and reuse it. // we compile with -Ofast, so the compiler will turn the inner loop into a bunch of FMAs for (int i = 0; i < C; i++) { float w = weight[i + o * C]; for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { int bt = obt + ibt; result[ibt] += inp[bt * C + i] * w; } } // write back results to main memory for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { int bt = obt + ibt; out[bt * OC + o] = result[ibt]; } } } } #define NUM_KERNELS 2 void matmul_forward(int kernel_num, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { switch (kernel_num) { case 0: matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC); break; case 1: matmul_forward_ngc92(out, inp, weight, bias, B, T, C, OC); break; default: printf("Invalid kernel number\n"); exit(1); } } void validate_results_cpu(const float* device_result, const float* cpu_reference, const char* name, int num_elements, float tolerance); float* make_random_float(size_t N); int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int C = 768; int OC = 768 * 4; // expansion of 4, e.g. in the MLP int RUNS = 4; // number of times to run a kernel for benchmarks srand(137); float* out = make_random_float(B * T * OC); float* inp = make_random_float(B * T * C); float* weight = make_random_float(OC * C); float* bias = make_random_float(OC); float* grad_out = make_random_float(B * T * OC); float* grad_inp = make_random_float(B * T * C); float* grad_weight = make_random_float(OC * C); float* grad_bias = make_random_float(OC); printf("> Calculating reference\n"); matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC); for (int kernel_num = 0; kernel_num < NUM_KERNELS; kernel_num++) { printf("> Verifying kernel #%d\n", kernel_num); srand(137); float* kernel_out = make_random_float(B * T * OC); float* kernel_inp = make_random_float(B * T * C); float* kernel_weight = make_random_float(OC * C); float* kernel_bias = make_random_float(OC); matmul_forward(kernel_num, kernel_out, kernel_inp, kernel_weight, kernel_bias, B, T, C, OC); validate_results_cpu(kernel_out, out, "out", B * T * OC, 1e-5); free(kernel_out); free(kernel_inp); free(kernel_weight); free(kernel_bias); } printf("All kernels passed! Starting benchmarks.\n\n"); for (int kernel_num = 0; kernel_num < NUM_KERNELS; kernel_num++) { printf("> Running kernel #%d\n", kernel_num); struct timespec start, end; clock_gettime(CLOCK_MONOTONIC, &start); for (int i = 0; i < RUNS; i++) { matmul_forward(kernel_num, out, inp, weight, bias, B, T, C, OC); } clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; printf("> Kernel #%d, (took %f ms)\n", kernel_num, time_elapsed_s * 1000); } // free memory free(out); free(inp); free(weight); free(bias); free(grad_out); free(grad_inp); free(grad_weight); free(grad_bias); return 0; } float* make_random_float(size_t N) { float* arr = (float*)malloc(N * sizeof(float)); for (size_t i = 0; i < N; i++) { arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; // range -1..1 } return arr; } void validate_results_cpu(const float* kernel_result, const float* cpu_reference, const char* name, int num_elements, float tolerance) { int nfaults = 0; for (int i = 0; i < num_elements; i++) { // print the first few comparisons if (i < 5) { printf("%f %f\n", cpu_reference[i], kernel_result[i]); } float t_eff = tolerance + fabs(cpu_reference[i]); // ensure correctness for all elements. if (fabs(cpu_reference[i] - kernel_result[i]) > t_eff) { printf("Mismatch of %s at %d: CPU_ref: %f vs CPU_new: %f\n", name, i, cpu_reference[i], kernel_result[i]); nfaults++; if (nfaults >= 10) { exit(EXIT_FAILURE); } } } if (nfaults > 0) { exit(EXIT_FAILURE); } printf("OK\n"); } ================================================ FILE: dev/cuda/Makefile ================================================ # Makefile for building dev/cuda kernels # Collects all the make commands in one file but each file also # has the compile and run commands in the header comments section. # Find nvcc (NVIDIA CUDA compiler) NVCC := $(shell which nvcc 2>/dev/null) ifeq ($(NVCC),) $(error nvcc not found.) endif ifneq ($(CI),true) # if not in CI, then use the GPU query ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY= GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY)) endif endif # Compiler flags ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY= CFLAGS = -O3 --use_fast_math else CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] endif NVCCFLAGS = -lcublas -lcublasLt -std=c++17 MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ # Default rule for our CUDA files %: %.cu $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ # Build all targets TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute all: $(TARGETS) all_ptx: $(TARGETS:%=%.ptx) all_sass: $(TARGETS:%=%.sass) # Individual targets: forward pass attention_forward: attention_forward.cu classifier_fused: classifier_fused.cu crossentropy_forward: crossentropy_forward.cu encoder_forward: encoder_forward.cu gelu_forward: gelu_forward.cu layernorm_forward: layernorm_forward.cu fused_residual_forward: fused_residual_forward.cu residual_forward: residual_forward.cu softmax_forward: softmax_forward.cu trimat_forward: trimat_forward.cu # matmul fwd/bwd also uses OpenMP (optionally) and cuBLASLt libs matmul_forward: matmul_forward.cu $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward # Individual targets: backward pass attention_backward: attention_backward.cu crossentropy_softmax_backward: crossentropy_softmax_backward.cu encoder_backward: encoder_backward.cu gelu_backward: gelu_backward.cu layernorm_backward: layernorm_backward.cu matmul_backward_bias: matmul_backward_bias.cu matmul_backward: matmul_backward.cu $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward # Update kernels adamw: adamw.cu global_norm: global_norm.cu permute: permute.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu $(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce # Generate PTX using cuobjdump %.ptx: % cuobjdump --dump-ptx $< > $@ # Generate SASS using cuobjdump %.sass: % cuobjdump --dump-sass $< > $@ # Run all targets run_all: all @for target in $(TARGETS); do \ echo "\n========================================"; \ echo "Running $$target ..."; \ echo "========================================\n"; \ ./$$target; \ done # Clean up clean: rm -f $(TARGETS) *.ptx *.sass ================================================ FILE: dev/cuda/README.md ================================================ # dev/cuda This directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, and usually multiple versions of that kernel that could have different running times and of different code or time complexity. See the top of each file for how to compile and run the kernel. Alternatively, the commands are also all grouped in the `Makefile` in this directory for convenience. For example, we can look at the top of `layernorm_forward.cu` to build the forward pass kernels for the LayerNorm: ```bash nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward ``` or simply ```bash make layernorm_forward ``` The comments at the top then document the different versions of this kernel available, usually these are in increasing complexity and decreasing running times. For example, inspecting the comments in the file on top, the most naive kernel we can then run as: ```bash ./layernorm_forward 1 ``` You'll see that this first forwards the reference code on the CPU, then it runs kernel 1 on the GPU, compares the results to check for correctness, and then runs a number of configurations of this kernel (most often and most notably the block size), to time the kernel in these launch configurations. We can then run one of the faster kernels (kernel 4) instead: ```bash ./layernorm_forward 4 ``` You'll see that this matches all the CPU results but runs much much faster. The typical process from here on is we copy paste the kernel that ran fastest, adjust it manually (e.g. to hardcode the best block size) and drop it into the training code file, e.g. `train_gpt2.cu`. To add a new version of a kernel, add the kernel to the corresponding file and adjust the docs. To add a new kernel, add the new file and adjust the Makefile. Run `make clean` to clean up binaries from your directory. If you do not have a GPU or is having trouble with CUDA dependencies, you can run the benchmarks on the [Modal platform](http://modal.com). For example, to run the benchmark for the attention forward pass on an A100 GPU with 80GB of memory, you can run the following command: ```bash GPU_MEM=80 modal run benchmark_on_modal.py --compile-command "nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas" --run-command "./attention_forward 1" ``` ================================================ FILE: dev/cuda/adamw.cu ================================================ /* Kernels for the AdamW optimizer. References: * https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html * https://github.com/nvidia/apex/blob/master/csrc/multi_tensor_adam.cu Compile example: nvcc -lcublas -lcublasLt adamw.cu -o adamw nvcc -O3 --use_fast_math -lcublas -lcublasLt adamw.cu -o adamw ./adamw TODO(general): amsgrad=True TODO(perf): dtype thread coarsening/ILP */ #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void adamw_cpu(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { // adapted from: train_gpt2.c for (int i = 0; i < num_parameters; i++) { float param = params_memory[i]; float grad = grads_memory[i]; // update the first moment (momentum) float m = beta1 * m_memory[i] + (1.0f - beta1) * grad; // update the second moment (RMSprop) float v = beta2 * v_memory[i] + (1.0f - beta2) * grad * grad; // bias-correct both moments float m_hat = m / (1.0f - powf(beta1, t)); float v_hat = v / (1.0f - powf(beta2, t)); // update m_memory[i] = m; v_memory[i] = v; params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param); } } // ---------------------------------------------------------------------------- // GPU kernels // utility functions // Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation). // Reference: https://developer.nvidia.com/blog/lerp-faster-cuda __device__ inline float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } // naive fused kernel __global__ void adamw_kernel1(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= num_parameters) return; // guard // update the first moment (momentum) m_memory[i] = beta1 * m_memory[i] + (1.0f - beta1) * grads_memory[i]; // update the second moment (RMSprop) v_memory[i] = beta2 * v_memory[i] + (1.0f - beta2) * grads_memory[i] * grads_memory[i]; float m_hat = m_memory[i] / beta1_correction; float v_hat = v_memory[i] / beta2_correction; params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * params_memory[i]); } // Slightly more optimized AdamW kernel by: // * loading data that is accessed more than once into registers, // * using optimized linear interpolation for the moment updates. __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= num_parameters) return; // guard float grad = grads_memory[i]; float m = m_memory[i]; float v = v_memory[i]; // update the first moment (momentum) m = lerp(grad, m, beta1); m_memory[i] = m; // update the second moment (RMSprop) v = lerp(grad * grad, v, beta2); v_memory[i] = v; m /= beta1_correction; // m_hat v /= beta2_correction; // v_hat params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); } // ---------------------------------------------------------------------------- // kernel launcher // version 1: naive dispatch to naive kernel void adamw_dispatch1(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { unsigned int block_size = 512; unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); adamw_kernel1<<>>(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); cudaCheck(cudaGetLastError()); } // version 2: naive dispatch to slightly optimized kernel void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { unsigned int block_size = 512; unsigned int num_blocks = ceil_div(num_parameters, (long) block_size); adamw_kernel2<<>>(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); cudaCheck(cudaGetLastError()); } void adamw(int kernel_num, float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) { // calculate the m_hat and v_hat correction terms once as they are the same for every param/thread float beta1_correction = 1.0f - powf(beta1, t); float beta2_correction = 1.0f - powf(beta2, t); switch (kernel_num) { case 1: adamw_dispatch1(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; case 2: adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); const long num_parameters = 1048576; const int t = 10; const float learning_rate = 1e-3f; const float beta1 = 0.9f; const float beta2 = 0.999f; const float eps = 1e-8f; const float weight_decay = 0.0f; // create random data on host (to be used for the CPU reference implementation) float* params_memory = make_random_float(num_parameters); float* grads_memory = make_random_float(num_parameters); float* m_memory = make_random_float(num_parameters); float* v_memory = make_random_float_01(num_parameters); // move to GPU float* d_params_memory; float* d_grads_memory; float* d_m_memory; float* d_v_memory; cudaCheck(cudaMalloc(&d_params_memory, num_parameters * sizeof(float))); cudaCheck(cudaMalloc(&d_grads_memory, num_parameters * sizeof(float))); cudaCheck(cudaMalloc(&d_m_memory, num_parameters * sizeof(float))); cudaCheck(cudaMalloc(&d_v_memory, num_parameters * sizeof(float))); cudaCheck(cudaMemcpy(d_params_memory, params_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_grads_memory, grads_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_m_memory, m_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_v_memory, v_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // calculate the CPU reference (using default hyperparams) clock_t start = clock(); adamw_cpu(params_memory, grads_memory, m_memory, v_memory, t, num_parameters); clock_t end = clock(); // TODO: measure runtime with multiple runs double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC; // calculate the GPU version (using default hyperparams) adamw(kernel_num, d_params_memory, d_grads_memory, d_m_memory, d_v_memory, t, num_parameters); // compare printf("Checking correctness...\n"); printf("parameters:\n"); validate_result(d_params_memory, params_memory, "params_memory", num_parameters); printf("first moment:\n"); validate_result(d_m_memory, m_memory, "m_memory", num_parameters); printf("second moment:\n"); validate_result(d_v_memory, v_memory, "v_memory", num_parameters); printf("All results match.\n\n"); // now benchmark the kernel int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, adamw, kernel_num, d_params_memory, d_grads_memory, d_m_memory, d_v_memory, t, num_parameters, learning_rate, beta1, beta2, eps, weight_decay); printf("time gpu %.4f ms\n", elapsed_time); printf("time cpu %.4f ms\n", elapsed_time_cpu); // cleanup free(params_memory); free(grads_memory); free(m_memory); free(v_memory); cudaCheck(cudaFree(d_params_memory)); cudaCheck(cudaFree(d_grads_memory)); cudaCheck(cudaFree(d_m_memory)); cudaCheck(cudaFree(d_v_memory)); return 0; } ================================================ FILE: dev/cuda/attention_backward.cu ================================================ /* Kernels for attention backward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt attention_backward.cu -o attention_backward version 1 is a naive first version OMP_NUM_THREADS=32 ./attention_backward 1 version 2 much ensures better load-balancing by having independent threads for each batch and attention head OMP_NUM_THREADS=32 ./attention_backward 2 version 3 uses a full warp to calculate each result (instead of a thread), which enables coalesced memory access OMP_NUM_THREADS=32 ./attention_backward 3 version 4 improves data reuse in registers by doing 8 values of t3 in one warp. OMP_NUM_THREADS=32 ./attention_backward 4 version 5 reduces the amount of non-fp32 instructions needed by avoiding ifs OMP_NUM_THREADS=32 ./attention_backward 5 */ #include #include #include #include #include #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference /* NOTE: This version of attention_forward is modified to be consistent with the attention_forward GPU kernel in the following way small but important way: - preatt is only QUERY @ KEY, without the scale - the scale instead moved and fused into the softmax - the full preatt matrix is materialized, even the parts that get masked out - this doesn't actually change anything due to masking, but it lets us easily compare to the GPU version, which also does the full, dense sgemm In this way we'll be able to make sure that preatt and att agree CPU vs GPU */ void attention_forward_cpu(float* out, float* preatt, float* att, float* inp, int B, int T, int C, int NH) { // input is (B, T, 3C) holding the query, key, value (Q, K, V) vectors // preatt, att are (B, NH, T, T). NH = number of heads, T = sequence length // that holds the pre-attention and post-attention scores (used in backward) // output is (B, T, C) // attention is the only layer that mixes information across time // every other operation is applied at every (b,t) position independently // (and of course, no layer mixes information across batch) int C3 = C*3; int hs = C / NH; // head size float scale = 1.0 / sqrtf(hs); #pragma omp parallel for collapse(3) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { for (int h = 0; h < NH; h++) { float* query_t = inp + b * T * C3 + t * C3 + h * hs; float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; float* att_bth = att + b*NH*T*T + h*T*T + t*T; // pass 1: calculate query dot key and maxval float maxval = -FLT_MAX; for (int t2 = 0; t2 < T; t2++) { // used to be t2 <= t float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key // (query_t) dot (key_t2) float val = 0.0f; for (int i = 0; i < hs; i++) { val += query_t[i] * key_t2[i]; } if (val > maxval) { maxval = val; } preatt_bth[t2] = val; } // pass 2: calculate the exp and keep track of sum // maxval is being calculated and subtracted only for numerical stability float expsum = 0.0f; for (int t2 = 0; t2 <= t; t2++) { float expv = expf(scale * (preatt_bth[t2] - maxval)); expsum += expv; att_bth[t2] = expv; } float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; // pass 3: normalize to get the softmax for (int t2 = 0; t2 < T; t2++) { if (t2 <= t) { att_bth[t2] *= expsum_inv; } else { // causal attention mask. not strictly necessary to set to zero here // only doing this explicitly for debugging and checking to PyTorch att_bth[t2] = 0.0f; } } // pass 4: accumulate weighted values into the output of attention float* out_bth = out + b * T * C + t * C + h * hs; for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; } for (int t2 = 0; t2 <= t; t2++) { float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value float att_btht2 = att_bth[t2]; for (int i = 0; i < hs; i++) { out_bth[i] += att_btht2 * value_t2[i]; } } } } } } // NOTE: Also contains the re-shuffling of the exact position of "scale" // and when it is applied (after preatt, not "during" preatt) // also, full matrices are materialized, even the parts that get masked out void attention_backward_cpu(float* dinp, float* dpreatt, float* datt, float* dout, float* inp, float* att, int B, int T, int C, int NH) { // inp/dinp are (B, T, 3C) Q,K,V // att/datt/dpreatt are (B, NH, T, T) // dout is (B, T, C) int C3 = C*3; int hs = C / NH; // head size float scale = 1.0 / sqrtf(hs); for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { for (int h = 0; h < NH; h++) { float* att_bth = att + b*NH*T*T + h*T*T + t*T; float* datt_bth = datt + b*NH*T*T + h*T*T + t*T; float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T; float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs; float* query_t = inp + b * T * C3 + t * C3 + h * hs; // backward pass 4, through the value accumulation float* dout_bth = dout + b * T * C + t * C + h * hs; for (int t2 = 0; t2 < T; t2++) { // ADJUSTED! this was t2 <= t (see note on function) float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2; for (int i = 0; i < hs; i++) { // in the forward pass this was: // out_bth[i] += att_bth[t2] * value_t2[i]; // so now we have: datt_bth[t2] += value_t2[i] * dout_bth[i]; dvalue_t2[i] += att_bth[t2] * dout_bth[i]; } } // backward pass 2 & 3, the softmax // note that softmax (like e.g. tanh) doesn't need the input (preatt) to backward for (int t2 = 0; t2 <= t; t2++) { for (int t3 = 0; t3 <= t; t3++) { float indicator = t2 == t3 ? 1.0f : 0.0f; float local_derivative = att_bth[t2] * (indicator - att_bth[t3]); dpreatt_bth[t3] += scale * local_derivative * datt_bth[t2]; } } // backward pass 1, the query @ key matmul for (int t2 = 0; t2 <= t; t2++) { float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key for (int i = 0; i < hs; i++) { // in the forward pass this was: // preatt_bth[t2] += query_t[i] * key_t2[i] // so now we have: dquery_t[i] += key_t2[i] * dpreatt_bth[t2]; dkey_t2[i] += query_t[i] * dpreatt_bth[t2]; } } } } } } // ---------------------------------------------------------------------------- // GPU kernels // the forward pass that is the sequence [permute, sgemm, softmax, sgemm, unpermute] __global__ void permute_kernel(float* q, float* k, float* v, const float* inp, int B, int N, int NH, int d) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; q[idx] = inp[inp_idx]; k[idx] = inp[inp_idx + NH * d]; v[idx] = inp[inp_idx + 2 * (NH * d)]; } } __global__ void permute_kernel_backward(float* dinp, const float* dq, const float* dk, const float* dv, int B, int N, int NH, int d) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; dinp[inp_idx] += dq[idx]; dinp[inp_idx + NH * d] += dk[idx]; dinp[inp_idx + 2 * (NH * d)] += dv[idx]; } } __global__ void unpermute_kernel(const float* inp, float *out, int B, int N, int NH, int d) { // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; out[other_idx] = inp[idx]; } } __global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; dinp[idx] += dout[other_idx]; } } __device__ float& vec_at(float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __device__ float vec_at(const float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) { // inp, out shape: (N, T, T), where N = B * NH // fuses the multiplication by scale inside attention // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N * T) { return; } int own_pos = idx % T; int pos_by_4 = own_pos / 4; // one row of inp, i.e. inp[idx, :] of shape (T,) const float* x = inp + idx * T; // not INF, so we don't get NaNs accidentally when subtracting two values. float maxval = -FLT_MAX; float sumval = 0.0f; const float4* x_vec = reinterpret_cast(x); for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) { float4 v = x_vec[i]; float old_maxval = maxval; for(int k = 0; k < 4; ++k) { maxval = fmaxf(maxval, vec_at(v, k)); } sumval *= expf(inv_temperature * (old_maxval - maxval)); for(int k = 0; k < 4; ++k) { sumval += expf(inv_temperature * (vec_at(v, k) - maxval)); } } if(4*pos_by_4 + warp.thread_rank() <= own_pos) { float old_maxval = maxval; maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]); sumval *= expf(inv_temperature * (old_maxval - maxval)); sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval)); } float global_maxval = cg::reduce(warp, maxval, cg::greater{}); sumval *= expf(inv_temperature * (maxval - global_maxval)); float sum = cg::reduce(warp, sumval, cg::plus{}); float norm = 1.f / sum; // divide the whole row by the sum for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, ev * norm); } } // naive kernel to backward through an autoregressive softmax, just to get correctness __global__ void softmax_autoregressive_backward_kernel1(float* dpreatt, const float* datt, const float* att, int B, int T, int C, int NH) { // dpreatt, datt, att are all (B, NH, T, T) int t3 = blockIdx.x * blockDim.x + threadIdx.x; if (t3 < T) { int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); for (int b = 0; b < B; b++) { for (int h = 0; h < NH; h++) { for (int t = t3; t < T; t++) { const float* att_bth = att + b*NH*T*T + h*T*T + t*T; const float* datt_bth = datt + b*NH*T*T + h*T*T + t*T; float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T; float accum = 0.0f; for (int t2 = 0; t2 <= t; t2++) { float indicator = t2 == t3 ? 1.0f : 0.0f; float local_derivative = att_bth[t2] * (indicator - att_bth[t3]); accum += scale * local_derivative * datt_bth[t2]; } dpreatt_bth[t3] = accum; } } } } } // parallelize across t,b,h __global__ void softmax_autoregressive_backward_kernel2(float* dpreatt, const float* datt, const float* att, int B, int T, int C, int NH) { int t3 = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.y * T * T; if (t3 >= T) { return; } int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); for (int t = t3; t < T; t++) { float result = 0.0; const float* att_bth = att + idx + t*T; const float* datt_bth = datt + idx + t*T; float* dpreatt_bth = dpreatt + idx + t*T; for (int t2 = 0; t2 <= t; t2++) { float indicator = t2 == t3 ? 1.0f : 0.0f; float local_derivative = att_bth[t2] * (indicator - att_bth[t3]); result += scale * local_derivative * datt_bth[t2]; } dpreatt_bth[t3] = result; } } // parallelize across t,b,h __global__ void softmax_autoregressive_backward_kernel3(float* dpreatt, const float* datt, const float* att, int B, int T, int C, int NH) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int t3 = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); int idx = blockIdx.y * T * T; if (t3 >= T) { return; } int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); for (int t = t3; t < T; t++) { float result = 0.0; const float* att_bth = att + idx + t*T; const float* datt_bth = datt + idx + t*T; float* dpreatt_bth = dpreatt + idx + t*T; const float att_at_t3 = att_bth[t3]; for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { float indicator = t2 == t3 ? 1.0f : 0.0f; float local_derivative = att_bth[t2] * (indicator - att_at_t3); result += local_derivative * datt_bth[t2]; } result = cg::reduce(warp, result, cg::plus()); if(warp.thread_rank() == 0) { dpreatt_bth[t3] = scale * result; } } } __global__ void softmax_autoregressive_backward_kernel4(float* __restrict__ dpreatt, const float* __restrict__ datt, const float* __restrict__ att, int B, int T, int C, int NH) { constexpr int UNROLL = 8; namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int t3 = UNROLL * (blockIdx.x * warp.meta_group_size() + warp.meta_group_rank()); int idx = blockIdx.y * T * T; if (t3 >= T) { return; } int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); // the innermost loop combines different values of t2 with different values of t. // by handling [t3, t3 + UNROLL) in one thread, we get much better memory reuse: // any t3/t-dependent value can be loaded once before the t2 loop. // within the t2 loop, we can combine each loaded value with each of the UNROLL // pre-loaded values, thus cutting memory ready by a factor of ~UNROLL. // one iteration of this loop has to handle the cases // this may lead to some invalid indices; therefore, we have several // early-outs in the iteration over k below. for (int t = t3; t < T; t++) { float result[UNROLL] = {}; const float* att_bth = att + idx + t * T; const float* datt_bth = datt + idx + t * T; float* dpreatt_bth = dpreatt + idx + t * T; float att_at_t3[UNROLL]; for(int k = 0; k < UNROLL; ++k) { if (t < t3 + k) continue; att_at_t3[k] = att_bth[t3 + k]; } for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { float att_t2 = att_bth[t2]; float datt_t2 = datt_bth[t2]; for(int k = 0; k < UNROLL; ++k) { if (t < t3 + k) continue; float indicator = t2 == (t3 + k) ? 1.0f : 0.0f; float local_derivative = att_t2 * (indicator - att_at_t3[k]); result[k] += local_derivative * datt_t2; } } for(int k = 0; k < UNROLL; ++k) { result[k] = cg::reduce(warp, result[k], cg::plus()); } if (warp.thread_rank() < UNROLL) { dpreatt_bth[t3 + warp.thread_rank()] = scale * result[warp.thread_rank()]; } } } __global__ void softmax_autoregressive_backward_kernel5(float* __restrict__ dpreatt, const float* __restrict__ datt, const float* __restrict__ att, int B, int T, int C, int NH) { constexpr int UNROLL = 8; namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int t3 = UNROLL * (blockIdx.x * warp.meta_group_size() + warp.meta_group_rank()); int idx = blockIdx.y * T * T; if (t3 >= T) { return; } int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); for (int t = t3; t < T; t++) { float result[UNROLL] = {}; const float* att_bth = att + idx + t * T; const float* datt_bth = datt + idx + t * T; float* dpreatt_bth = dpreatt + idx + t * T; float att_at_t3[UNROLL]; for(int k = 0; k < UNROLL; ++k) { // if t < t3+k, we're out of bounds. // in that case, we don't care what we read, because later on, // we won't write the corresponding result. So just clip to // make sure this is a valid (in-bounds) memory access. att_at_t3[k] = att_bth[min(t, t3 + k)]; } // the code below is actually just a for loop; except, // we have to do something special in one iteration in // the middle, and an if turned out to have significant // performance impact. // so we split the loop in three parts. Ugly, but effective. // the beginning/end loop does the same thing, so we write the code // just once in a lambda. In this step, we're guaranteed that // indicator == 0 auto loop_step = [&](int t2){ float p = att_bth[t2] * datt_bth[t2]; for (int k = 0; k < UNROLL; ++k) { result[k] -= p * att_at_t3[k]; } }; // Now the actual loop. { // declare the loop iterator. Needs to be kept across the // three different parts, so it's not a local variable in // the for loop. int t2 = warp.thread_rank(); // first part, as long as t2 < t3, indicator == 0 for (; t2 < t3; t2 += warp.size()) { loop_step(t2); } // because k <= warp.size() (==32), the event that t3+k == t2 // has to happen at this particular step. static_assert(UNROLL <= 32, "UNROLL is too large, this won't produce correct results."); if (t2 <= t) { float att_t2 = att_bth[t2]; float datt_t2 = datt_bth[t2]; float p = att_t2 * datt_t2; for (int k = 0; k < UNROLL; ++k) { float indicator = t2 == (t3 + k) ? 1.0f : 0.0f; result[k] += p * (indicator - att_at_t3[k]); } t2 += warp.size(); } // rest of the loop, indicator == 0 again for (; t2 <= t; t2 += warp.size()) { loop_step(t2); } } for(int k = 0; k < UNROLL; ++k) { result[k] = cg::reduce(warp, result[k], cg::plus()); } // when storing, we need to check that this is actually a valid result. // here, warp.thread_rank() corresponds to `k` in the previous loops. if (warp.thread_rank() < UNROLL && t >= t3 + warp.thread_rank()) { dpreatt_bth[t3 + warp.thread_rank()] = scale * result[warp.thread_rank()]; } } } // I want `BlockSize` to be statically known to the compiler, thus we get a template here. // This kernel takes a step back, and looks at the original CPU code again. We have some simple outer loops // That are independent, (b, t, h), and then the inner loops over (t2, t3) where we're combining elements -- this is // where we can reuse data and be more efficient // => handle b, t, h through block indices; each block does all the work for the (t2, t3) loop cooperatively. // Now we have two nested loops, and in the inner instruction, we combine indexing from both => this calls for // loop tiling, and lifting some of the memory ops out of the loop. // We're in luck here; if we tile so that t3 is the outer loop, we can get a sinlge write op per result, AND also cache // the t2-indexed part of the computation, which is the problematic one because it contains a multiplication that now we // do not have to repeat over and over. // => do an outer t3 loop where each thread gets one t3 index. Then, do an outer t2 loop in steps of BlockSize, and // prepare BlockSize many elements for the inner loop. Here, each thread calculates one element and stores it in shmem. // Then, in the inner t2 loop, each thread reads *all* the elements previously stored and does its computations. // This way, we do 3*BlockSize loads, but BlockSize^2 computation steps => This kernel is now entirely compute bound. // To fix up the compute issues, as above, we replace ifs in memory reading with min, and also split the inner loop // into a large region where we don't have to calculate the indicator, and a small, costly region where we do. template __global__ void __launch_bounds__(BlockSize) softmax_autoregressive_backward_kernel6(float* dpreatt, const float* datt, const float* att, int B, int T, int C, int NH) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); __shared__ float att_bth_s[BlockSize]; int idx = blockIdx.y; int t = blockIdx.x; att += idx * T * T; datt += idx * T * T; dpreatt += idx * T * T; int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); const float* att_bth = att + t * T; const float* datt_bth = datt + t * T; float* dpreatt_bth = dpreatt + t * T; int block_steps = ceil_div(t+1, BlockSize); // very important: This loop condition needs to be the same for all threads. // even if a thread later on is not going to do any work, it needs to participate in the // data loading process! for (int t3f = 0; t3f < block_steps; ++t3f) { int t3 = t3f * BlockSize + block.thread_rank(); float acc = 0.f; float at3 = att_bth[t3]; for (int t2b = 0; t2b <= t; t2b += BlockSize) { int end = min(t + 1 - t2b, BlockSize); block.sync(); { int t2i = block.thread_rank(); int t2 = min(t, t2b + t2i); att_bth_s[t2i] = att_bth[t2] * datt_bth[t2]; } block.sync(); if(t3f * BlockSize == t2b) { for (int t2i = 0; t2i < end; t2i++) { int t2 = t2b + t2i; float indicator = t2 == t3 ? 1.0f : 0.0f; acc += att_bth_s[t2i] * (indicator - at3); } } else { for (int t2i = 0; t2i < end; t2i++) { acc += att_bth_s[t2i] * (0.f - at3); } } } dpreatt_bth[t3] = scale * acc; } } // Actually disentangling the loops and simplifying the resulting math gives us this pretty nice kernel. template __global__ void softmax_autoregressive_backward_kernel7(float* dpreatt, const float* datt, const float* att, int B, int T, int C, float scale) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); __shared__ float block_acc[32]; int idx = blockIdx.y; int t = blockIdx.x; att += idx * T * T; datt += idx * T * T; dpreatt += idx * T * T; const float* att_bth = att + t * T; const float* datt_bth = datt + t * T; float* dpreatt_bth = dpreatt + t * T; if(warp.meta_group_rank() == 0) { block_acc[warp.thread_rank()] = 0; } float local_sum = 0; for(int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) { local_sum += att_bth[t2] * datt_bth[t2]; } block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus{}); block.sync(); local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus{}); for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) { float acc = att_bth[t3] * (datt_bth[t3] - local_sum); dpreatt_bth[t3] = scale * acc; } } // The slightly less pretty version of kernel 7. Adding in all the dirty tricks that can give us a few more percent // - streaming memory access instructions // - reordering blocks to prevent tail effect // - multiple values of T per block template __global__ void softmax_autoregressive_backward_kernel8(float* dpreatt, const float* datt, const float* att, int B, int T, int C, float scale) { namespace cg = cooperative_groups; constexpr int T_per_block = 4; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); __shared__ float block_acc[32]; int idx = blockIdx.y; // go through blocks in reverse order, so the slowest block starts first int t0 = T - 1 - T_per_block*blockIdx.x; att += idx * T * T; datt += idx * T * T; dpreatt += idx * T * T; if (warp.meta_group_rank() == 0) { block_acc[warp.thread_rank()] = 0; } for(int to = 0; to < T_per_block; ++to) { int t = t0 - to; if(t < 0) return; const float* att_bth = att + t * T; const float* datt_bth = datt + t * T; float* dpreatt_bth = dpreatt + t * T; float local_sum = 0; for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) { local_sum += att_bth[t2] * datt_bth[t2]; } block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus{}); block.sync(); local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus{}); for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) { // don't touch the cache. Some parts will still be here from the previous loop, and // we want to exploit those. float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum); __stcs(dpreatt_bth + t3, scale * acc); } } } // ---------------------------------------------------------------------------- // kernel launchers // attention forward pass kernel void attention_forward(float* out, float* vaccum, float* qkvr, float* preatt, float* att, const float* inp, int B, int T, int C, int NH, const int block_size) { // inp is (B, T, 3C) QKV // preatt, att are (B, NH, T, T) // output is (B, T, C) int HS = C / NH; // head size // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) float *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; int num_blocks = ceil_div(total_threads, block_size); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); // batched matrix multiply with cuBLAS const float alpha = 1.0f; const float beta = 0.0f; cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH)); // multiply all elements of preatt elementwise by scale float scale = 1.0 / sqrtf(HS); int softmax_block_size = 256; int grid_size = ceil_div(B * NH * T * 32, softmax_block_size); softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); // new approach: first cuBLAS another batched matmul // vaccum = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH)); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = ceil_div(B * T * C, block_size); unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); } void launch_softmax_1(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int num_blocks = ceil_div(T, block_size); softmax_autoregressive_backward_kernel1<<>>(dpreatt, datt, att, B, T, C, NH); } void launch_softmax_2(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int num_blocks = ceil_div(T, block_size); softmax_autoregressive_backward_kernel2<<>>(dpreatt, datt, att, B, T, C, NH); } void launch_softmax_3(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int num_blocks = ceil_div(32*T, block_size); softmax_autoregressive_backward_kernel3<<>>(dpreatt, datt, att, B, T, C, NH); } void launch_softmax_4(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int num_blocks = ceil_div(32/8*T, block_size); softmax_autoregressive_backward_kernel4<<>>(dpreatt, datt, att, B, T, C, NH); } void launch_softmax_5(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int num_blocks = ceil_div(32/8*T, block_size); softmax_autoregressive_backward_kernel5<<>>(dpreatt, datt, att, B, T, C, NH); } template void dispatch_launch(Launcher&& launch, int block_size) { switch(block_size) { case 32: return launch(std::integral_constant{}); case 64: return launch(std::integral_constant{}); case 128: return launch(std::integral_constant{}); case 256: return launch(std::integral_constant{}); case 512: return launch(std::integral_constant{}); case 1024: return launch(std::integral_constant{}); default: assert(false && "Invalid block size"); } } void launch_softmax_6(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { auto launch = [&](auto int_const) { softmax_autoregressive_backward_kernel6<<>>(dpreatt, datt, att, B, T, C, NH); }; dispatch_launch(launch, block_size); } void launch_softmax_7(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); auto launch = [&](auto int_const) { constexpr int block_size = int_const.value; softmax_autoregressive_backward_kernel7<<>> (dpreatt, datt, att, B, T, C, scale); }; dispatch_launch(launch, block_size); } void launch_softmax_8(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) { int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); auto launch = [&](auto int_const) { constexpr int block_size = int_const.value; softmax_autoregressive_backward_kernel8<<>> (dpreatt, datt, att, B, T, C, scale); }; dispatch_launch(launch, block_size); } // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) template void attention_backward1(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* dvaccum, const float* dout, const float* inp, const float* qkvr, const float* preatt, const float* att, const float* vaccum, int B, int T, int C, int NH, SoftmaxKernel softmax_autoregressive_backward, const int block_size) { int HS = C / NH; // head size const float alpha = 1.0f; const float beta = 1.0f; // note beta = 1.0f so that we accumulate gradients (+=) // unpack convenience pointers into q, k, v const float *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; float *dq, *dk, *dv; dq = dqkvr + 0 * B * T * C; dk = dqkvr + 1 * B * T * C; dv = dqkvr + 2 * B * T * C; // backward through the unpermute operation int num_blocks = ceil_div(B * T * C, block_size); unpermute_kernel_backward<<>>(dvaccum, dout, B, T, NH, HS); cudaCheck(cudaGetLastError()); // backward into datt cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, v, HS, T * HS, dvaccum, HS, T * HS, &beta, datt, T, T * T, B * NH)); // backward into dv cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &alpha, dvaccum, HS, T * HS, att, T, T * T, &beta, dv, HS, T * HS, B * NH)); // backward into preatt softmax_autoregressive_backward(dpreatt, datt, att, B, T, C, NH, block_size); cudaCheck(cudaGetLastError()); // backward into q cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, k, HS, T * HS, dpreatt, T, T * T, &beta, dq, HS, T * HS, B * NH)); // backward into k cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &alpha, q, HS, T * HS, dpreatt, T, T * T, &beta, dk, HS, T * HS, B * NH)); // backward into inp num_blocks = ceil_div(B * NH * T * HS, block_size); permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); cudaCheck(cudaGetLastError()); } // kernel version dispatch void attention_backward(int kernel_num, float* dinp, float* dqkvr, float* dpreatt, float* datt, float* dvaccum, const float* dout, const float* inp, const float* qkvr, const float* preatt, const float* att, const float* vaccum, int B, int T, int C, int NH, const int block_size) { switch (kernel_num) { case 1: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_1, block_size); break; case 2: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_2, block_size); break; case 3: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_3, block_size); break; case 4: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_4, block_size); break; case 5: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_5, block_size); break; case 6: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_6, block_size); break; case 7: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_7, block_size); break; case 8: attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH, launch_softmax_8, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); // hyperparameters int B = 4; int T = 1024; int C = 768; int NH = 12; // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // create the host memory for the forward pass float* inp = make_random_float(B * T * 3 * C); float* qkvr = (float*)malloc(B * T * 3 * C * sizeof(float)); float* preatt = (float*)malloc(B * NH * T * T * sizeof(float)); float* att = (float*)malloc(B * NH * T * T * sizeof(float)); float* vaccum = (float*)malloc(B * T * C * sizeof(float)); float* out = (float*)malloc(B * T * C * sizeof(float)); // execute the forward pass on the CPU attention_forward_cpu(out, preatt, att, inp, B, T, C, NH); // create device memory for the forward pass float *d_inp, *d_qkvr, *d_preatt, *d_att, *d_vaccum, *d_out; cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); // copy over the input cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice)); // execute the forward pass on the GPU const int block_size = 256; attention_forward(d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); // check that preatt, att, and out match between the CPU and GPU versions printf("Checking the forward pass CPU <-> GPU...\n"); printf("[preatt]\n"); validate_result(d_preatt, preatt, "preatt", B * T * C, 5e-3f); printf("[att]\n"); validate_result(d_att, att, "att", B * T * C, 1e-3f); printf("[out]\n"); validate_result(d_out, out, "out", B * T * C, 1e-3f); // set up the memory for the backward pass float* dout = make_random_float(B * T * C); // the gradients on the output float* dinp = make_zeros_float(B * T * 3 * C); // zeros for all else, to += into float* dpreatt = make_zeros_float(B * NH * T * T); float* datt = make_zeros_float(B * NH * T * T); // call backward() on the CPU to get our reference gradients attention_backward_cpu(dinp, dpreatt, datt, dout, inp, att, B, T, C, NH); // create device memory for the backward pass float *d_dinp, *d_dqkvr, *d_dpreatt, *d_datt, *d_dvaccum, *d_dout; cudaCheck(cudaMalloc(&d_dinp, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dqkvr, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dpreatt, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_datt, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_dvaccum, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(float))); // copy over the dout gradients that starts the backprop chain cudaCheck(cudaMemcpy(d_dout, dout, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); // memset all the other memory to zeros, to += into cudaCheck(cudaMemset(d_dinp, 0, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMemset(d_dqkvr, 0, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMemset(d_dpreatt, 0, B * NH * T * T * sizeof(float))); cudaCheck(cudaMemset(d_datt, 0, B * NH * T * T * sizeof(float))); cudaCheck(cudaMemset(d_dvaccum, 0, B * T * C * sizeof(float))); // call backward() on the GPU attention_backward(kernel_num, d_dinp, d_dqkvr, d_dpreatt, d_datt, d_dvaccum, d_dout, d_inp, d_qkvr, d_preatt, d_att, d_vaccum, B, T, C, NH, block_size); // check that the gradients match between the CPU and GPU versions // note that we will only check the correctness at [att, preatt, inp] // the gradients at qkvr and vaccum will remain unchecked, but are // assumed to be correct if the other gradients are correct printf("Checking the backward pass CPU <-> GPU...\n"); printf("[datt]\n"); validate_result(d_datt, datt, "datt", B * NH * T * T, 5e-3f); printf("[dpreatt]\n"); validate_result(d_dpreatt, dpreatt, "dpreatt", B * NH * T * T, 1e-3f); printf("[dinp]\n"); validate_result(d_dinp, dinp, "dinp", B * T * 3 * C, 1e-3f); // also let's manually step through the gradients here float* h_dinp = (float*)malloc(B * T * 3 * C * sizeof(float)); cudaCheck(cudaMemcpy(h_dinp, d_dinp, B * T * 3 * C * sizeof(float), cudaMemcpyDeviceToHost)); int num_match = 0; int num_no_match = 0; int num_zero_grad = 0; int HS = C / NH; for (int i = 0; i < B * T * 3 * C; i++) { // the dimensions of inp are (B, T, 3, NH, HS) // where B = batch, T = time, 3 = qkv, NH = num heads, HS = head size // unpack the individual b,t,qkvix,h,c indices int ix = i; int c = ix % HS; ix /= HS; int h = ix % NH; ix /= NH; int qkvix = ix % 3; ix /= 3; int t = ix % T; ix /= T; int b = ix; float diff = fabs(dinp[i] - h_dinp[i]); // attempt to index at random if (b == 1 && t == 5 && c == 23 && h == 2) { printf("ix %5d [b=%4d, t=%4d, qkv=%4d, nh=%4d, hs=%4d]: ref: %f gpu: %f\n", i, b, t, qkvix, h, c, dinp[i], h_dinp[i]); } if (diff > 1e-4f) { num_no_match++; } else { num_match++; } if (dinp[i] == 0.0f) { num_zero_grad++; } } printf("Number of matching gradients: %d (%.2f%% of total)\n", num_match, 100*(float)num_match / (B * T * 3 * C)); printf("Number of non-matching gradients: %d (%.2f%% of total)\n", num_no_match, 100*(float)num_no_match / (B * T * 3 * C)); printf("Number of gradients that are exactly zero: %d (%.2f%% of total)\n", num_zero_grad, 100*(float)num_zero_grad / (B * T * 3 * C)); // final verdict printf("All results match. Starting benchmarks.\n\n"); // benchmark speed of the kernel int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 10; float elapsed_time = benchmark_kernel(repeat_times, attention_backward, kernel_num, d_dinp, d_dqkvr, d_dpreatt, d_datt, d_dvaccum, d_dout, d_inp, d_qkvr, d_preatt, d_att, d_vaccum, B, T, C, NH, block_size); printf("block_size %4d | time %f ms\n", block_size, elapsed_time); } // free memory free(inp); free(qkvr); free(preatt); free(att); free(vaccum); free(out); free(dout); free(dinp); free(dpreatt); free(datt); free(h_dinp); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_qkvr)); cudaCheck(cudaFree(d_preatt)); cudaCheck(cudaFree(d_att)); cudaCheck(cudaFree(d_vaccum)); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_dinp)); cudaCheck(cudaFree(d_dqkvr)); cudaCheck(cudaFree(d_dpreatt)); cudaCheck(cudaFree(d_datt)); cudaCheck(cudaFree(d_dvaccum)); cudaCheck(cudaFree(d_dout)); cublasDestroy(cublas_handle); return 0; } ================================================ FILE: dev/cuda/attention_forward.cu ================================================ /* Kernels for attention forward pass. If you do not have CUDNN, you can remove ENABLE_CUDNN to run the other kernels See the README for cuDNN install instructions Compile example with cuDNN: nvcc -I/PATH/TO/cudnn-frontend/include -DENABLE_CUDNN -O3 --use_fast_math --lcublas -lcublasLt -lcudnn attention_forward.cu -o attention_forward Compile example without cuDNN: nvcc -O3 --use_fast_math -lcublas -lcublasLt attention_forward.cu -o attention_forward version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only ./attention_forward 1 version 2 is a naive implementation of flash attention, taken, adapted from https://github.com/tspeterkim/flash-attention-minimal and with help from https://github.com/leloykun/flash-hyperbolic-attention-minimal sadly, this flash attention version seems about 3X slower than the naive version ./attention_forward 2 version 3 is a cuBLAS + softmax version, similar to the PyTorch implementation cuBLAS is used both to calculate the QK^T and the final weighted sum the softmax is calculated using a custom, efficient kernel as well this turns out to be ~20X faster than (1) nice ./attention_forward 3 version 4 is a further optimized kernel that fuses the scale operation, uses a directly autoregressive softmax, and uses the online softmax algorithm. ./attention_forward 4 version 5 is a FP16 version of kernel 4 ./attention_forward 5 version 6 is kernel 5 skipping (un)permute (unrealistic but useful comparison point) version 10 is using cuDNN Flash Attention using FP16 or BF16, see: https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md ./attention_forward 10 version 11 is kernel 10 skipping FP16/FP32 conversions (full FP16/BF16 network) ./attention_forward 11 */ //#define ENABLE_CUDNN // can be enabled via nvcc "-DENABLE_CUDNN" #include #include #include #include #include #include #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CUDA & cuDNN setup static bool first_run_validation = true; // always run e.g. permute on 1st run #ifdef ENABLE_CUDNN #include namespace fe = cudnn_frontend; #if CUBLAS_LOWP == CUDA_R_16BF #define CUDNN_16BIT fe::DataType_t::BFLOAT16 #else #define CUDNN_16BIT fe::DataType_t::HALF #endif static cudnnHandle_t cudnn_handle; static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!) static void* cudnn_workspace = NULL; #define checkCudaErr(err) assert((int)err == 0); #define checkCudnnErr(err) assert((int)err == 0); #endif // ENABLE_CUDNN // ---------------------------------------------------------------------------- // CPU code reference void attention_forward_cpu(float* out, float* preatt, float* att, const float* inp, int B, int T, int C, int NH) { // input is (B, T, 3C) Q,K,V // preatt, att are (B, NH, T, T) // output is (B, T, C) int C3 = C*3; int hs = C / NH; // head size float scale = 1.0 / sqrtf(hs); for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { for (int h = 0; h < NH; h++) { const float* query_t = inp + b * T * C3 + t * C3 + h * hs; float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; float* att_bth = att + b*NH*T*T + h*T*T + t*T; // pass 1: calculate query dot key and maxval float maxval = -FLT_MAX; for (int t2 = 0; t2 <= t; t2++) { const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key // (query_t) dot (key_t2) float val = 0.0f; for (int i = 0; i < hs; i++) { val += query_t[i] * key_t2[i]; } val *= scale; if (val > maxval) { maxval = val; } preatt_bth[t2] = val; } // pad with -INFINITY outside of autoregressive region for debugging comparisons for (int t2 = t+1; t2 < T; t2++) { preatt_bth[t2] = -INFINITY; } // pass 2: calculate the exp and keep track of sum float expsum = 0.0f; for (int t2 = 0; t2 <= t; t2++) { float expv = expf(preatt_bth[t2] - maxval); expsum += expv; att_bth[t2] = expv; } float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; // pass 3: normalize to get the softmax for (int t2 = 0; t2 < T; t2++) { if (t2 <= t) { att_bth[t2] *= expsum_inv; } else { // causal attention mask. not strictly necessary to set to zero here // only doing this explicitly for debugging and checking to PyTorch att_bth[t2] = 0.0f; } } // pass 4: accumulate weighted values into the output of attention float* out_bth = out + b * T * C + t * C + h * hs; for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; } for (int t2 = 0; t2 <= t; t2++) { const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value float att_btht2 = att_bth[t2]; for (int i = 0; i < hs; i++) { out_bth[i] += att_btht2 * value_t2[i]; } } } } } } // ---------------------------------------------------------------------------- // GPU kernels __global__ void attention_query_key_kernel1(float* preatt, const float* inp, int B, int T, int C, int NH) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int total_threads = B * NH * T * T; if (idx < total_threads) { int t2 = idx % T; int t = (idx / T) % T; if (t2 > t) { // autoregressive mask preatt[idx] = -INFINITY; return; } int h = (idx / (T * T)) % NH; int b = idx / (NH * T * T); int C3 = C*3; int hs = C / NH; // head size const float* query_t = inp + b * T * C3 + t * C3 + h * hs; const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key // (query_t) dot (key_t2) float val = 0.0f; for (int i = 0; i < hs; i++) { val += query_t[i] * key_t2[i]; } val *= 1.0 / sqrtf(hs); preatt[idx] = val; } } __global__ void attention_softmax_kernel1(float* att, const float* preatt, int B, int T, int NH) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int total_threads = B * T * NH; if (idx < total_threads) { int h = idx % NH; int t = (idx / NH) % T; int b = idx / (NH * T); const float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; float* att_bth = att + b*NH*T*T + h*T*T + t*T; // find maxval float maxval = -FLT_MAX; for (int t2 = 0; t2 <= t; t2++) { if (preatt_bth[t2] > maxval) { maxval = preatt_bth[t2]; } } // calculate the exp and keep track of sum float expsum = 0.0f; for (int t2 = 0; t2 <= t; t2++) { float expv = expf(preatt_bth[t2] - maxval); expsum += expv; att_bth[t2] = expv; } float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; // normalize to get the softmax for (int t2 = 0; t2 < T; t2++) { if (t2 <= t) { att_bth[t2] *= expsum_inv; } else { // causal attention mask. not strictly necessary to set to zero here // only doing this explicitly for debugging and checking to PyTorch att_bth[t2] = 0.0f; } } } } // warp-level reduction for finding the maximum value __device__ float warpReduceMax(float val) { for (int offset = 16; offset > 0; offset /= 2) { val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); } return val; } __global__ void softmax_forward_kernel4(float* out, const float* inp, int N, int C) { // out is (N, C) just like inp. Each row of inp will get softmaxed. // same as kernel3, but can handle any block size (multiple of 32) // each row of C elements is handled by block_size threads // furthermore, each block_size threads get executed in warps of 32 threads // special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions // shared memory is used for inter-warp reduction extern __shared__ float shared[]; int idx = blockIdx.x; int tid = threadIdx.x; int warpId = threadIdx.x / 32; // warp index within a block int laneId = threadIdx.x % 32; // thread index within a warp // the number of warps per block. recall that blockDim.x is block_size int warpsPerBlock = blockDim.x / 32; // shared[] must be allocated to have 2 * warpsPerBlock elements // first half for max values, the second half for sum values float* maxvals = shared; float* sumvals = &shared[warpsPerBlock]; // one row of inp, i.e. inp[idx, :] of shape (C,) const float* x = inp + idx * C; // first, thread coarsening by directly accessing global memory in series float maxval = -INFINITY; for (int i = tid; i < C; i += blockDim.x) { maxval = fmaxf(maxval, x[i]); } // now within-warp reductions for maxval maxval = warpReduceMax(maxval); // the 0th thread of each warp writes the maxval of that warp to shared memory if (laneId == 0) maxvals[warpId] = maxval; __syncthreads(); // now the 0th thread reduces the maxvals in shared memory, i.e. across warps if (tid == 0) { float val = maxvals[tid]; for (int i = 1; i < warpsPerBlock; i++) { val = fmaxf(val, maxvals[i]); } // store the final max in the first position maxvals[0] = val; } __syncthreads(); // broadcast the max to all threads float offset = maxvals[0]; // compute expf and write the result to global memory for (int i = tid; i < C; i += blockDim.x) { // subtract max for numerical stability out[idx * C + i] = expf(x[i] - offset); } // okay now we calculated exp(x - max(x)) // step 2: sum all the values and divide by the sum // thread coarsening for sum x = out + idx * C; float sumval = 0.0f; for (int i = tid; i < C; i += blockDim.x) { sumval += x[i]; } // within-warp reduction for sumval sumval = warpReduceSum(sumval); // write sumval to shared memory if (laneId == 0) sumvals[warpId] = sumval; __syncthreads(); // inter-thread reduction of sum if (tid == 0) { float val = sumvals[tid]; for (int i = 1; i < warpsPerBlock; ++i) { val += sumvals[i]; } sumvals[0] = val; } __syncthreads(); // broadcast the sum to all threads float sum = sumvals[0]; // divide the whole row by the sum for (int i = tid; i < C; i += blockDim.x) { out[idx * C + i] = x[i] / sum; } } __device__ float& vec_at(float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __device__ float vec_at(const float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) { // inp, out shape: (N, T, T), where N = B * NH // fuses the multiplication by scale inside attention // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N * T) { return; } int own_pos = idx % T; int pos_by_4 = own_pos / 4; // one row of inp, i.e. inp[idx, :] of shape (T,) const float* x = inp + idx * T; // not INF, so we don't get NaNs accidentally when subtracting two values. float maxval = -FLT_MAX; float sumval = 0.0f; const float4* x_vec = reinterpret_cast(x); for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) { float4 v = x_vec[i]; float old_maxval = maxval; for(int k = 0; k < 4; ++k) { maxval = fmaxf(maxval, vec_at(v, k)); } sumval *= expf(inv_temperature * (old_maxval - maxval)); for(int k = 0; k < 4; ++k) { sumval += expf(inv_temperature * (vec_at(v, k) - maxval)); } } if(4*pos_by_4 + warp.thread_rank() <= own_pos) { float old_maxval = maxval; maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]); sumval *= expf(inv_temperature * (old_maxval - maxval)); sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval)); } float global_maxval = cg::reduce(warp, maxval, cg::greater{}); sumval *= expf(inv_temperature * (maxval - global_maxval)); float sum = cg::reduce(warp, sumval, cg::plus{}); float norm = 1.f / sum; // divide the whole row by the sum for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, ev * norm); } } __global__ void attention_value_kernel1(float* out, const float* att, const float* inp, int B, int T, int C, int NH) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int total_threads = B * T * NH; if (idx < total_threads) { int h = idx % NH; int t = (idx / NH) % T; int b = idx / (NH * T); int C3 = C*3; int hs = C / NH; // head size float* out_bth = out + b * T * C + t * C + h * hs; const float* att_bth = att + b*NH*T*T + h*T*T + t*T; for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; } for (int t2 = 0; t2 <= t; t2++) { const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value float att_btht2 = att_bth[t2]; for (int i = 0; i < hs; i++) { out_bth[i] += att_btht2 * value_t2[i]; } } } } __global__ void attention_forward_kernel2( const float* Q, const float* K, const float* V, const int N, const int d, const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale, float* l, float* m, float* O ) { int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index // Offset into Q,K,V,O,l,m - different for each batch and head int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m // Define SRAM for Q,K,V,S extern __shared__ float sram[]; int tile_size = Bc * d; // size of Qi, Kj, Vj float* Qi = sram; float* Kj = &sram[tile_size]; float* Vj = &sram[tile_size * 2]; float* S = &sram[tile_size * 3]; for (int j = 0; j < Tc; j++) { // Load Kj, Vj to SRAM for (int x = 0; x < d; x++) { Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; } __syncthreads(); // such that the inner loop can use the correct Kj, Vj for (int i = 0; i < Tr; i++) { // if past the end of the sequence, break if (i * Br + tx >= N) { break; } // Load Qi to SRAM, l and m to registers for (int x = 0; x < d; x++) { Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; } float row_m_prev = m[lm_offset + (Br * i) + tx]; float row_l_prev = l[lm_offset + (Br * i) + tx]; // S = QK^T, row_m = rowmax(S) // S[tx][y] = Sum_{x = 0}^{d-1} {Qi[tx][x] * Kj[y][x]} // row_m = Max_{y = 0}^{Bc-1} S[tx][y] // with causal masking float row_m = -INFINITY; for (int y = 0; y < Bc; y++) { if (j * Bc + y >= N) { break; } float sum = 0; for (int x = 0; x < d; x++) { sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; } sum *= softmax_scale; if (i * Br + tx < j * Bc + y) sum = -INFINITY; S[(Bc * tx) + y] = sum; if (sum > row_m) row_m = sum; } // implement softmax with causal masking // P = exp(S - row_m), row_l = rowsum(P) // P[tx][y] = exp(S[tx][y] - row_m) float row_l = 0; for (int y = 0; y < Bc; y++) { if (j * Bc + y >= N) { break; } if (i * Br + tx < j * Bc + y) S[(Bc * tx) + y] = 0; else S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m); row_l += S[(Bc * tx) + y]; } // Compute new m and l float row_m_new = max(row_m_prev, row_m); float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l); // Write O, l, m to HBM for (int x = 0; x < d; x++) { float pv = 0; // Pij * Vj for (int y = 0; y < Bc; y++) { if (j * Bc + y >= N) { break; } pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; } O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \ * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \ + (__expf(row_m - row_m_new) * pv)); } m[lm_offset + (Br * i) + tx] = row_m_new; l[lm_offset + (Br * i) + tx] = row_l_new; } __syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop } } __global__ void permute_kernel(float* q, float* k, float* v, const float* inp, int B, int N, int NH, int d) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int inp_idx = \ (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; q[idx] = inp[inp_idx]; k[idx] = inp[inp_idx + NH * d]; v[idx] = inp[inp_idx + 2 * (NH * d)]; } } __global__ void unpermute_kernel(const float* inp, float *out, int B, int N, int NH, int d) { // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; out[other_idx] = inp[idx]; } } __global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) { // scales the pre-softmax attention scores by scale // and sets the autoregressive locations to -INFINITY int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < B * NH * T * T) { int rest = idx % (NH * T * T); rest = rest % (T * T); int t2 = rest / T; int t = rest % T; if (t > t2) { inp[idx] = -INFINITY; } else { inp[idx] *= scale; } } } // direct translation of the CPU kernel. Each warp handles ont (b, h, t) combination. // The important changes compared to the CPU version: // - each inner loop is handled by a warp // - don't write non-autoregressive parts // - reordered the last loops so that we can do all writing in the outer loop. __global__ void attention_forward_fused1(float* out, float* preatt, float* att, const float* inp, int B, int T, int C, int NH) { // input is (B, T, 3C) Q,K,V // preatt, att are (B, NH, T, T) // output is (B, T, C) int C3 = C*3; int hs = C / NH; // head size float scale = 1.0 / sqrtf(hs); namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int t = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); int h = blockIdx.y; int b = blockIdx.z; if(t >= T) return; const float* query_t = inp + b * T * C3 + t * C3 + h * hs; float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; float* att_bth = att + b*NH*T*T + h*T*T + t*T; // pass 1: calculate query dot key and maxval float maxval = -INFINITY; for (int t2 = 0; t2 <= t; t2++) { const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key // (query_t) dot (key_t2) float val = 0.0f; for (int i = warp.thread_rank(); i < hs; i += warp.size()) { val += query_t[i] * key_t2[i]; } val = cg::reduce(warp, val, cg::plus{}); val *= scale; maxval = max(maxval, val); if(warp.thread_rank() == 0) { preatt_bth[t2] = val; } } // pass 2: calculate the exp and keep track of sum float expsum = 0.0f; for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { float expv = expf(preatt_bth[t2] - maxval); expsum += expv; } expsum = cg::reduce(warp, expsum, cg::plus{}); float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; // pass 3: normalize to get the softmax is combined with the next loop to reduce memory round-trips for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { att_bth[t2] = expf(preatt_bth[t2] - maxval) * expsum_inv; } // pass 4: accumulate weighted values into the output of attention float* out_bth = out + b * T * C + t * C + h * hs; for (int i = warp.thread_rank(); i < hs; i += warp.size()) { float o = 0.f; for (int t2 = 0; t2 <= t; t2++) { const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C * 2; // +C*2 because it's value float att_btht2 = att_bth[t2]; o += att_btht2 * value_t2[i]; } out_bth[i] = o; } } // ---------------------------------------------------------------------------- // kernel launcher void attention_forward1(float* out, float* preatt, float* att, const float* inp, int B, int T, int C, int NH, const int block_size) { // attention calculation int total_threads = B * NH * T * T; int num_blocks = ceil_div(total_threads, block_size); attention_query_key_kernel1<<>>(preatt, inp, B, T, C, NH); // softmax and value accumulation total_threads = B * T * NH; num_blocks = ceil_div(total_threads, block_size); attention_softmax_kernel1<<>>(att, preatt, B, T, NH); attention_value_kernel1<<>>(out, att, inp, B, T, C, NH); } void attention_forward2(float* out, const float* inp, int B, int T, int C, int NH, const int block_size) { // TODO there should be no mallocs inside any of these functions! // not fixing this because we don't intend to use attention_forward2, // it seems to be way too slow as is // these are hardcoded to 32 for now const int Bc = 32; const int Br = 32; // renaming these to be consistent with the kernel // const int B = B; const int nh = NH; const int N = T; const int d = C / NH; // more const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br); const float softmax_scale = 1.0 / sqrt(d); // create some temporary memory float* l; float* m; cudaCheck(cudaMalloc(&l, B * nh * N * sizeof(float))); cudaCheck(cudaMalloc(&m, B * nh * N * sizeof(float))); cudaCheck(cudaMemset(l, 0, B * nh * N * sizeof(float))); cudaCheck(cudaMemset(m, -10000.0f, B * nh * N * sizeof(float))); // calculate SRAM size needed per block, ensure we have enough shared memory int col_tile_size = Bc * d; // size of Kj, Vj int row_tile_size = Br * d; // size of Qi const int sram_size = (2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj + (row_tile_size * sizeof(float)) // SRAM size for Qi + (Bc * Br * sizeof(float)); // SRAM size for S int max_sram_size; cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); if (sram_size > max_sram_size) { printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size); printf("SRAM size exceeds maximum shared memory per block\n"); printf("Try decreasing col_tile_size or row_tile_size further\n"); exit(1); } // grid and block dims dim3 grid_dim(B, nh); // batch_size x num_heads dim3 block_dim(Br); // Br threads per block // okay so now, this kernel wants Q,K,V to all be of shape (B, nh, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, nh, d) // so we have to permute the tensor using a kernel with block_size float *q, *k, *v; cudaCheck(cudaMalloc(&q, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&k, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&v, B * T * C * sizeof(float))); int total_threads = B * N * nh * d; int num_blocks = ceil_div(total_threads, block_size); permute_kernel<<>>(q, k, v, inp, B, N, nh, d); // now actually call the flash attention kernel attention_forward_kernel2<<>>( q, k, v, N, d, Tc, Tr, Bc, Br, softmax_scale, l, m, out ); // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) unpermute_kernel<<>>(out, q, B, N, nh, d); cudaCheck(cudaMemcpy(out, q, B * T * C * sizeof(float), cudaMemcpyDeviceToDevice)); // free memory cudaCheck(cudaFree(l)); cudaCheck(cudaFree(m)); cudaCheck(cudaFree(q)); cudaCheck(cudaFree(k)); cudaCheck(cudaFree(v)); } void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, float* att, const float* inp, int B, int T, int C, int NH, const int block_size) { // inp is (B, T, 3C) QKV // preatt, att are (B, NH, T, T) // output is (B, T, C) int HS = C / NH; // head size // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) float *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; int num_blocks = ceil_div(total_threads, block_size); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); // batched matrix multiply with cuBLAS const float alpha = 1.0f; const float beta = 0.0f; cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH)); // multiply all elements of preatt elementwise by scale float scale = 1.0f / sqrtf(HS); total_threads = B * NH * T * T; num_blocks = ceil_div(total_threads, block_size); scale_kernel<<>>(preatt, scale, B, NH, T); // softmax. preatt is (B, NH, T, T) but we view it as (B * NH * T, T) and use the softmax kernel int softmax_block_size = 256; int grid_size = B * NH * T; size_t shared_mem_size = 2 * softmax_block_size / 32 * sizeof(float); softmax_forward_kernel4<<>>(att, preatt, B * NH * T, T); // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH)); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = ceil_div(B * T * C, block_size); unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); } void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, float* att, const float* inp, int B, int T, int C, int NH, const int block_size) { // inp is (B, T, 3C) QKV // preatt, att are (B, NH, T, T) // output is (B, T, C) int HS = C / NH; // head size // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) float *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; int num_blocks = ceil_div(total_threads, block_size); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); // batched matrix multiply with cuBLAS const float alpha = 1.0f; const float beta = 0.0f; cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH)); // multiply all elements of preatt elementwise by scale float scale = 1.0 / sqrtf(HS); int softmax_block_size = 256; int grid_size = ceil_div(B * NH * T * 32, softmax_block_size); softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH)); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = ceil_div(B * T * C, block_size); unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); } __global__ void softmax_forward_kernel5_lowp(floatX* out, float inv_temperature, const floatX* inp, int N, int T) { // inp, out shape: (N, T, T), where N = B * NH // fuses the multiplication by scale inside attention // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N * T) { return; } int own_pos = idx % T; int pos_by_4 = own_pos / 4; // one row of inp, i.e. inp[idx, :] of shape (T,) const floatX* x = inp + idx * T; // not INF, so we don't get NaNs accidentally when subtracting two values. float maxval = -FLT_MAX; float sumval = 0.0f; // Same thing but without float4, one at a time for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) { float old_maxval = maxval; for(int k = 0; k < 4; ++k) { maxval = fmaxf(maxval, (float)x[4*i + k]); } sumval *= expf(inv_temperature * (old_maxval - maxval)); for(int k = 0; k < 4; ++k) { sumval += expf(inv_temperature * ((float)x[4*i + k] - maxval)); } } if(4*pos_by_4 + warp.thread_rank() <= own_pos) { float old_maxval = maxval; maxval = fmaxf(maxval, (float)x[4*pos_by_4 + warp.thread_rank()]); sumval *= expf(inv_temperature * (old_maxval - maxval)); sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + warp.thread_rank()] - maxval)); } float global_maxval = cg::reduce(warp, maxval, cg::greater{}); sumval *= expf(inv_temperature * (maxval - global_maxval)); float sum = cg::reduce(warp, sumval, cg::plus{}); float norm = 1.f / sum; // divide the whole row by the sum for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, (floatX)(ev * norm)); } } __global__ void permute_kernel_lowp(floatX* q, floatX* k, floatX* v, const float* inp, int B, int N, int NH, int d) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int inp_idx = \ (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; q[idx] = (floatX)inp[inp_idx]; k[idx] = (floatX)inp[inp_idx + NH * d]; v[idx] = (floatX)inp[inp_idx + 2 * (NH * d)]; } } __global__ void unpermute_kernel_lowp(const floatX* inp, float *out, int B, int N, int NH, int d) { // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) int idx = blockIdx.x * blockDim.x + threadIdx.x; // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] if (idx < B * NH * N * d) { int b = idx / (NH * N * d); int rest = idx % (NH * N * d); int nh_ = rest / (N * d); rest = rest % (N * d); int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; out[other_idx] = (float)inp[idx]; } } void attention_forward5(float* out, floatX* vaccum, floatX* qkvr, floatX* preatt, floatX* att, const float* inp, int B, int T, int C, int NH, const int block_size, bool skip_permute=false) { // FP16 version of kernel 4 (with permute/unpermute doing FP32<->FP16) // That permute can be skipped on perf runs to analyse its performance impact // inp is (B, T, 3C) QKV // preatt, att are (B, NH, T, T) // output is (B, T, C) // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) int HS = C / NH; // head size floatX *q = qkvr + 0 * B * T * C; floatX *k = qkvr + 1 * B * T * C; floatX* v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; int num_blocks = ceil_div(total_threads, block_size); if (!skip_permute || first_run_validation) { permute_kernel_lowp<<>>(q, k, v, inp, B, T, NH, HS); } // IMPORTANT: alpha/beta are FP32 for CUBLAS_COMPUTE_32F even if FP16 inputs/outputs // But need FP16 scale for CUBLAS_COMPUTE_16F (no errors otherwise, just garbage results *sigh*) const float alpha = 1.0f; const float beta = 0.0f; const floatX alpha_lowp = (floatX)alpha; const floatX beta_lowp = (floatX)beta; void* alpha_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&alpha_lowp : (void*)α void* beta_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&beta_lowp : (void*)β // batched matrix multiply with cuBLAS cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, alpha_ptr, k, CUBLAS_LOWP, HS, T * HS, q, CUBLAS_LOWP, HS, T * HS, beta_ptr, preatt, CUBLAS_LOWP, T, T * T, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT)); // multiply all elements of preatt elementwise by scale float scale = 1.0f / sqrtf(HS); int softmax_block_size = 256; int grid_size = ceil_div(B * NH * T * 32, softmax_block_size); softmax_forward_kernel5_lowp<<>>(att, scale, preatt, B * NH, T); // new approach: first cuBLAS another batched matmul // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, alpha_ptr, v, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, beta_ptr, vaccum, CUBLAS_LOWP, HS, T * HS, B * NH, CUBLAS_LOWP_COMPUTE, CUBLAS_GEMM_DEFAULT)); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side num_blocks = ceil_div(B * T * C, block_size); if(!skip_permute || first_run_validation) { unpermute_kernel_lowp<<>>(vaccum, out, B, T, NH, HS); } } #ifdef ENABLE_CUDNN using graph_tensors_fwd = std::tuple, std::shared_ptr, // Q, std::shared_ptr, // K, std::shared_ptr, // V, std::shared_ptr, // Attn_scale, std::shared_ptr, // O std::shared_ptr>; // Stats // Need a cache because graph->build_operation_graph() is slow but everything else seems fast using cache_type_fwd = std::unordered_map; // Loosely based on cuDNN frontend samples functions and massively simplified template auto lookup_cache_or_build_graph_fwd(Args... args) { static cache_type_fwd user_maintained_cache_fwd; auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...); auto graph = std::make_shared(); graph->set_io_data_type(CUDNN_16BIT) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute auto Q = graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({B, H, T, HS}) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto K = graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({B, H, T, HS}) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto V = graph->tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({B, H, T, HS}) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto attn_scale = graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention"); sdpa_options.set_is_inference(is_inference_only); sdpa_options.set_attn_scale(attn_scale); sdpa_options.set_causal_mask(true); // Create the graph operation and get the output tensors back auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32 O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}); assert(stats == nullptr || is_inference_only == false); if (is_inference_only == false) { stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) .set_dim({B, H, T, 1}) .set_stride({H * T, T, 1, 1}); } assert(graph->validate().is_good()); auto key = graph->key(); auto it = user_maintained_cache_fwd.find(key); if (it != user_maintained_cache_fwd.end()) { return it->second; } // Build the operation graph and execution part (this is the VERY SLOW PART) assert(graph->build_operation_graph(cudnn_handle).is_good()); auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); assert(graph->check_support(cudnn_handle).is_good()); assert(graph->build_plans(cudnn_handle).is_good()); auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); user_maintained_cache_fwd.insert({key, tuple}); return tuple; } // Used on first run only so we can validate against the CPU results __global__ void fp32_to_lowp_kernel(floatX* out, const float* inp) { int idx = blockIdx.x * blockDim.x + threadIdx.x; out[idx] = (floatX)inp[idx]; } __global__ void lowp_to_fp32_kernel(const floatX* inp, float *out) { int idx = blockIdx.x * blockDim.x + threadIdx.x; out[idx] = (float)inp[idx]; } void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) float* stats, // output for backward pass: (B, NH, T) floatX* inp, // input: (B, T, 3, NH, HS) QKV float* in_fp32, // fp32 input float* out_fp32, // fp32 output for validation int B, int T, int C, int NH) { static bool first_run_validation = true; int HS = C / NH; // number of features per head bool is_inference_only = (stats == nullptr); // Convert from FP32 to FP16/BF16 on 1st run to get correct results const int block_size = 64; // smallest full occupancy block size on modern GPUs if (first_run_validation) { int total_threads = B * T * C * 3; assert(total_threads % block_size == 0); int num_blocks = total_threads / block_size; fp32_to_lowp_kernel<<>>(inp, in_fp32); } // Get graph and tensors from cache (or generate it on first use) auto [graph, Q, K, V, attn_scale, O, softmax_stats] = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); // Prepare all the tensor pointers for executing the graph void* devPtrQ = inp; void* devPtrK = (inp + C); void* devPtrV = (inp + 2 * C); float attn_scale_cpu = 1.0 / sqrtf(HS); void* devPtrO = out; // Build variant pack std::unordered_map, void*> variant_pack = { {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; // Add the stats tensor unless we are only doing inference (only needed for backward pass) if (is_inference_only == false) { variant_pack[softmax_stats] = stats; } // Reallocate the workspace if the required size is greater than the current workspace // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum if (graph->get_workspace_size() > cudnn_workspace_size) { if (cudnn_workspace_size > 0) { cudaCheck(cudaFree(cudnn_workspace)); } cudnn_workspace_size = graph->get_workspace_size(); cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); } // Execute graph assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good()); cudaCheck(cudaGetLastError()); // Optionally convert back from FP16/BF16 to FP32 if (first_run_validation) { int total_threads = B * T * C; assert(total_threads % block_size == 0); int num_blocks = total_threads / block_size; lowp_to_fp32_kernel<<>>(out, out_fp32); } cudaCheck(cudaGetLastError()); first_run_validation = false; } #endif // ENABLE_CUDNN // kernel version dispatch void attention_forward(int kernel_num, float* out, float* stats, float* vaccum, float* qkvr, float* preatt, float* att, float* inp, int B, int T, int C, int NH, const int block_size) { switch (kernel_num) { case 1: attention_forward1(out, preatt, att, inp, B, T, C, NH, block_size); break; case 2: attention_forward2(out, inp, B, T, C, NH, block_size); break; case 3: attention_forward3(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size); break; case 4: attention_forward4(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size); break; case 5: attention_forward5(out, (floatX*)vaccum, (floatX*)qkvr, (floatX*)preatt, (floatX*)att, inp, B, T, C, NH, block_size, false); break; case 6: // skip permutes for perf passes (to analyse perf as if in/out were truly 16-bit) attention_forward5(out, (floatX*)vaccum, (floatX*)qkvr, (floatX*)preatt, (floatX*)att, inp, B, T, C, NH, block_size, true); break; #ifdef ENABLE_CUDNN case 10: // note: validation only cares about out, which is out_fp32 of the function // inp is hackily converted to FP16 into qkvr only on the first run // similarly, vaccum is converted to FP32 into out only on the first run attention_forward_cudnn((floatX*)vaccum, stats, (floatX*)qkvr, inp, out, B, T, C, NH); break; #endif default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; int NH = 12; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, deviceIdx); // setup cuBLAS (and cuDNN if needed) cublasCreate(&cublas_handle); int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; printf("enable_tf32: %d\n", enable_tf32); cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); #ifdef ENABLE_CUDNN checkCudnnErr(cudnnCreate(&cudnn_handle)); #endif // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* preatt = (float*)malloc(B * NH * T * T * sizeof(float)); float* att = (float*)malloc(B * NH * T * T * sizeof(float)); //float* inp = make_random_float(B * T * 3 * C, 10.0f); float* inp = make_random_float(B * T * 3 * C); // move to GPU float* d_out; float* d_stats; // for cuDNN float* d_vaccum; float* d_qkvr; float* d_preatt; float* d_att; float* d_inp; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_stats, B * NH * T * sizeof(float))); cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512}; // Lower accuracy requirements for FP16 (1e-4f also too much for TF32 on kernels 3 & 4) float accuracy_threshold = (kernel_num <= 4) ? 1e-3f : 1e-2f; // first check the correctness of the kernel attention_forward_cpu(out, preatt, att, inp, B, T, C, NH); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); attention_forward(kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); // all kernels should produce the correct output out // todo - make accuracy threshold dynamic and depend on FP16 vs FP32? validate_result(d_out, out, "out", B * T * C, accuracy_threshold); // but as for preatt and att, things get a bit more complicated: if (kernel_num != 2 && kernel_num < 5) { // kernel 2 (knowingly) fails att/preatt because it uses a different algorithm // that estimates the softmax online and never materializes preatt/att validate_result(d_att, att, "att", B * NH * T * T, accuracy_threshold); } if (kernel_num != 2 && kernel_num < 4) { // kernel 4 (knowingly) fails preatt because it fuses the scale normalization // into the softmax, so preatt is off by 1.0f / sqrt(HS) // but att and out (checked below) should match. validate_result(d_preatt, preatt, "preatt", B * NH * T * T, accuracy_threshold); } } printf("All results match. Starting benchmarks.\n\n"); first_run_validation = false; // benchmark speed of the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, attention_forward, kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size); printf("block_size %4d | time %f ms\n", block_size, elapsed_time); } // free memory free(out); free(preatt); free(att); free(inp); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_vaccum)); cudaCheck(cudaFree(d_qkvr)); cudaCheck(cudaFree(d_preatt)); cudaCheck(cudaFree(d_att)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_stats)); cublasDestroy(cublas_handle); #ifdef ENABLE_CUDNN cudnnDestroy(cudnn_handle); if (cudnn_workspace_size > 0) { cudaCheck(cudaFree(cudnn_workspace)); } #endif return 0; } ================================================ FILE: dev/cuda/benchmark_on_modal.py ================================================ """ Script for running benchmarks on the Modal platform. This is useful for folks who do not have access to expensive GPUs locally. Example usage for cuda kernels: GPU_MEM=80 modal run benchmark_on_modal.py \ --compile-command "nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas" \ --run-command "./attention_forward 1" OR if you want to use cuDNN etc. For training the gpt2 model with cuDNN use: GPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \ --compile-command "make train_gpt2cu USE_CUDNN=1" --run-command "./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4" For profiling using nsight system: GPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \ --compile-command "make train_gpt2cu USE_CUDNN=1" \ --run-command "nsys profile --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true \ ./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin \ -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4" For more nsys profiling specifics and command options, take a look at: https://docs.nvidia.com/nsight-systems/2024.2/UserGuide/ -> To profile the report using a GUI, download NVIDIA NSight System GUI version (this software can run on all OS, so you download it locally) NOTE: Currently there is a bug in the profiling using nsight system which produces a unrecognized GPU UUId error on the command line but it does not actually interfere with the model training and validation. The report (that you download) is still generated and can be viewed from Nsight Systems """ import subprocess import os import sys import datetime import modal from modal import Image, Stub GPU_NAME_TO_MODAL_CLASS_MAP = { "H100": modal.gpu.H100, "A100": modal.gpu.A100, "A10G": modal.gpu.A10G, } N_GPUS = int(os.environ.get("N_GPUS", 1)) GPU_MEM = int(os.environ.get("GPU_MEM", 40)) GPU_NAME = os.environ.get("GPU_NAME", "A100") GPU_CONFIG = GPU_NAME_TO_MODAL_CLASS_MAP[GPU_NAME](count=N_GPUS, size=str(GPU_MEM) + 'GB') APP_NAME = "llm.c benchmark run" image = ( Image.from_registry("totallyvyom/cuda-env:latest-2") .pip_install("huggingface_hub==0.20.3", "hf-transfer==0.1.5") .env( dict( HUGGINGFACE_HUB_CACHE="/pretrained", HF_HUB_ENABLE_HF_TRANSFER="1", TQDM_DISABLE="true", ) ) .run_commands( "wget -q https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-Linux-x86_64.sh", "bash cmake-3.28.1-Linux-x86_64.sh --skip-license --prefix=/usr/local", "rm cmake-3.28.1-Linux-x86_64.sh", "ln -s /usr/local/bin/cmake /usr/bin/cmake",) .run_commands( "apt-get install -y --allow-change-held-packages libcudnn8 libcudnn8-dev", "apt-get install -y openmpi-bin openmpi-doc libopenmpi-dev kmod sudo", "git clone https://github.com/NVIDIA/cudnn-frontend.git /root/cudnn-frontend", "cd /root/cudnn-frontend && mkdir build && cd build && cmake .. && make" ) .run_commands( "wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \ add-apt-repository \"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /\" && \ apt-get update" ).run_commands( "apt-get install -y nsight-systems-2023.3.3" ) ) stub = modal.App(APP_NAME) def execute_command(command: str): command_args = command.split(" ") print(f"{command_args = }") subprocess.run(command_args, stdout=sys.stdout, stderr=subprocess.STDOUT) @stub.function( gpu=GPU_CONFIG, image=image, allow_concurrent_inputs=4, container_idle_timeout=900, mounts=[modal.Mount.from_local_dir("./", remote_path="/root/")], # Instead of 'cuda-env' put your volume name that you create from 'modal volume create {volume-name}' # This enables the profiling reports to be saved on the volume that you can download by using: # 'modal volume get {volume-name} {/output_file_name} # For example right now, when profiling using this command "nsys profile --trace=cuda,nvtx --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true" you would get your report # using in a directory in your volume, where the name contains the timestamp unique id. # This script will generate a "report1_{timestamp} folder in volume" # and you can download it with 'modal volume get {volume-name} report1_{timestamp} volumes={"/cuda-env": modal.Volume.from_name("cuda-env")}, ) def run_benchmark(compile_command: str, run_command: str): execute_command("pwd") execute_command("ls") execute_command(compile_command) execute_command(run_command) # Use this section if you want to profile using nsight system and install the reports on your volume to be locally downloaded timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") execute_command("mkdir report1_" + timestamp) execute_command("mv /root/report1.nsys-rep /root/report1_" + timestamp + "/") execute_command("mv /root/report1.qdstrm /root/report1_" + timestamp + "/") execute_command("mv /root/report1_" + timestamp + "/" + " /cuda-env/") return None @stub.local_entrypoint() def inference_main(compile_command: str, run_command: str): results = run_benchmark.remote(compile_command, run_command) return results ================================================ FILE: dev/cuda/classifier_fused.cu ================================================ /* Kernels for fused forward/backward classifier part This fuses softmax, crossentropy, and logit gradients into a single pass, so we don't have to write unnecessary (B, T, V) tensors. Such an operation is only possible if `dloss` can be known beforehand, which doesn't seem like much of a restriction: In pretraining, it is just a constant 1/batch_size tensor, for fine-tuning we might zero out the input prompt, but that is known in advance. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt classifier_fused.cu -o classifier_fused ./classifier_fused 1 ./classifier_fused 2 ./classifier_fused 3 ./classifier_fused 4 */ #include #include #include #include #include #include #include "common.h" // todo - this file does not properly support anything but FP32 // kernel 5 can be run in fp16/bf16 to test performance, but the outputs will be wrong #if defined(ENABLE_BF16) typedef __nv_bfloat16 floatX; #elif defined(ENABLE_FP16) typedef half floatX; #else typedef float floatX; #endif typedef Packed128 x128; // ---------------------------------------------------------------------------- // CPU code reference void softmax_forward_cpu(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed for (int64_t i = 0; i < N; i++) { const float* inp_row = inp + i * C; float* out_row = out + i * C; float maxval = -INFINITY; for (int j = 0; j < C; j++) { if (inp_row[j] > maxval) { maxval = inp_row[j]; } } double sum = 0.0; for (int j = 0; j < C; j++) { out_row[j] = expf(inp_row[j] - maxval); sum += out_row[j]; } for (int j = 0; j < C; j++) { out_row[j] /= sum; } } } void crossentropy_forward_cpu(float* losses, const float* probs, const int* targets, int B, int T, int V) { // output: losses is (B,T) of the individual losses at each position // input: probs are (B,T,V) of the probabilities // input: targets is (B,T) of integers giving the correct index in logits for (int64_t bt = 0; bt < B * T; bt++) { // loss = -log(probs[target]) const float* probs_bt = probs + bt * V; int ix = targets[bt]; losses[bt] = -logf(probs_bt[ix]); } } void crossentropy_softmax_backward_cpu(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V) { // backwards through both softmax and crossentropy for (int64_t bt = 0; bt < B * T; bt++) { float* dlogits_bt = dlogits + bt * V; const float* probs_bt = probs + bt * V; float dloss = dlosses[bt]; int ix = targets[bt]; for (int i = 0; i < V; i++) { float p = probs_bt[i]; float indicator = i == ix ? 1.0f : 0.0f; dlogits_bt[i] = (p - indicator) * dloss; } } } // ---------------------------------------------------- // Kernel Utils // warp-level reduction for finding the maximum value __device__ float warpReduceMax(float val) { for (int offset = 16; offset > 0; offset /= 2) { val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset)); } return val; } // ---------------------------------------------------------------------------- // GPU kernels struct SoftmaxParams { float Scale; float Offset; }; namespace cg = cooperative_groups; __device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp, int64_t idx, const float* inp, int V, int P) { // this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,) // note that inp is actually (B * T, P) but we only use the first V elements // this function then calculates: // 1) the max value to subtract for numerical stability and // 2) the sum normalization factor const float* x = inp + idx * P; // thread coarsening loop, where the 32 threads serially process all V elements // thread_rank() is in [0, 31], warp.size() is 32 float maxval = -INFINITY; float sumval = 0.0f; for (int i = warp.thread_rank(); i < V; i += warp.size()) { float v = x[i]; float old_maxval = maxval; // online softmax recurrence from "Online normalizer calculation for softmax" paper maxval = fmaxf(maxval, v); sumval *= expf((old_maxval - maxval)); sumval += expf(v - maxval); } // warp-level reduction to get the maxval across the 32 threads float global_maxval = cg::reduce(warp, maxval, cg::greater{}); // all 32 threads do a final shift of the sum considering the global max in this row sumval *= expf((maxval - global_maxval)); // warp-level reduction to get the sumval across the 32 threads float global_sumval = cg::reduce(warp, sumval, cg::plus{}); // the final normalization factor float norm = 1.0f / global_sumval; return SoftmaxParams{norm, global_maxval}; } __global__ void fused_classifier_kernel1(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); // example: B = 4, T = 1024, block_size = 128 => we'd have grid_size = 1024 // each block of 4 warps is in charge of 4 rows of the input, one warp per row // meta_group_size is the number of warps per block (e.g. 4) // meta_group_rank is the index of the warp in the block (e.g. 0, 1, 2, 3) int64_t idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if (idx >= B * T) { // there are B * T rows in the input return; } int b = idx / T; int t = idx % T; // calculate the offset (maxval) and scale (sumval) for the softmax SoftmaxParams sp = prepare_softmax(warp, idx, logits, V, P); // in each row (handled by one warp), thread 0 calculates the loss // calculate the probability needed for the loss and update losses if(warp.thread_rank() == 0) { int ix = targets[b * T + t]; float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[b * T + t] = -logf(prob); } // finally all threads calculate the gradients // prob is only materialized here temporarily and in registers, never // as a full tensor that gets written to global memory for (int i = warp.thread_rank(); i < V; i += warp.size()) { float prob = expf(logits[idx * P + i] - sp.Offset) * sp.Scale; float* dlogits_bt = dlogits + b * T * P + t * P; float dloss = dlosses[b * T + t]; int ix = targets[b * T + t]; float indicator = i == ix ? 1.0f : 0.0f; dlogits_bt[i] = (prob - indicator) * dloss; } } __device__ float vec_at(const float4& vec, int index) { return reinterpret_cast(&vec)[index]; } __device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& warp, int64_t idx, const float* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) // float4 to get 128-bit loads and memory level parallelism const float4* x_vec4 = reinterpret_cast(inp + idx * P); float thread_maxval = -INFINITY; float thread_sumval = 0.0f; // do the loop in reverse to maximise probability of L2 cache hits // so even small L2s get some hits on the 2nd read of the same thread for (int i = ceil_div(V, 4) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { float4 v4 = x_vec4[i]; #pragma unroll for(int k = 0; k < 4; k++) { if (i*4+k >= V) { // bounds checking against real V continue; } float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, vec_at(v4, k)); thread_sumval *= expf(old_maxval - thread_maxval); thread_sumval += expf(vec_at(v4, k) - thread_maxval); } } // two reductions of up to 1024 threads: // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) // this results in much cleaner assembly than a multi-warp cg::reduce __shared__ float shared_maxval[32]; __shared__ float shared_sumval[32]; int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; // reduce maxval within each warp float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater{}); // thread 0 in each warp writes to shared memory if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; } __syncthreads(); // each thread now loads the maxval across previous warps // if the thread is "out of range" of data, use -FLT_MAX as the maxval warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX; // now reduce the maxval among the warp threads float block_maxval = cg::reduce(warp, warp_maxval, cg::greater{}); // each thread uses maxval to scale sumval to avoid numerical instability / overflow thread_sumval *= expf(thread_maxval - block_maxval); // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus{}); if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; } __syncthreads(); // same strategy, now reduce sumval across warps warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f; float block_sumval = cg::reduce(warp, warp_sumval, cg::plus{}); // return the softmax parameters return SoftmaxParams{1.f / block_sumval, block_maxval}; } // Fused forward and backward pass for classifier including softmax, and logit gradients // Writes to both probs (only for debugging) and dlogits (only for training) are optional // N.B.: We may want to reuse the logits memory for dlogits, so they should *not* be __restrict__! __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* probs, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) SoftmaxParams sp = prepare_softmax_blockwide(warp, idx, logits, V, P); // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[idx] = -logf(prob); } // very sensible default for dlosses is 1/(B*T), which is the uniform loss float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const float4* logits_vec4 = reinterpret_cast(logits + idx * P); for (int i = threadIdx.x; i < ceil_div(V, 4); i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // this data will never be needed again, so we reduce cache persistence float4 v4 = __ldcs(&logits_vec4[i]); #pragma unroll for(int k = 0; k < 4; ++k) { int element = i*4 + k; float prob = expf(vec_at(v4, k) - sp.Offset) * sp.Scale; prob = (element < V) ? prob : 0.0f; // bounds checking against real V // this kernel is DRAM limited so cost of inner branch is ~zero if (probs != NULL) { probs[idx * P + element] = prob; } if (dlogits != NULL) { float indicator = element == ix ? 1.0f : 0.0f; dlogits[idx * P + element] = (prob - indicator) * dloss; } } } } __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp, int64_t idx, const float* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) const float* x = inp + idx * P; float thread_maxval = -INFINITY; float thread_sumval = 0.0f; // do the loop in reverse to maximise probability of L2 cache hits // so even small L2s get some hits on the 2nd read of the same thread for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { float v = x[i]; float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf(old_maxval - thread_maxval); thread_sumval += expf(v - thread_maxval); } // two reductions of up to 1024 threads: // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) // this results in much cleaner assembly than a multi-warp cg::reduce __shared__ float shared_maxval[32]; __shared__ float shared_sumval[32]; int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; // reduce maxval within each warp float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater{}); // thread 0 in each warp writes to shared memory if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; } __syncthreads(); // each thread now loads the maxval across previous warps // if the thread is "out of range" of data, use -FLT_MAX as the maxval warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX; // now reduce the maxval among the warp threads float block_maxval = cg::reduce(warp, warp_maxval, cg::greater{}); // each thread uses maxval to scale sumval to avoid numerical instability / overflow thread_sumval *= expf(thread_maxval - block_maxval); // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus{}); if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; } __syncthreads(); // same strategy, now reduce sumval across warps warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f; float block_sumval = cg::reduce(warp, warp_sumval, cg::plus{}); // return the softmax parameters return SoftmaxParams{1.f / block_sumval, block_maxval}; } // same as 2 but not using float4 __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* probs, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P); // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[idx] = -logf(prob); } // very sensible default for dlosses is 1/(B*T), which is the uniform loss float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const float* logits_vec = logits + idx * P; for (int i = threadIdx.x; i < V; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // this data will never be needed again, so we reduce cache persistence float v = __ldcs(&logits_vec[i]); float prob = expf(v - sp.Offset) * sp.Scale; if (probs != NULL) { probs[idx * P + i] = prob; } if (dlogits != NULL) { float indicator = (i == ix) ? 1.0f : 0.0f; dlogits[idx * P + i] = (prob - indicator) * dloss; } } } __device__ SoftmaxParams prepare_softmax_blockwide2(int64_t idx, const floatX* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) const floatX* x = inp + idx * P; float thread_maxval = -INFINITY; float thread_sumval = 0.0f; // do the loop in reverse to maximise probability of L2 cache hits // so even small L2s get some hits on the 2nd read of the same thread for (int i = ceil_div(V, x128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { x128 packed_x = load128cs(x + i * x128::size); // load and do not keep in cache for(int k = 0; k < packed_x.size; ++k) { if (i*x128::size+k >= V) { // bounds checking against real V continue; } float v = (float)packed_x[k]; float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf(old_maxval - thread_maxval); thread_sumval += expf(v - thread_maxval); } } // two reductions of up to 1024 threads: // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) // this results in much cleaner assembly than a multi-warp cg::reduce __shared__ float shared_maxval[32]; __shared__ float shared_sumval[32]; int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; // reduce maxval within each warp float warp_maxval = warpReduceMax(thread_maxval); // thread 0 in each warp writes to shared memory if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; } __syncthreads(); // each thread now loads the maxval across previous warps // if the thread is "out of range" of data, use -FLT_MAX as the maxval warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX; // now reduce the maxval among the warp threads float block_maxval = warpReduceMax(warp_maxval); // each thread uses maxval to scale sumval to avoid numerical instability / overflow thread_sumval *= expf(thread_maxval - block_maxval); // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory float warp_sumval = warpReduceSum(thread_sumval); //cg::reduce(warp, thread_sumval, cg::plus{}); if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; } __syncthreads(); // same strategy, now reduce sumval across warps warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f; float block_sumval = warpReduceSum(warp_sumval); //cg::reduce(warp, thread_sumval, cg::plus{}); // return the softmax parameters return SoftmaxParams{1.f / block_sumval, block_maxval}; } // same as 2 but using x128 __global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) SoftmaxParams sp = prepare_softmax_blockwide2(idx, logits, V, P); // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[idx] = -logf(prob); } // very sensible default for dlosses is 1/(B*T), which is the uniform loss float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const floatX* logits_vec = logits + idx * P; for (int i = threadIdx.x; i < ceil_div(V , x128::size); i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // this data will never be needed again, so we reduce cache persistence x128 packed_logits_vec = load128cs(logits_vec + i * x128::size); // load and do not keep in cache x128 packed_probs; x128 packed_dlogits; for(int k = 0; k < packed_logits_vec.size; ++k) { int element = i*packed_logits_vec.size + k; if (element >= V) { // bounds checking against real V continue; } float v = packed_logits_vec[k]; float prob = expf(v - sp.Offset) * sp.Scale; packed_probs[k] = prob; float indicator = (element == ix) ? 1.0f : 0.0f; packed_dlogits[k] = (prob - indicator) * dloss; } // Note: missing .cs hint hurts our performance due to cache thrashing, fixed in kernel5 store128(dlogits + idx * P + i * packed_logits_vec.size, packed_dlogits); if (probs != NULL) { store128(probs + idx * P + i * packed_logits_vec.size, packed_probs); } } } __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) const floatX* x = inp + idx * P; float thread_maxval = -INFINITY; float thread_sumval = 0.0f; int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; // special-case loop to handle the unaligned elements at the end of the array // this lets us skip the bounds check in the main loop below, which improves performance while ((i+1)*x128::size > V) { for(int k = 0; k < x128::size; ++k) { if (i*x128::size+k >= V) { break; // bounds checking against real V (rather than padded P) } float v = (float)x[i*x128::size+k]; float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); thread_sumval += expf(v - thread_maxval); } i -= blockDim.x; } // main loop for the bulk of the iterations (no bounds checking required!) for (; i >= 0; i -= blockDim.x) { x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop for(int k = 0; k < x128::size; ++k) { float v = (float)packed_x[k]; float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); thread_sumval += expf(v - thread_maxval); } } // Block Max Reduction -> Maths -> Block Sum Reduction float block_maxval = blockReduce(thread_maxval, false, -FLT_MAX); thread_sumval *= expf(thread_maxval - block_maxval); float block_sumval = blockReduce(thread_sumval); // return the softmax parameters return SoftmaxParams{1.f / block_sumval, block_maxval}; } // will _update_ logits to logit gradients // uses template to decide whether to write logits and probs // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P); // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[idx] = (floatX)(-logf(prob)); } // very sensible default for dlosses is 1/(B*T), which is the uniform loss float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const floatX* logits_vec = logits + idx * P; for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // it will be overwritten by the logits gradients which is when we reduce cache persistence x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs x128 packed_probs; for(int k = 0; k < x128::size; ++k) { int element = i*x128::size + k; float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale; packed_probs[k] = (floatX)prob; float indicator = (element == ix) ? 1.0f : 0.0f; packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); } if (WriteLogits){ // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec); } if (WriteProbs) { store128(probs + idx * P + i * x128::size, packed_probs); } } // handle remaining elements after the last multiple of x128::size // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size for (int i = threadIdx.x + unaligned_start; i < V; i++) { float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; if (WriteLogits){ __stcs(dlogits + idx * P + i, (floatX)dlogit); } if (WriteProbs) { probs[idx * P + i] = (floatX)prob; } } } // ---------------------------------------------------------------------------- // kernel launcher void fused_classifier1(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { const int N = B * T; // total number of rows in the input // how many rows of the input can each block of threads process? // e.g. in block_size=128, 4 rows get handled by 4 warps (of 32 threads each) const int rows_per_block = block_size / 32; const int grid_size = N / rows_per_block; // total number of blocks needed fused_classifier_kernel1<<>>(dlogits, losses, logits, dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } void fused_classifier2(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { const int N = B * T; const int grid_size = N; fused_classifier_kernel2<<>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } void fused_classifier3(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { const int N = B * T; const int grid_size = N; fused_classifier_kernel3<<>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } void fused_classifier4(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { const int N = B * T; const int grid_size = N; fused_classifier_kernel4<<>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } void fused_classifier5(float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { const int N = B * T; const int grid_size = N; fused_classifier_kernel5<<>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } void fused_classifier(int kernel_num, float* dlogits, float* losses, const float* logits, const float* dlosses, const int* targets, int B, int T, int V, int P, int block_size) { switch (kernel_num) { case 1: fused_classifier1(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; case 2: fused_classifier2(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; case 3: fused_classifier3(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; case 4: fused_classifier4(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; case 5: fused_classifier5(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int64_t B = 8; // batch size int64_t T = 1024; // sequence length int64_t V = 50257; // vocab size int64_t P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64 int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* logits = make_random_float(B * T * V); float* probs = make_random_float_01(B * T * V); float* dlogits = (float*)malloc(B * T * V * sizeof(float)); float* losses = (float*)malloc(B * T * sizeof(float)); float* dlosses = make_random_float(B * T); int* targets = make_random_int(B * T, V); // make the input less uniformly random: Otherwise, all probabilities will be basically zero, // and the tests are not actually meaningful. int* outliers = make_random_int(B * T * 3, V); for(int k = 0; k < 3; ++k) { for(int j = 0; j < B * T; ++j) { logits[j * V + outliers[j*3 + k]] *= 20; } } // move to GPU int *d_targets; float *d_logits, *d_losses; float *d_dlogits, *d_dlosses, *d_dlogits_no_pad; cudaCheck(cudaMalloc(&d_dlogits, B * T * P * sizeof(float))); cudaCheck(cudaMalloc(&d_logits, B * T * P * sizeof(float))); cudaCheck(cudaMalloc(&d_dlogits_no_pad, B * T * V * sizeof(float))); cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(&d_losses, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_dlosses, B * T * sizeof(float))); // move to GPU cudaCheck(cudaMemset(d_logits, 0xff, B * T * P * sizeof(float))); cudaCheck(cudaMemcpy2D(d_logits, P * sizeof(float), logits, V * sizeof(float), V * sizeof(float), B * T, cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_dlosses, dlosses, B * T * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // define block sizes we'll use in correctness and timing int block_sizes[] = {32, 64, 128, 256, 512, 1024}; // first check the correctness of the kernel softmax_forward_cpu(probs, logits, B * T, V); crossentropy_forward_cpu(losses, probs, targets, B, T, V); crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V); #if defined(ENABLE_BF16) || defined(ENABLE_FP16) if (kernel_num < 4) // kernel 4/5 + BF16 is only for testing performance, it doesn't do the format conversions yet etc... #endif { // time the kernel at different block sizes for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size); validate_result(d_losses, losses, "losses", B * T, 1e-4f); // undo the padding before we can check for correctness cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice)); validate_result(d_dlogits_no_pad, dlogits, "dlogits", B * T * V, 1e-4f); } printf("All results match. Starting benchmarks.\n\n"); } for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, fused_classifier, kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size); printf("block_size %4d | time %f ms\n", block_size, elapsed_time); } // free memory free(logits); free(probs); free(dlogits); free(losses); free(dlosses); free(targets); free(outliers); cudaCheck(cudaFree(d_dlogits)); cudaCheck(cudaFree(d_losses)); cudaCheck(cudaFree(d_logits)); cudaCheck(cudaFree(d_dlosses)); cudaCheck(cudaFree(d_targets)); cudaCheck(cudaFree(d_dlogits_no_pad)); return 0; } ================================================ FILE: dev/cuda/common.h ================================================ #include #include #include #include #include #include #define WARP_SIZE 32U extern cudaDeviceProp deviceProp; template __host__ __device__ T ceil_div(T dividend, T divisor) { return (dividend + divisor-1) / divisor; } __device__ float warpReduceSum(float val) { for (int offset = 16; offset > 0; offset /= 2) { val += __shfl_xor_sync(0xFFFFFFFF, val, offset); } return val; } // requires all 32 threads in the warp to be active, but should work for any block size // uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes // the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end // but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 using reduction_func_t = float (*) (float); template __device__ inline float blockReduce(float val, bool final_sync, float out_of_bounds) { // two reductions of up to 1024 threads: // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) __shared__ float shared_val[WARP_SIZE]; const int lane_id = threadIdx.x % WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; float warp_val = warp_reduction(val); if (lane_id == 0) { shared_val[warp_id] = warp_val; } __syncthreads(); warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds; float block_val = warp_reduction(warp_val); if (final_sync) { __syncthreads(); // only needed in loops when effectively reusing shared memory etc. } return block_val; } // Helper function to call blockReduce with default arguments template __device__ inline float blockReduce(float val) { return blockReduce(val, false, 0.0f); } // ---------------------------------------------------------------------------- // checking utils // CUDA error checking void cuda_check(cudaError_t error, const char *file, int line) { if (error != cudaSuccess) { printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error)); exit(EXIT_FAILURE); } }; #define cudaCheck(err) (cuda_check(err, __FILE__, __LINE__)) // cuBLAS error checking void cublasCheck(cublasStatus_t status, const char *file, int line) { if (status != CUBLAS_STATUS_SUCCESS) { printf("[cuBLAS ERROR]: %d %s %d\n", status, file, line); exit(EXIT_FAILURE); } } #define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); } // ---------------------------------------------------------------------------- // cuBLAS setup // these will be initialized by setup_main // cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK static size_t cublaslt_workspace_size = 32 * 1024 * 1024; static void* cublaslt_workspace = NULL; static cublasComputeType_t cublas_compute_type; cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; int cuda_arch_major = 0; int cuda_arch_minor = 0; int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM int cuda_threads_per_SM = 0; // needed to calculate how many blocks to launch to fill up the GPU // ---------------------------------------------------------------------------- // to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance #if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 #define MAX_1024_THREADS_BLOCKS 2 #else #define MAX_1024_THREADS_BLOCKS 1 #endif // ---------------------------------------------------------------------------- // Packed128 data structure, which forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) // This is a bit similar to the use of float4 in the case of 32-bit floats, but // supports arbitrary precision. template struct alignas(16) Packed128 { // Note: = default implicitly generates a __device__ function, but explicitly // adding __device__ causes a lot of warnings. Packed128() = default; __device__ explicit Packed128(int4 bits) { static_assert(sizeof(bits) == sizeof(payload), "Size mismatch."); memcpy(&payload, &bits, sizeof(bits)); } __device__ static Packed128 constant(ElementType value) { Packed128 result; for(int k = 0; k < size; ++k) { result.payload[k] = value; } return result; } __device__ static Packed128 zeros() { return constant(0); } __device__ static Packed128 ones() { return constant(1); } __device__ ElementType& operator[](int index) { return payload[index]; } __device__ const ElementType& operator[](int index) const { return payload[index]; } __device__ int4 get_bits() const { int4 bits; static_assert(sizeof(bits) == sizeof(payload), "Size mismatch."); memcpy(&bits, &payload, sizeof(bits)); return bits; } // e.g. sizeof(int4) is 16 (4 X 4 bytes), sizeof(bfloat16) = 2, so size = 8 // so in the case where ElementType = bfloat16, we store 8 elements in one Packed128 static constexpr const int size = sizeof(int4) / sizeof(ElementType); ElementType payload[size]; }; // short-form typedef typedef Packed128 f128; // load a Packed128 from an aligned memory address template __device__ Packed128 load128(const ElementType* address) { return Packed128{*reinterpret_cast(address)}; } // load a Packed128 from an aligned memory address with streaming cache hint template __device__ Packed128 load128cs(const ElementType* address) { return Packed128{__ldcs(reinterpret_cast(address))}; } // store a Packed128 to an aligned memory address template __device__ void store128(ElementType* target, Packed128 value) { *reinterpret_cast(target) = value.get_bits(); } // store a Packed128 to an aligned memory address with streaming cache hint template __device__ void store128cs(ElementType* target, Packed128 value) { __stcs(reinterpret_cast(target), value.get_bits()); } // store a Packed128 to an aligned memory address while caching in L2 but bypassing L1 template __device__ void store128cg(ElementType* target, Packed128 value) { __stcg(reinterpret_cast(target), value.get_bits()); } // ---------------------------------------------------------------------------- // reduced/mixed precision utilities #if defined(ENABLE_BF16) typedef __nv_bfloat16 floatX; typedef __nv_bfloat16 floatN; #define CUBLAS_LOWP CUDA_R_16BF // CUDA_R_16F or CUDA_R_16BF (or CUDA_R_32F) // CUBLAS_COMPUTE_32F or CUBLAS_COMPUTE_16F (for CUDA_R_16F only, potentially slower?!) #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F #elif defined(ENABLE_FP16) typedef half floatX; typedef half floatN; #else typedef float floatX; typedef float floatN; #endif typedef Packed128 x128; // older nvcc does not provide __ldcs and __stcs for bfloat16, despite these actually just being unsigned shorts. // we need to be careful here to only define our own versions if none already exist, otherwise the compiler will // complain. // If not, you easily get "no viable overload" (for sm52) and "function already exists" (sm_80) #if defined(ENABLE_BF16) && (__CUDACC_VER_MAJOR__ < 12) && !((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__)) __device__ floatX __ldcs(const floatX* address) { unsigned short bf = __ldcs(reinterpret_cast(address)); return __nv_bfloat16_raw{bf}; } __device__ void __stcs(floatX* address, floatX value) { __stcs(reinterpret_cast(address), ((__nv_bfloat16_raw)value).x); } #endif // ---------------------------------------------------------------------------- // random utils float* make_random_float_01(size_t N) { float* arr = (float*)malloc(N * sizeof(float)); for (size_t i = 0; i < N; i++) { arr[i] = ((float)rand() / RAND_MAX); // range 0..1 } return arr; } float* make_random_float(size_t N) { float* arr = (float*)malloc(N * sizeof(float)); for (size_t i = 0; i < N; i++) { arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; // range -1..1 } return arr; } int* make_random_int(size_t N, int V) { int* arr = (int*)malloc(N * sizeof(int)); for (size_t i = 0; i < N; i++) { arr[i] = rand() % V; // range 0..V-1 } return arr; } float* make_zeros_float(size_t N) { float* arr = (float*)malloc(N * sizeof(float)); memset(arr, 0, N * sizeof(float)); // all zero return arr; } float* make_ones_float(size_t N) { float* arr = (float*)malloc(N * sizeof(float)); for (size_t i = 0; i < N; i++) { arr[i] = 1.0f; } return arr; } // ---------------------------------------------------------------------------- // testing and benchmarking utils template [[nodiscard]] cudaError_t memcpy_convert(TargetType* d_ptr, float* h_ptr, size_t count) { // copy from host to device with data type conversion. TargetType* converted = (TargetType*)malloc(count * sizeof(TargetType)); for (int i = 0; i < count; i++) { converted[i] = (TargetType)h_ptr[i]; } cudaError_t status = cudaMemcpy(d_ptr, converted, count * sizeof(TargetType), cudaMemcpyHostToDevice); free(converted); // instead of checking the status at cudaMemcpy, we return it from here. This way, we // still need to use our checking macro, and get better line info as to where the error // happened. return status; } void setup_main() { srand(0); // determinism // set up the device int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, deviceIdx); cuda_num_SMs = deviceProp.multiProcessorCount; cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor; cuda_arch_major = deviceProp.major; cuda_arch_minor = deviceProp.minor; // setup cuBLAS and cuBLASLt cublasCheck(cublasCreate(&cublas_handle)); cublasCheck(cublasLtCreate(&cublaslt_handle)); cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') int enable_tf32 = cuda_arch_major >= 8 ? 1 : 0; // TODO implement common CLI for all tests/benchmarks // if (override_enable_tf32 == 0) { enable_tf32 = 0; } // force to zero via arg cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); } template void validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) { D* out_gpu = (D*)malloc(num_elements * sizeof(D)); cudaCheck(cudaMemcpy(out_gpu, device_result, num_elements * sizeof(D), cudaMemcpyDeviceToHost)); int nfaults = 0; #ifndef ENABLE_BF16 float epsilon = FLT_EPSILON; #else float epsilon = 0.079; #endif for (int i = 0; i < num_elements; i++) { // Skip masked elements if(!isfinite(cpu_reference[i])) continue; // print the first few comparisons if (i < 5) { printf("%f %f\n", cpu_reference[i], (T)out_gpu[i]); } // effective tolerance is based on expected rounding error (epsilon), // plus any specified additional tolerance float t_eff = tolerance + fabs(cpu_reference[i]) * epsilon; // ensure correctness for all elements. if (fabs(cpu_reference[i] - (T)out_gpu[i]) > t_eff) { printf("Mismatch of %s at %d: CPU_ref: %f vs GPU: %f\n", name, i, cpu_reference[i], (T)out_gpu[i]); nfaults ++; if (nfaults >= 10) { free(out_gpu); exit(EXIT_FAILURE); } } } if (nfaults > 0) { free(out_gpu); exit(EXIT_FAILURE); } free(out_gpu); } template float benchmark_kernel(int repeats, Kernel kernel, KernelArgs&&... kernel_args) { cudaEvent_t start, stop; // prepare buffer to scrub L2 cache between benchmarks // just memset a large dummy array, recommended by // https://stackoverflow.com/questions/31429377/how-can-i-clear-flush-the-l2-cache-and-the-tlb-of-a-gpu // and apparently used in nvbench. int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cudaDeviceProp deviceProp; cudaCheck(cudaGetDeviceProperties(&deviceProp, deviceIdx)); void* flush_buffer; cudaCheck(cudaMalloc(&flush_buffer, deviceProp.l2CacheSize)); cudaCheck(cudaEventCreate(&start)); cudaCheck(cudaEventCreate(&stop)); float elapsed_time = 0.f; for (int i = 0; i < repeats; i++) { // clear L2 cudaCheck(cudaMemset(flush_buffer, 0, deviceProp.l2CacheSize)); // now we can start recording the timing of the kernel cudaCheck(cudaEventRecord(start, nullptr)); kernel(std::forward(kernel_args)...); cudaCheck(cudaEventRecord(stop, nullptr)); cudaCheck(cudaEventSynchronize(start)); cudaCheck(cudaEventSynchronize(stop)); float single_call; cudaCheck(cudaEventElapsedTime(&single_call, start, stop)); elapsed_time += single_call; } cudaCheck(cudaFree(flush_buffer)); return elapsed_time / repeats; } ================================================ FILE: dev/cuda/crossentropy_forward.cu ================================================ /* Kernels for crossentropy forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_forward.cu -o crossentropy_forward version 1 is a straight-forward port from CPU code to kernel, parallel over B,T ./crossentropy_forward 1 */ #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void crossentropy_forward_cpu(float* losses, const float* probs, const int* targets, int B, int T, int V) { // output: losses is (B,T) of the individual losses at each position // input: probs are (B,T,V) of the probabilities // input: targets is (B,T) of integers giving the correct index in logits for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // loss = -log(probs[target]) const float* probs_bt = probs + b * T * V + t * V; int ix = targets[b * T + t]; losses[b * T + t] = -logf(probs_bt[ix]); } } } // ---------------------------------------------------------------------------- // GPU kernels __global__ void crossentropy_forward_kernel1(float* losses, const float* probs, const int* targets, int B, int T, int V) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < B * T) { int b = i / T; int t = i % T; const float* probs_bt = probs + b * T * V + t * V; int ix = targets[b * T + t]; losses[b * T + t] = -logf(probs_bt[ix]); } } // ---------------------------------------------------------------------------- // kernel launcher void crossentropy_forward1(float* losses, const float* probs, const int* targets, int B, int T, int V, const int block_size) { const int N = B * T; const int grid_size = ceil_div(N, block_size); crossentropy_forward_kernel1<<>>(losses, probs, targets, B, T, V); cudaCheck(cudaGetLastError()); } // kernel version dispatch void crossentropy_forward(int kernel_num, float* losses, const float* probs, const int* targets, int B, int T, int V, const int block_size) { switch (kernel_num) { case 1: crossentropy_forward1(losses, probs, targets, B, T, V, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int V = 50257; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* out = (float*)malloc(B * T * sizeof(float)); float* probs = make_random_float_01(B * T * V); int* targets = make_random_int(B * T, V); // move to GPU float* d_out; float* d_probs; int* d_targets; cudaCheck(cudaMalloc(&d_out, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float))); cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int))); cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel crossentropy_forward_cpu(out, probs, targets, B, T, V); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); crossentropy_forward(kernel_num, d_out, d_probs, d_targets, B, T, V, block_size); validate_result(d_out, out, "out", B * T, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, crossentropy_forward, kernel_num, d_out, d_probs, d_targets, B, T, V, block_size); printf("block_size %4d | time %.4f ms | per token %.2f ns\n", block_size, elapsed_time, elapsed_time * 1'000'000 / (B*T)); } // free memory free(out); free(probs); free(targets); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_probs)); cudaCheck(cudaFree(d_targets)); return 0; } ================================================ FILE: dev/cuda/crossentropy_softmax_backward.cu ================================================ /* Kernels for crossentropy forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_softmax_backward.cu -o crossentropy_softmax_backward version 1 is a straight-forward port from CPU code to kernel, parallel over B,T ./crossentropy_softmax_backward 1 */ #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void crossentropy_softmax_backward_cpu(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V) { // backwards through both softmax and crossentropy for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* dlogits_bt = dlogits + b * T * V + t * V; const float* probs_bt = probs + b * T * V + t * V; float dloss = dlosses[b * T + t]; int ix = targets[b * T + t]; for (int i = 0; i < V; i++) { float p = probs_bt[i]; float indicator = i == ix ? 1.0f : 0.0f; dlogits_bt[i] += (p - indicator) * dloss; } } } } // ---------------------------------------------------------------------------- // GPU kernels // naive kernel that just parallelizes over B,T,V __global__ void crossentropy_softmax_backward_kernel1(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < B * T * V) { int b = i / (T * V); int t = (i / V) % T; int v = i % V; float* dlogits_bt = dlogits + b * T * V + t * V; const float* probs_bt = probs + b * T * V + t * V; float dloss = dlosses[b * T + t]; int ix = targets[b * T + t]; float p = probs_bt[v]; float indicator = v == ix ? 1.0f : 0.0f; dlogits_bt[v] += (p - indicator) * dloss; } } // ---------------------------------------------------------------------------- // kernel launcher void crossentropy_softmax_backward1(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V, const int block_size) { const int N = B * T * V; const int grid_size = ceil_div(N, block_size); crossentropy_softmax_backward_kernel1<<>>(dlogits, dlosses, probs, targets, B, T, V); cudaCheck(cudaGetLastError()); } // kernel version dispatch void crossentropy_softmax_backward(int kernel_num, float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V, const int block_size) { switch (kernel_num) { case 1: crossentropy_softmax_backward1(dlogits, dlosses, probs, targets, B, T, V, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int V = 50257; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* probs = make_random_float_01(B * T * V); int* targets = make_random_int(B * T, V); float* dlosses = make_random_float(B * T); float* dlogits = make_zeros_float(B * T * V); // move to GPU float* d_probs; int* d_targets; float* d_dlosses; float* d_dlogits; cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float))); cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(&d_dlosses, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_dlogits, B * T * V * sizeof(float))); cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_dlosses, dlosses, B * T * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; cudaCheck(cudaMemset(d_dlogits, 0, B * T * V * sizeof(float))); printf("Checking block size %d.\n", block_size); crossentropy_softmax_backward(kernel_num, d_dlogits, d_dlosses, d_probs, d_targets, B, T, V, block_size); validate_result(d_dlogits, dlogits, "dlogits", B * T * V, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, crossentropy_softmax_backward, kernel_num, d_dlogits, d_dlosses, d_probs, d_targets, B, T, V, block_size); printf("block_size %4d | time %.4f ms | per token %.2f µs\n", block_size, elapsed_time, elapsed_time * 1'000 / (B*T)); } // free memory free(probs); free(targets); free(dlosses); free(dlogits); cudaCheck(cudaFree(d_probs)); cudaCheck(cudaFree(d_targets)); cudaCheck(cudaFree(d_dlosses)); cudaCheck(cudaFree(d_dlogits)); return 0; } ================================================ FILE: dev/cuda/encoder_backward.cu ================================================ /* Kernels for the positional encoder forward pass in GPT-2. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt encoder_backward.cu -o encoder_backward version 1 is naive port from CPU code to kernel parallelizes over B,T,C, uses atomics to add to dwte, dwpe ./encoder_backward 1 version 2 is another naive port parallelizes over C, loops over B,T; much slower than version 1 ./encoder_backward 2 */ #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference // GPT-2 positional encoder forward pass void encoder_backward_cpu(float* dwte, float* dwpe, float* dout, int* inp, int B, int T, int C) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* dout_bt = dout + b * T * C + t * C; int ix = inp[b * T + t]; float* dwte_ix = dwte + ix * C; float* dwpe_t = dwpe + t * C; for (int i = 0; i < C; i++) { float d = dout_bt[i]; dwte_ix[i] += d; dwpe_t[i] += d; } } } } // ---------------------------------------------------------------------------- // GPU kernels // naive implementation with atomics __global__ void encoder_backward_kernel1(float* dwte, float* dwpe, const float* dout, const int* inp, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int N = B * T * C; if (idx < N) { int bt = idx / C; int b = bt / T; int t = bt % T; int c = idx % C; int ix = inp[b * T + t]; const float* dout_btc = dout + b * T * C + t * C + c; float* dwte_ix = dwte + ix * C + c; float* dwpe_tc = dwpe + t * C + c; atomicAdd(dwte_ix, *dout_btc); atomicAdd(dwpe_tc, *dout_btc); } } // naive implementation that parallelizes over C and loops over B,T // but it gets rid of atomics __global__ void encoder_backward_kernel2(float* dwte, float* dwpe, const float* dout, const int* inp, int B, int T, int C) { int c = blockIdx.x * blockDim.x + threadIdx.x; if (c >= C) { return; } // guard int BT = B * T; for (int i = 0; i < BT; i++) { int t = i % T; int ix = inp[i]; float dout_btc = dout[i * C + c]; dwte[ix * C + c] += dout_btc; dwpe[t * C + c] += dout_btc; } } // ---------------------------------------------------------------------------- // kernel launcher void encoder_backward1(float* dwte, float* dwpe, const float* dout, const int* inp, int B, int T, int C, const int block_size) { const int N = B * T * C; const int grid_size = ceil_div(N, block_size); encoder_backward_kernel1<<>>(dwte, dwpe, dout, inp, B, T, C); cudaCheck(cudaGetLastError()); } void encoder_backward2(float* dwte, float* dwpe, const float* dout, const int* inp, int B, int T, int C, const int block_size) { const int grid_size = ceil_div(C, block_size); encoder_backward_kernel2<<>>(dwte, dwpe, dout, inp, B, T, C); cudaCheck(cudaGetLastError()); } // kernel version dispatch void encoder_backward(int kernel_num, float* dwte, float* dwpe, const float* dout, const int* inp, int B, int T, int C, const int block_size) { switch (kernel_num) { case 1: encoder_backward1(dwte, dwpe, dout, inp, B, T, C, block_size); break; case 2: encoder_backward2(dwte, dwpe, dout, inp, B, T, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int C = 768; int V = 50257; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* dout = make_random_float(B * T * C); int* inp = make_random_int(B * T, V); float* dwte = make_zeros_float(V * C); float* dwpe = make_zeros_float(T * C); // move to GPU float* d_dout; int* d_inp; float* d_dwte; float* d_dwpe; cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * sizeof(int))); cudaCheck(cudaMalloc(&d_dwte, V * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dwpe, T * C * sizeof(float))); cudaCheck(cudaMemcpy(d_dout, dout, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_inp, inp, B * T * sizeof(int), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel encoder_backward_cpu(dwte, dwpe, dout, inp, B, T, C); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; cudaCheck(cudaMemset(d_dwte, 0, V * C * sizeof(float))); cudaCheck(cudaMemset(d_dwpe, 0, T * C * sizeof(float))); printf("Checking block size %d.\n", block_size); encoder_backward(kernel_num, d_dwte, d_dwpe, d_dout, d_inp, B, T, C, block_size); validate_result(d_dwte, dwte, "dwte", V * C, 1e-5f); validate_result(d_dwpe, dwpe, "dwpe", T * C, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, encoder_backward, kernel_num, d_dwte, d_dwpe, d_dout, d_inp, B, T, C, block_size); printf("block_size %4d | time %.4f ms\n", block_size, elapsed_time); } // free memory free(dout); free(inp); free(dwte); free(dwpe); cudaFree(d_dout); cudaFree(d_inp); cudaFree(d_dwte); cudaFree(d_dwpe); return 0; } ================================================ FILE: dev/cuda/encoder_forward.cu ================================================ /* Kernels for the positional encoder forward pass in GPT-2. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt encoder_forward.cu -o encoder_forward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C ./encoder_forward 1 version 2 is more optimized, parallelizes over all of B,T,C ./encoder_forward 2 version 3 is like version 2 but uses float4 reads/writes ./encoder_forward 3 */ #include #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference // GPT-2 positional encoder forward pass void encoder_forward_cpu(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* out_bt = out + b * T * C + t * C; int ix = inp[b * T + t]; const float* wte_ix = wte + ix * C; const float* wpe_t = wpe + t * C; for (int i = 0; i < C; i++) { out_bt[i] = wte_ix[i] + wpe_t[i]; } } } } // ---------------------------------------------------------------------------- // GPU kernels // naive implementation into kernel, parallelize over B,T, loop over C __global__ void encoder_forward_kernel1(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int N = B * T; if (idx < N) { int b = idx / T; int t = idx % T; floatX* out_bt = out + b * T * C + t * C; int ix = inp[b * T + t]; const floatX* wte_ix = wte + ix * C; const floatX* wpe_t = wpe + t * C; for (int i = 0; i < C; i++) { out_bt[i] = (floatX)((float)wte_ix[i] + (float)wpe_t[i]); } } } // optimized implementation: parallelize over all of B,T,C __global__ void encoder_forward_kernel2(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int N = B * T * C; if (idx < N) { int bt = idx / C; int b = bt / T; int t = bt % T; int c = idx % C; int ix = inp[b * T + t]; floatX* out_btc = out + b * T * C + t * C + c; const floatX* wte_ix = wte + ix * C + c; const floatX* wpe_tc = wpe + t * C + c; *out_btc = (floatX)((float)*wte_ix + (float)*wpe_tc); } } __global__ void encoder_forward_kernel3(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; int N = B * T * C; if (idx < N) { int bt = idx / C; int b = bt / T; int t = bt % T; int c = idx % C; int ix = inp[b * T + t]; floatX* out_btc = out + b * T * C + t * C + c; const floatX* wte_ix = wte + ix * C + c; const floatX* wpe_tc = wpe + t * C + c; x128 packed_out; x128 wte = load128cs(wte_ix); x128 wpe = load128cs(wpe_tc); #pragma unroll for (int k = 0; k < wte.size; k++) { packed_out[k] = (floatX)((float)wte[k] + (float)wpe[k]); } store128(out_btc, packed_out); } } // ---------------------------------------------------------------------------- // kernel launcher void encoder_forward1(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C, const int block_size) { const int N = B * T; const int grid_size = ceil_div(N, block_size); encoder_forward_kernel1<<>>(out, inp, wte, wpe, B, T, C); cudaCheck(cudaGetLastError()); } void encoder_forward2(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C, const int block_size) { const int N = B * T * C; const int grid_size = ceil_div(N, block_size); encoder_forward_kernel2<<>>(out, inp, wte, wpe, B, T, C); cudaCheck(cudaGetLastError()); } void encoder_forward3(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C, const int block_size) { const int N = B * T * C; const int grid_size = ceil_div(N, (int)(block_size * x128::size)); encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C); cudaCheck(cudaGetLastError()); } // kernel version dispatch void encoder_forward(int kernel_num, floatX* out, const int* inp, const floatX* wte, const floatX* wpe, int B, int T, int C, const int block_size) { switch (kernel_num) { case 1: encoder_forward1(out, inp, wte, wpe, B, T, C, block_size); break; case 2: encoder_forward2(out, inp, wte, wpe, B, T, C, block_size); break; case 3: encoder_forward3(out, inp, wte, wpe, B, T, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; int V = 50257; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); int* inp = make_random_int(B * T, V); float* wte = make_random_float(V * C); float* wpe = make_random_float(T * C); // move to GPU floatX* d_out; int* d_inp; floatX* d_wte; floatX* d_wpe; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp, B * T * sizeof(int))); cudaCheck(cudaMalloc(&d_wte, V * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_wpe, T * C * sizeof(floatX))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * sizeof(int), cudaMemcpyHostToDevice)); cudaCheck(memcpy_convert(d_wte, wte, V * C)); cudaCheck(memcpy_convert(d_wpe, wpe, T * C)); // read kernel_num from command line int kernel_num = 2; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel encoder_forward_cpu(out, inp, wte, wpe, B, T, C); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); encoder_forward(kernel_num, d_out, d_inp, d_wte, d_wpe, B, T, C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) float tol = 1e-5; #else float tol = 1e-2f; #endif validate_result(d_out, out, "out", B * T * C, tol); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, encoder_forward, kernel_num, d_out, d_inp, d_wte, d_wpe, B, T, C, block_size ); // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 3 reads and 1 write, 4 bytes each // and e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = B * T * C * 4 * 4; float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(out); free(inp); free(wte); free(wpe); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_wte)); cudaCheck(cudaFree(d_wpe)); return 0; } ================================================ FILE: dev/cuda/fused_residual_forward.cu ================================================ /* Kernels for residual forward pass fused with layernorm Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt fused_residual_forward.cu -o fused_residual_forward version 1 is naive port from CPU code to kernel ./fused_residual_forward 1 version 2 packs input into 128 bit memory reads ./fused_residual_forward 2 */ #include #include #include "assert.h" #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference lol void residual_forward_cpu(float* out, const float* inp1, const float* inp2, int N) { for (int i = 0; i < N; i++) { out[i] = inp1[i] + inp2[i]; } } void layernorm_forward_cpu(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C) { float eps = 1e-5f; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // seek to the input position inp[b,t,:] const float* x = inp + b * T * C + t * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { m += x[i]; } m = m/C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { float xshift = x[i] - m; v += xshift * xshift; } v = v/C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[b,t,:] float* out_bt = out + b * T * C + t * C; for (int i = 0; i < C; i++) { float n = (s * (x[i] - m)); // normalized output float o = n * weight[i] + bias[i]; // scale and shift it out_bt[i] = o; // write } // cache the mean and rstd for the backward pass later mean[b * T + t] = m; rstd[b * T + t] = s; } } } // ---------------------------------------------------------------------------- // GPU kernels // elementwise ops are nice and ez __global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]); } } // naive drag and drop implementation into kernel, parallelize over B,T, loop over C __global__ void layernorm_forward_kernel1(floatX* out, floatX* mean, floatX* rstd, const floatX* inp, const floatX* weight, const floatX* bias, int N, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; float eps = 1e-5f; if (idx < N) { // seek to the input position inp[idx,:] const floatX* x = inp + idx * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { m += (float)x[i]; } m = m / C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { float xshift = (float)x[i] - m; v += xshift * xshift; } v = v / C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[idx,:] floatX* out_idx = out + idx * C; for (int i = 0; i < C; i++) { float n = (s * ((float)x[i] - m)); // normalized output float o = n * (float)weight[i] + (float)bias[i]; // scale and shift it out_idx[i] = o; // write } // cache the mean and rstd for the backward pass later mean[idx] = m; rstd[idx] = s; } } // naive fusion; uncoalesced access pattern leads to terrible performance __global__ void fused_residual_forward2(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if(idx > N) return; // adjust pointers to current token residual += C * idx; normed += C * idx; inp1 += C * idx; inp2 += C * idx; float eps = 1e-5f; float m = 0.0f; for(int c = 0; c < C; ++c) { float out = (float)inp1[c] + (float)inp2[c]; m += out; residual[c] = (floatX)out; } m = m / C; float v = 0.0f; for (int c = 0; c < C; c++) { float xshift = (float)residual[c] - m; v += xshift * xshift; } v = v / C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); for (int c = 0; c < C; c++) { float n = (s * ((float)residual[c] - m)); // normalized output float o = n * (float)weight[c] + (float)bias[c]; // scale and shift it normed[c] = (floatX)o; // write } // cache the mean and rstd for the backward pass later mean[idx] = (floatX)m; rstd[idx] = (floatX)s; } // handle one token per warp for coalesced access __global__ void fused_residual_forward3(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { constexpr const int WarpSize = 32; assert(blockDim.x == WarpSize); int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx > N) return; // adjust pointers to current token residual += C * idx; normed += C * idx; inp1 += C * idx; inp2 += C * idx; float eps = 1e-5f; float m = 0.0f; for(int c = threadIdx.x; c < C; c += WarpSize) { float out = (float)inp1[c] + (float)inp2[c]; m += out; residual[c] = out; } m = warpReduceSum(m); m = m / C; float v = 0.0f; for(int c = threadIdx.x; c < C; c += WarpSize) { float xshift = (float)residual[c] - m; v += xshift * xshift; } v = warpReduceSum(v); v = v / C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); for(int c = threadIdx.x; c < C; c += WarpSize) { float n = (s * ((float)residual[c] - m)); // normalized output float o = n * (float)weight[c] + (float)bias[c]; // scale and shift it normed[c] = o; // write } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { mean[idx] = m; rstd[idx] = s; } } // vectorized loading, single pass stats, streaming access and zigzag loop __global__ void fused_residual_forward_kernel4(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { using x128 = Packed128; constexpr const int WarpSize = 32; assert(blockDim.x == WarpSize); int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx > N) return; // adjust pointers to current token residual += C * idx; normed += C * idx; inp1 += C * idx; inp2 += C * idx; const float eps = 1e-5f; float sum = 0.0f; float sum_sq = 0.0f; int c = threadIdx.x * x128::size; for(; c < C; c += WarpSize * x128::size) { const x128 in1 = load128cs(inp1 + c); const x128 in2 = load128cs(inp2 + c); x128 out; for(int k = 0; k < x128::size; ++k) { out[k] = (floatX)((float)in1[k] + (float)in2[k]); sum += (float)out[k]; sum_sq += (float)out[k] * (float)out[k]; } store128(residual + c, out); } sum = warpReduceSum(sum); sum_sq = warpReduceSum(sum_sq); float m = sum / C; float v = sum_sq / C - m * m; float s = rsqrtf(v + eps); c -= WarpSize * x128::size; for(; c >= 0; c -= WarpSize * x128::size) { const x128 res = load128cs(residual + c); const x128 w = load128(weight + c); const x128 b = load128(bias + c); x128 out; for(int k = 0; k < x128::size; ++k) { float n = s * ((float)res[k] - m); // normalized output float o = n * (float)w[k] + (float)b[k]; // scale and shift it out[k] = o; } store128cs(normed + c, out); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { mean[idx] = m; rstd[idx] = s; } } // what do you want in shared memory? EVERYTHING! // thus, we no longer require zigzag loops and can do the numerically more stable variance estimation // needs special attention in the kernel launcher to ensure we have enough smem. __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { constexpr const int WarpSize = 32; assert(blockDim.x == WarpSize); // load weights and biases into shared memory // do this before we allow any threads to exit! extern __shared__ char params[]; // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so // let's keep everything as x128 x128* s_weight = reinterpret_cast(params); x128* s_bias = reinterpret_cast(params) + (C / x128::size); x128* s_res = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); int sidx = (threadIdx.x + WarpSize * threadIdx.y) * x128::size; for(int i = sidx; i < C; i += blockDim.y * WarpSize * x128::size) { s_weight[i/x128::size] = load128(weight + i); s_bias[i/x128::size] = load128(bias + i); } __syncthreads(); int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx > N) return; // adjust pointers to current token residual += C * idx; normed += C * idx; inp1 += C * idx; inp2 += C * idx; const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) { const x128 in1 = load128cs(inp1 + c); const x128 in2 = load128cs(inp2 + c); x128 out; for(int k = 0; k < x128::size; ++k) { out[k] = (floatX)((float)in1[k] + (float)in2[k]); sum += (float)out[k]; } store128cs(residual + c, out); s_res[c / x128::size] = out; } sum = warpReduceSum(sum); float m = sum / C; float v = 0.f; for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) { const x128 res = s_res[c / x128::size]; for(int k = 0; k < x128::size; ++k) { v += ((float)res[k] - m) * ((float)res[k] - m); } } v = warpReduceSum(v) / C; float s = rsqrtf(v + eps); for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) { const x128 res = s_res[c / x128::size]; const x128 w = s_weight[c / x128::size]; const x128 b = s_bias[c / x128::size]; x128 out; for(int k = 0; k < x128::size; ++k) { float n = s * ((float)res[k] - m); // normalized output float o = n * (float)w[k] + (float)b[k]; // scale and shift it out[k] = o; } store128cs(normed + c, out); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { mean[idx] = m; rstd[idx] = s; } } // using multiple warps per token, and keep threads persistent, so we never have to reload weights and biases // if we had one warp per token, though, this would require us to use a huge amount of shared memory. Therefore, // we use multiple warps per token; but generally we cannot use the entire block, because that would give too // little work per warp to be effective (each warp processes 256 bfloat16 elements, so for C=768 more than 3 warps // will just mean idle). Therefore, we add a z dimension, where warps with different z handle different tokens. // all this makes the launcher logic more complicated :( __global__ void fused_residual_forward_kernel6(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C) { constexpr const int WarpSize = 32; assert(blockDim.x == WarpSize); // load weights and biases into shared memory // do this before we allow any threads to exit! extern __shared__ char params[]; // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so // let's keep everything as x128 // weights and biases are shared among all tokens x128* s_weight = reinterpret_cast(params); x128* s_bias = reinterpret_cast(params + C * sizeof(floatX)); // residual output (input to layernorm) is independent for each sub-block indicates by threadIdx.z x128* s_res = reinterpret_cast(params + (2 + threadIdx.z) * C * sizeof(floatX)); // similarly, each sub-block needs its own reduction buffers float* s_mean = reinterpret_cast(params + (2 + blockDim.z) * C * sizeof(floatX) + threadIdx.z * 32 * sizeof(float)); float* s_var = reinterpret_cast(params + (2 + blockDim.z) * C * sizeof(floatX) + 32 * sizeof(float) * (blockDim.z + threadIdx.z)); int cidx = (threadIdx.x + WarpSize * threadIdx.y) * x128::size; int step = blockDim.y * WarpSize * x128::size; for(int c = cidx; c < C; c += step) { s_weight[c / x128::size] = load128(weight + c); s_bias[c / x128::size] = load128(bias + c); } // the block-level reductions will cause sync before the first time we read these // => no syncthreads needed here // loop over all tokens for(int tidx = blockIdx.x * blockDim.z + threadIdx.z; tidx < N; tidx += gridDim.x * blockDim.z) { // adjust pointers to current token floatX* residual_bt = residual + C * tidx; floatX* normed_bt = normed + C * tidx; const floatX* inp1_bt = inp1 + C * tidx; const floatX* inp2_bt = inp2 + C * tidx; const float eps = 1e-5f; float sum = 0.0f; for (int c = cidx; c < C; c += step) { const x128 in1 = load128cs(inp1_bt + c); const x128 in2 = load128cs(inp2_bt + c); x128 out; for (int k = 0; k < x128::size; ++k) { out[k] = (float) in1[k] + (float) in2[k]; sum += (float) out[k]; } store128cs(residual_bt + c, out); s_res[c / x128::size] = out; } sum = warpReduceSum(sum); if(threadIdx.x == 0) { s_mean[threadIdx.y] = sum; } __syncthreads(); float m = warpReduceSum(threadIdx.x < blockDim.y ? s_mean[threadIdx.x] : 0.f) / C; // normally, we'd syncthread here to make sure that no warp is already at the next // iteration of the loop, messing with s_mean. The fact that we interleave s_mean and s_var means // we don't need these additional syncs. float v = 0.f; for (int c = cidx; c < C; c += step) { const x128 res = s_res[c / x128::size]; for (int k = 0; k < x128::size; ++k) { v += ((float) res[k] - m) * ((float) res[k] - m); } } v = warpReduceSum(v); if(threadIdx.x == 0) { s_var[threadIdx.y] = v; } __syncthreads(); v = warpReduceSum(threadIdx.x < blockDim.y ? s_var[threadIdx.x] : 0.f) / C; float s = rsqrtf(v + eps); for (int c = cidx; c < C; c += step) { const x128 res = s_res[c / x128::size]; const x128 w = s_weight[c / x128::size]; const x128 b = s_bias[c / x128::size]; x128 out; for (int k = 0; k < x128::size; ++k) { float n = s * ((float) res[k] - m); // normalized output float o = n * (float) w[k] + (float) b[k]; // scale and shift it out[k] = o; } store128(normed_bt + c, out); } // cache the mean and rstd for the backward pass later if (threadIdx.x == 0 && threadIdx.y == 0) { mean[tidx] = m; rstd[tidx] = s; } } } // ---------------------------------------------------------------------------- // kernel launcher void fused_residual_forward1(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { const int grid_size_resid = ceil_div(N * C, block_size); residual_forward_kernel1<<>>(residual, inp1, inp2, N*C); cudaCheck(cudaGetLastError()); const int grid_size_ln = ceil_div(N, block_size); layernorm_forward_kernel1<<>>(normed, mean, rstd, residual, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void fused_residual_forward2(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { const int grid_size = ceil_div(N, (int)(block_size)); fused_residual_forward2<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void fused_residual_forward3(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { int block_y = block_size / 32; const int grid_size = ceil_div(N, block_y); fused_residual_forward3<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void fused_residual_forward4(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { int block_y = block_size / 32; const int grid_size = ceil_div(N, block_y); fused_residual_forward_kernel4<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { int block_y = block_size / 32; const int grid_size = ceil_div(N, block_y); size_t smem = (2 + block_y) * C * sizeof(floatX); // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute // this may fail, in which case we fall back to the smem free implementation. cudaCheck(cudaGetLastError()); auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); cudaGetLastError(); if(status == cudaSuccess) { fused_residual_forward_kernel5<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); } else { fused_residual_forward_kernel4<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); } cudaCheck(cudaGetLastError()); } void fused_residual_forward6(floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { int warps_per_token = max(1, C / Packed128::size / 32); int total_warps = block_size / 32; int block_z = max(1, total_warps / warps_per_token); int block_y = max(1, total_warps / block_z); size_t smem = (2 + block_z) * C * sizeof(floatX) + 64 * sizeof(float) * block_z; // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute // this may fail, in which case we fall back to the smem free implementation. cudaCheck(cudaGetLastError()); auto status = cudaFuncSetAttribute(fused_residual_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); cudaGetLastError(); if(status == cudaSuccess) { const int num_blocks = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size); fused_residual_forward_kernel6<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); } else { const int grid_size = ceil_div(N, total_warps); fused_residual_forward_kernel4<<>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C); } cudaCheck(cudaGetLastError()); } // kernel version dispatch void fused_residual_forward(int kernel_num, floatX* residual, floatX* normed, floatX* mean, floatX* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, int N, int C, const int block_size) { switch (kernel_num) { case 1: fused_residual_forward1(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; case 2: fused_residual_forward2(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; case 3: fused_residual_forward3(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; case 4: fused_residual_forward4(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; case 5: fused_residual_forward5(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; case 6: fused_residual_forward6(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, const char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // create host memory of random numbers float* residual = (float*)malloc(B * T * C * sizeof(float)); float* normed = (float*)malloc(B * T * C * sizeof(float)); float* inp1 = make_random_float(B * T * C); float* inp2 = make_random_float(B * T * C); float* mean = (float*)malloc(B * T * sizeof(float)); float* rstd = (float*)malloc(B * T * sizeof(float)); float* weight = make_random_float(C); float* bias = make_random_float(C); // move to GPU floatX* d_residual; floatX* d_normed; floatX* d_inp1; floatX* d_inp2; floatX* d_mean; floatX* d_rstd; floatX* d_weight; floatX* d_bias; cudaCheck(cudaMalloc(&d_residual, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_normed, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_weight, C * sizeof(float))); cudaCheck(cudaMalloc(&d_bias, C * sizeof(float))); cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C)); cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C)); cudaCheck(memcpy_convert(d_weight, weight, C)); cudaCheck(memcpy_convert(d_bias, bias, C)); // first check the correctness of the kernel residual_forward_cpu(residual, inp1, inp2, B * T * C); layernorm_forward_cpu(normed, mean, rstd, residual, weight, bias, B, T, C); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); cudaCheck(cudaMemset(d_residual, 0, B * T * C * sizeof(floatX))); fused_residual_forward(kernel_num, d_residual, d_normed, d_mean, d_rstd, d_inp1, d_inp2, d_weight, d_bias, B*T, C, block_size); float tol = std::is_same_v ? 1e-5 : 5e-2; validate_result(d_residual, residual, "residual", B * T * C, tol); validate_result(d_mean, mean, "mean", B * T, tol); validate_result(d_rstd, rstd, "rstd", B * T, tol); validate_result(d_normed, normed, "normed", B * T * C, tol); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, fused_residual_forward, kernel_num, d_residual, d_normed, d_mean, d_rstd, d_inp1, d_inp2, d_weight, d_bias, B*T, C, block_size ); // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 2 reads and 2 writes, plus 2 BT writes for mean/rstd // and e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = B * T * (C * 4 + 2) * sizeof(floatX); float memory_bandwidth = memory_ops / elapsed_time / 1e6; float toks_per_msec = B * T / elapsed_time / 1e3; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s | elements: %.2f ktok/ms\n", block_size, elapsed_time, memory_bandwidth, toks_per_msec); } // free memory free(residual); free(normed); free(mean); free(rstd); free(weight); free(bias); free(inp1); free(inp2); cudaCheck(cudaFree(d_residual)); cudaCheck(cudaFree(d_normed)); cudaCheck(cudaFree(d_mean)); cudaCheck(cudaFree(d_rstd)); cudaCheck(cudaFree(d_weight)); cudaCheck(cudaFree(d_bias)); cudaCheck(cudaFree(d_inp1)); cudaCheck(cudaFree(d_inp2)); return 0; } ================================================ FILE: dev/cuda/gelu_backward.cu ================================================ /* Kernels for gelu backward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt gelu_backward.cu -o gelu_backward If encountering "error: identifier "M_PI" is undefined", add the following lines to the top of the file: #define _USE_MATH_DEFINES #include OR #include version 1 is naive port from CPU code to kernel ./gelu_backward 1 version 2 uses the Packed128 data structure ./gelu_backward 2 */ #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) void gelu_backward_cpu(float* dinp, const float* inp, const float* dout, const int N) { for (int i = 0; i < N; i++) { float x = inp[i]; float cube = 0.044715f * x * x * x; float tanh_arg = GELU_SCALING_FACTOR * (x + cube); float tanh_out = tanhf(tanh_arg); float coshf_out = coshf(tanh_arg); float sech_out = 1.0f / (coshf_out * coshf_out); float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); dinp[i] = (floatX)(local_grad * (float)dout[i]); } } // ---------------------------------------------------------------------------- // GPU kernels // elementwise ops are nice and ez __global__ void gelu_backward1(floatX* dinp, const floatX* inp, const floatX* dout, int N) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < N) { float x = (float)inp[i]; float cube = 0.044715f * x * x * x; float tanh_arg = GELU_SCALING_FACTOR * (x + cube); float tanh_out = tanhf(tanh_arg); float coshf_out = coshf(tanh_arg); float sech_out = 1.0f / (coshf_out * coshf_out); float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); dinp[i] = (floatX)(local_grad * (float)dout[i]); } } __global__ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, const int N) { int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; if (i < N) { x128 packed_dinp; x128 packed_inp = load128cs(inp + i); x128 packed_dout = load128cs(dout + i); for (int k = 0; k < packed_inp.size; ++k) { float x = (float)packed_inp[k]; float cube = 0.044715f * x * x * x; float tanh_arg = GELU_SCALING_FACTOR * (x + cube); float tanh_out = tanhf(tanh_arg); float coshf_out = coshf(tanh_arg); float sech_out = 1.0f / (coshf_out * coshf_out); float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]); } store128(dinp + i, packed_dinp); } } // ---------------------------------------------------------------------------- // kernel launcher void gelu_backward1(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) { const int grid_size = ceil_div(N, block_size); gelu_backward1<<>>(dinp, inp, dout, N); cudaCheck(cudaGetLastError()); } void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) { const int grid_size = ceil_div(N, block_size * x128::size); gelu_backward2<<>>(dinp, inp, dout, N); cudaCheck(cudaGetLastError()); } // kernel version dispatch void gelu_backward(int kernel_num, floatX* dinp, const floatX* inp, const floatX* dout, int B, int T, int C, int block_size) { switch (kernel_num) { case 1: gelu_backward1(dinp, inp, dout, B * T * C, block_size); break; case 2: gelu_backward2(dinp, inp, dout, B * T * C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; // create host memory of random numbers float* dinp = (float*)malloc(B * T * C * sizeof(float)); float* inp = make_random_float(B * T * C); float* dout = make_random_float(B * T * C); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel gelu_backward_cpu(dinp, inp, dout, B * T * C); // move to GPU floatX* d_dinp; floatX* d_inp; floatX* d_dout; cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(floatX))); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); gelu_backward(kernel_num, d_dinp, d_inp, d_dout, B, T, C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) float tol = 1e-5; #else float tol = 1e-2f; #endif validate_result(d_dinp, dinp, "dinp", B * T * C, tol); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, gelu_backward, kernel_num, d_dinp, d_inp, d_dout, B, T, C, block_size); // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each // and e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = B * T * C * 2 * 4; float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(dinp); free(inp); free(dout); cudaCheck(cudaFree(d_dinp)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_dout)); return 0; } ================================================ FILE: dev/cuda/gelu_forward.cu ================================================ /* Kernels for gelu forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt gelu_forward.cu -o gelu_forward If encountering "error: identifier "M_PI" is undefined", add the following lines to the top of the file: #define _USE_MATH_DEFINES #include OR #include version 1 is naive CPU port ./gelu_forward 1 version 2 is bfloat16 with the Packed128 data structure ./gelu_forward 2 */ #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) void gelu_forward_cpu(float* out, const float* inp, int N) { for (int i = 0; i < N; i++) { float x = inp[i]; float cube = 0.044715f * x * x * x; out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube))); } } // ---------------------------------------------------------------------------- // GPU kernels // elementwise ops are nice and ez __global__ void gelu_forward_kernel1(floatX* out, const floatX* inp, int N) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < N) { float xi = inp[i]; float cube = 0.044715f * xi * xi * xi; out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))); } } // elementwise ops are nice and ez __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) { int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; if (i < N) { x128 packed_out; x128 packed_inp = load128cs(inp + i); // load and do not keep in cache for(int k = 0; k < packed_inp.size; ++k) { float xi = (float)packed_inp[k]; float cube = 0.044715f * xi * xi * xi; packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)))); } // store instead of storecs (without cache streaming) in case it is useful for the // data to be in the cache for the next operation after this GeLU store128(out + i, packed_out); } } // ---------------------------------------------------------------------------- // kernel launcher void gelu_forward1(floatX* out, const floatX* inp, int N, const int block_size) { const int grid_size = ceil_div(N, block_size); gelu_forward_kernel1<<>>(out, inp, N); cudaCheck(cudaGetLastError()); } void gelu_forward2(floatX* out, const floatX* inp, int N, const int block_size) { const int grid_size = ceil_div(N, block_size * x128::size); gelu_forward_kernel2<<>>(out, inp, N); cudaCheck(cudaGetLastError()); } // kernel version dispatch void gelu_forward(int kernel_num, floatX* out, const floatX* inp, int B, int T, int C, int block_size) { switch (kernel_num) { case 1: gelu_forward1(out, inp, B * T * C, block_size); break; case 2: gelu_forward2(out, inp, B * T * C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, const char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* inp = make_random_float(B * T * C); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel gelu_forward_cpu(out, inp, B * T * C); // move to GPU floatX* d_out; floatX* d_inp; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) float tol = 1e-5; #else float tol = 1e-2f; #endif validate_result(d_out, out, "out", B * T * C, tol); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, gelu_forward, kernel_num, d_out, d_inp, B, T, C, block_size); // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each // and e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = B * T * C * 2 * (int)sizeof(floatX); float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(out); free(inp); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); return 0; } ================================================ FILE: dev/cuda/global_norm.cu ================================================ /* Kernels for a global norm. Global norm in this context means that we want to calculate a single norm cooperatively using all avalailable SMs, instead of multiple norms that can be handled by separate blocks. Compile example: nvcc -O3 --use_fast_math global_norm.cu -o global_norm */ #include #include #include // turn on bf16 as default, done up here for now #define ENABLE_BF16 #include "common.h" cudaDeviceProp deviceProp; float global_norm_cpu(const float* data, size_t count) { // accumulate in double so we have an accurate numerical reference double acc = 0.0; for(size_t i = 0; i < count; ++i) { acc += (double)data[i] * (double)data[i]; } return (float)acc; } template __global__ void norm_kernel1(float* out, const T* data, size_t count) { // we want as few atomics as possible, so each block tries to do // the maximum amount of work (so no fixed chunk, but instead iterating // until we run out of data), and then we reduce inside the block // and finally have just one atomic per block. namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); __shared__ float block_result[32]; // out will be updated atomically from all thread blocks size_t index = threadIdx.x + blockDim.x * blockIdx.x; size_t grid_width = blockDim.x * gridDim.x; float accumulator = 0.f; for(size_t i = index; i < count; i += grid_width) { accumulator += (float)data[i] * (float)data[i]; } // warp-level reduce float warp_result = cg::reduce(warp, accumulator, cg::plus{}); block_result[warp.meta_group_rank()] = warp_result; block.sync(); if(warp.meta_group_rank() == 0) { float gather = warp.thread_rank() < warp.meta_group_size() ? block_result[warp.thread_rank()] : 0.f; float block_sum = cg::reduce(warp, gather, cg::plus{}); if(warp.thread_rank() == 0) { atomicAdd(out, block_sum); } } } template __global__ void norm_kernel2(float* out, const T* data, size_t count) { // concrete example for an A100 GPU (108 SMs, 2048 max threads each) // so there are 2048 * 108 = 221,184 threads total // say the block_size is 512, then we would launch 432 blocks in total // say num_params is ~100M, each thread will process ~500 elements // warps reduce with warp-level reduce, we have 221,184/32 = 6,912 warps // and then each warp atomicAdd's to global memory, total of 6,912 atomics // no shared memory; but one atomic per warp instead of per block namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); // out will be updated atomically from all thread blocks size_t index = threadIdx.x + blockDim.x * blockIdx.x; size_t grid_width = blockDim.x * gridDim.x; float accumulator = 0.f; for(size_t i = index; i < count; i += grid_width) { accumulator += (float)data[i] * (float)data[i]; } // warp-level reduce float warp_result = cg::reduce(warp, accumulator, cg::plus{}); // and atomic in global buffer if(warp.thread_rank() == 0) { atomicAdd(out, warp_result); } } template __global__ void norm_kernel3(float* out, const T* data, size_t count) { size_t index = blockIdx.x * blockDim.x + threadIdx.x; size_t grid_width = blockDim.x * gridDim.x; float accumulator = 0.f; for(size_t i = index; i < count; i += grid_width) { accumulator += (float)data[i] * (float)data[i]; } // block-level reduce float block_sum = blockReduce(accumulator); if(threadIdx.x == 0) { atomicAdd(out, block_sum); } } // Same as kernel3 but without atomic adds -> this allows us to have determinism due to the // non associativity of floating point operations. Roughly same performance as kernel3. template __global__ void norm_kernel4(float* out, const T* data, size_t count) { size_t index = blockIdx.x * blockDim.x + threadIdx.x; size_t grid_width = blockDim.x * gridDim.x; float accumulator = 0.f; for(size_t i = index; i < count; i += grid_width) { accumulator += (float)data[i] * (float)data[i]; } // block-level reduce float block_sum = blockReduce(accumulator); // each block accumulates its partial sum to out[blockIdx.x] // we want to avoid using atomic add here so we combine this kernel with the aggregate kernel call // that sums up the partial block sums if(threadIdx.x == 0) { out[blockIdx.x] = block_sum; } } __global__ void global_norm_aggregate_kernel(float* out, size_t count) { size_t index = threadIdx.x; // grab block sums from the previous kernel, use 0. as the neutral sum element float block_sum = (index < count) ? out[index] : 0.f; float sum = blockReduce(block_sum); if(threadIdx.x == 0) { out[0] = sum; // out[0] ends up with the final norm squared } } // ---------------------------------------------------------------------------- // kernel launchers template void global_norm1(float* out, const T* values, size_t count, int block_size) { // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. // having one block less than possible is a tiny performance hit, having // one block too many is catastrophic, since it only can start once all the other // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 // on all gpus, so the division really is going to be exact. const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size; assert(grid_size > 0); // gives a better error than letting the call below fail norm_kernel1<<>>(out, values, count); cudaCheck(cudaGetLastError()); } template void global_norm2(float* out, const T* values, size_t count, int block_size) { // ditto const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size; assert(grid_size > 0); // gives a better error than letting the call below fail norm_kernel2<<>>(out, values, count); cudaCheck(cudaGetLastError()); } template void global_norm3(float* out, const T* values, size_t count, int block_size) { // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. // having one block less than possible is a tiny performance hit, having // one block too many is catastrophic, since it only can start once all the other // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 // on all gpus, so the division really is going to be exact. const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; assert(grid_size > 0); // gives a better error than letting the call below fail norm_kernel3<<>>(out, values, count); cudaCheck(cudaGetLastError()); } template void global_norm4(float* out, const T* values, size_t count, int block_size) { if (block_size <= 64) { block_size = 128; // to avoid triggering the assert below } // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. // having one block less than possible is a tiny performance hit, having // one block too many is catastrophic, since it only can start once all the other // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 // on all gpus, so the division really is going to be exact. const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; assert(grid_size > 0); // gives a better error than letting the call below fail assert(grid_size < 1024); // we want to later accumulate the block sums in a single block norm_kernel4<<>>(out, values, count); cudaCheck(cudaGetLastError()); global_norm_aggregate_kernel<<<1, 1024>>>(out, grid_size); cudaCheck(cudaGetLastError()); } void global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size) { switch (kernel_num) { case 1: return global_norm1(out, values, count, block_size); case 2: return global_norm2(out, values, count, block_size); case 3: return global_norm3(out, values, count, block_size); case 4: return global_norm4(out, values, count, block_size); } } int main(int argc, const char **argv) { setup_main(); cudaGetDeviceProperties(&deviceProp, 0); int C = 768; int L = 12; size_t num_params = (size_t)(C * 4*C + C*C) * 2 * L; // create host memory of random numbers float* inp = make_random_float(num_params); // scale them down for(size_t i = 0; i < num_params; ++i) { inp[i] *= 1e-3; } // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel float out = global_norm_cpu(inp, num_params); // move to GPU float* d_out; floatX* d_inp; cudaCheck(cudaMalloc(&d_out, 1024 * sizeof(float))); // 1024 needed for kernel 4 cudaCheck(cudaMalloc(&d_inp, num_params * sizeof(floatX))); cudaCheck(memcpy_convert(d_inp, inp, num_params)); int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); cudaCheck(cudaMemset(d_out, 0, sizeof(float))); global_norm(kernel_num, d_out, d_inp, num_params, block_size); validate_result(d_out, &out, "out", 1, 1e-2f); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, global_norm, kernel_num, d_out, d_inp, num_params, block_size); size_t memory_ops = num_params * sizeof(floatX); float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(inp); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); } ================================================ FILE: dev/cuda/layernorm_backward.cu ================================================ /* Kernels for layernorm backward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_backward.cu -o layernorm_backward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C ./layernorm_backward 1 version 2 moves a lot of reduction to shared memory over global memory ./layernorm_backward 2 */ #include #include #include #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void layernorm_forward_cpu(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C) { // reference: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html // both inp and out are (B,T,C) of the activations // mean and rstd are (B,T) buffers, to be used later in backward pass // at each position (b,t) of the input, the C-dimensional vector // of activations gets normalized, then scaled and shifted float eps = 1e-5f; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // seek to the input position inp[b,t,:] const float* x = inp + b * T * C + t * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { m += x[i]; } m = m/C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { float xshift = x[i] - m; v += xshift * xshift; } v = v/C; // calculate the rstd (reciprocal standard deviation) float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[b,t,:] float* out_bt = out + b * T * C + t * C; for (int i = 0; i < C; i++) { float n = (s * (x[i] - m)); // normalize float o = n * weight[i] + bias[i]; // scale and shift out_bt[i] = o; // write } // cache the mean and rstd for the backward pass later mean[b * T + t] = m; rstd[b * T + t] = s; } } } void layernorm_backward_cpu(float* dinp, float* dweight, float* dbias, const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd, int B, int T, int C) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { const float* dout_bt = dout + b * T * C + t * C; const float* inp_bt = inp + b * T * C + t * C; float* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = mean[b * T + t]; const float rstd_bt = rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; // gradient contribution to bias dbias[i] += dout_bt[i]; // gradient contribution to weight dweight[i] += norm_bti * dout_bt[i]; // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] += dval; } } } } // ---------------------------------------------------------------------------- // GPU kernels // GPU helper functions for atomicAdd on smaller than 32-bit types #ifdef ENABLE_BF16 __device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) { uintptr_t ptr_val = reinterpret_cast(addr); __nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3)); // Prepare the value to add, setting the other half to zero __nv_bfloat162 add_val = (ptr_val & 0x3) ? __halves2bfloat162(__ushort_as_bfloat16(0), val) : __halves2bfloat162(val, __ushort_as_bfloat16(0)); atomicAdd(ptr_bf16, add_val); } #endif #ifdef ENABLE_FP16 __device__ void atomicAddX(half* addr, half val) { uintptr_t ptr_val = reinterpret_cast(addr); half2* ptr_fp16 = reinterpret_cast(ptr_val & ~uintptr_t(0x3)); // Prepare the value to add, setting the other half to zero half2 add_val = (ptr_val & 0x3) ? __halves2half2(__ushort_as_half(0), val) : __halves2half2(val, __ushort_as_half(0)); atomicAdd(ptr_fp16, add_val); } #endif __device__ void atomicAddX(float* addr, float val) { atomicAdd(addr, val); } // super naive kernel that just parallelizes over B,T and loops over C __global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* dbias, const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= B*T) return; int b = idx / T; int t = idx % T; const float* dout_bt = dout + b * T * C + t * C; const float* inp_bt = inp + b * T * C + t * C; float* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = mean[b * T + t]; const float rstd_bt = rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; // gradient contribution to bias atomicAdd(&dbias[i], dout_bt[i]); // gradient contribution to weight atomicAdd(&dweight[i], norm_bti * dout_bt[i]); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] += dval; } } // uses shared memory instead for the reduces template __global__ void layernorm_backward_kernel2(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, float* dweight_tmp, float* dbias_tmp) { extern __shared__ float shared[]; // size = 2 * C namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); int N = B * T; if(idx >= N) { return; } // thread guards int b = idx / T; int t = idx % T; const Tdout* dout_bt = dout + b * T * C + t * C; const Trest* inp_bt = inp + b * T * C + t * C; Tdinp* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } __syncthreads(); // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; // gradient contribution to bias atomicAdd(&dbias_shared[i], (float)dout_bt[i]); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * (float)dout_bt[i]); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval); } __syncthreads(); // write to global memory for(int i = threadIdx.x; i < C; i+= blockDim.x) { atomicAdd(&dbias_tmp[i], dbias_shared[i]); atomicAdd(&dweight_tmp[i], dweight_shared[i]); } } template __global__ void copy_to_dweight_dbias(int C, Tparams* dbias, Tparams* dweight, float* dbias_tmp, float* dweight_tmp) { for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < C; i += blockDim.x * gridDim.x) { dbias[i] = (Tparams)dbias_tmp[i]; dweight[i] = (Tparams)dweight_tmp[i]; } } // kernel2 is 1 threadblock for all Cs on 32 BTs (assuming threadblock size of 1024 threads = 32 warps) // To minimise the amount of atomicAdds, we will aim for 1 threadblock per SM, processing (total BTs / threadblocks) BTs template __global__ void layernorm_backward_kernel3(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll 4 for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } __syncthreads(); int warps_in_grid = gridDim.x * warp.meta_group_size(); for (int idx = base_idx; idx < B * T; idx += warps_in_grid) { int b = idx / T; int t = idx % T; const Tdout* dout_bt = dout + b * T * C + t * C; const Trest* inp_bt = inp + b * T * C + t * C; Tdinp* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warp.thread_rank(); i < C; i += warp.size()) { float dout_i = (float)__ldcs(&dout_bt[i]); float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * dout_i; // gradient contribution to bias atomicAdd(&dbias_shared[i], dout_i); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval); } } __syncthreads(); for(int i = threadIdx.x; i < C; i+= blockDim.x) { atomicAddX(&dbias[i], (Tparams)dbias_shared[i]); atomicAddX(&dweight[i], (Tparams)dweight_shared[i]); } } // atomicCAS version of kernel3 template __global__ void layernorm_backward_kernel4(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll 4 for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } __syncthreads(); int warps_in_grid = gridDim.x * warp.meta_group_size(); for (int idx = base_idx; idx < B * T; idx += warps_in_grid) { int b = idx / T; int t = idx % T; const Tdout* dout_bt = dout + b * T * C + t * C; const Trest* inp_bt = inp + b * T * C + t * C; Tdinp* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warp.thread_rank(); i < C; i += warp.size()) { float dout_i = (float)__ldcs(&dout_bt[i]); float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * dout_i; // gradient contribution to bias atomicAdd(&dbias_shared[i], dout_i); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval); } } __syncthreads(); __nv_bfloat162* dbiasVec2 = reinterpret_cast<__nv_bfloat162*>(dbias); __nv_bfloat162* dweightVec2 = reinterpret_cast<__nv_bfloat162*>(dweight); // write to global memory for(int i = threadIdx.x; i < C/2; i+= blockDim.x) { __nv_bfloat162 add_dbias = __halves2bfloat162((__nv_bfloat16)dbias_shared[i*2], (__nv_bfloat16)dbias_shared[i*2+1]); __nv_bfloat162 add_dweight = __halves2bfloat162((__nv_bfloat16)dweight_shared[i*2], (__nv_bfloat16)dweight_shared[i*2+1]); // Get the current value from L2 cache __nv_bfloat162 current_dbias = __ldcg(&dbiasVec2[i]); __nv_bfloat162 current_dweight = __ldcg(&dweightVec2[i]); // Add the two values __nv_bfloat162 new_dbias = add_dbias + current_dbias; __nv_bfloat162 new_dweight = add_dweight + current_dweight; // Write the result back to L2 cache using 32-bit integer atomic compare and exchange unsigned int current_dbias32b = *reinterpret_cast(¤t_dbias); unsigned int current_dweight32b = *reinterpret_cast(¤t_dweight); unsigned int new_dbias32b = *reinterpret_cast(&new_dbias); unsigned int new_dweight32b = *reinterpret_cast(&new_dweight); unsigned int old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b); unsigned int old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b); // If the value has changed between read and atomic, we need to try again while (old_dbias32b != current_dbias32b) { current_dbias32b = old_dbias32b; new_dbias = *reinterpret_cast<__nv_bfloat162*>(¤t_dbias32b) + add_dbias; new_dbias32b = *reinterpret_cast(&new_dbias); old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b); } while (old_dweight32b != current_dweight32b) { current_dweight32b = old_dweight32b; new_dweight = *reinterpret_cast<__nv_bfloat162*>(¤t_dweight32b) + add_dweight; new_dweight32b = *reinterpret_cast(&new_dweight); old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b); } } } // FP32 scratchpad per threadgroup, zero atomics except atomicAdd on unsigned int for the flag (based on kernel3) template __global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C + 1 namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll 4 for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } unsigned int *tmp_flag = (unsigned int*)(shared + C*2); __syncthreads(); int warps_in_grid = gridDim.x * warp.meta_group_size(); for (int idx = base_idx; idx < B * T; idx += warps_in_grid) { int b = idx / T; int t = idx % T; const Tdout* dout_bt = dout + b * T * C + t * C; const Trest* inp_bt = inp + b * T * C + t * C; Tdinp* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warp.thread_rank(); i < C; i += warp.size()) { float dout_i = (float)__ldcs(&dout_bt[i]); float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * dout_i; // gradient contribution to bias atomicAdd(&dbias_shared[i], dout_i); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval); } } __syncthreads(); float* scratch_dbias = scratch; float* scratch_dweight = scratch + C * gridDim.x; unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C * gridDim.x)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { scratch_dbias[i + C*blockIdx.x] = dbias_shared[i]; scratch_dweight[i + C*blockIdx.x] = dweight_shared[i]; } __threadfence(); __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicAdd(scratchFlag, 1); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { // last block to finish, accumulate the scratchpad for (int i = threadIdx.x; i < C; i += blockDim.x) { float dbias_sum = 0.0f; float dweight_sum = 0.0f; #pragma unroll 8 for (int j = 0; j < gridDim.x; j++) { dbias_sum += scratch_dbias[i + j*C]; dweight_sum += scratch_dweight[i + j*C]; } dbias[i] = (Tparams)((float)dbias[i] + dbias_sum); dweight[i] = (Tparams)((float)dweight[i] + dweight_sum); } } } // single FP32 scratchpad shared by all the threadblocks (based on kernels 3 & 5) template __global__ void layernorm_backward_kernel6(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C + 1 namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll 4 for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } unsigned int *tmp_flag = (unsigned int*)(shared + C*2); __syncthreads(); int warps_in_grid = gridDim.x * warp.meta_group_size(); for (int idx = base_idx; idx < B * T; idx += warps_in_grid) { int b = idx / T; int t = idx % T; const Tdout* dout_bt = dout + b * T * C + t * C; const Trest* inp_bt = inp + b * T * C + t * C; Tdinp* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warp.thread_rank(); i < C; i += warp.size()) { float dout_i = (float)__ldcs(&dout_bt[i]); float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * dout_i; // gradient contribution to bias atomicAdd(&dbias_shared[i], dout_i); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval); } } // Accumulate into a FP32 scratchpad // BF16 atomics are potentially much slower... and this is more precise! __syncthreads(); float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { atomicAdd(&scratch_dbias[i], dbias_shared[i]); atomicAdd(&scratch_dweight[i], dweight_shared[i]); } __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicAdd(scratchFlag, 1); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { for(int i = threadIdx.x; i < C; i+= blockDim.x) { // todo - potentially do stochastic rounding here as well dbias[i] = (Tparams)scratch_dbias[i]; dweight[i] = (Tparams)scratch_dweight[i]; } } } // Same as kernel 6 but without cooperative groups or templates __global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C + 1 int warpId = threadIdx.x / warpSize; // warp index within a block int warpsInBlock = blockDim.x / warpSize; int base_idx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp int warps_in_grid = gridDim.x * warpsInBlock; // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero #pragma unroll 4 for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } unsigned int *tmp_flag = (unsigned int*)(shared + C*2); __syncthreads(); for (int idx = base_idx; idx < B * T; idx += warps_in_grid) { int b = idx / T; int t = idx % T; const floatX* dout_bt = dout + b * T * C + t * C; const floatX* inp_bt = inp + b * T * C + t * C; floatX* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx; i < C; i += warpSize) { float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * (float)dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = warpReduceSum(dnorm_mean); dnorm_norm_mean = warpReduceSum(dnorm_norm_mean); dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C; // now iterate again and accumulate all the gradients for (int i = warpThreadIdx; i < C; i += warpSize) { float dout_i = (float)__ldcs(&dout_bt[i]); float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; float dnorm_i = (float)weight[i] * dout_i; // gradient contribution to bias atomicAdd(&dbias_shared[i], dout_i); // gradient contribution to weight atomicAdd(&dweight_shared[i], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp_bt[i] = (floatX)((float)dinp_bt[i] + dval); } } // Accumulate into a FP32 scratchpad // BF16 atomics are potentially much slower... and this is more precise! __syncthreads(); float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { atomicAdd(&scratch_dbias[i], dbias_shared[i]); atomicAdd(&scratch_dweight[i], dweight_shared[i]); } __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicAdd(scratchFlag, 1); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { for(int i = threadIdx.x; i < C; i+= blockDim.x) { // todo - potentially do stochastic rounding here as well dbias[i] = (floatX)scratch_dbias[i]; dweight[i] = (floatX)scratch_dweight[i]; } } } __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C + 1 int warpId = threadIdx.x / warpSize; // warp index within a block int warpsInBlock = blockDim.x / warpSize; //number of warps in block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = warpSize * x128::size; int iterations_C = C / C_per_iteration; // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; // init shared memory to zero for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } unsigned int *tmp_flag = (unsigned int*)(shared + C*2); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { int b = idx / T; int t = idx % T; const floatX* dout_bt = dout + b * T * C + t * C; const floatX* inp_bt = inp + b * T * C + t * C; floatX* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx * x128::size; i < C; i += warpSize * x128::size) { x128 dout128_i = load128(dout_bt + i); x128 inp128_i = load128(inp_bt + i); x128 weight128_i = load128(weight + i); for (int k = 0; k < x128::size; k++) { float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } } dnorm_mean = warpReduceSum(dnorm_mean) / C; dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; // now iterate again and accumulate all the gradients // unfortunately we cannot use the same index for x128 arrays and shared memory // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) // so this would result in an 8-way bank conflict, and kill performance // so instead, we use a shared memory friendly index, and reorder before the final write for (int i = 0; i < iterations_C; i++) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); x128 dout128 = load128cs(dout_bt + global_index); x128 inp128 = load128cs(inp_bt + global_index); x128 dinp128 = load128(dinp_bt + global_index); x128 weight128 = load128(weight + global_index); for (int x = 0; x < x128::size; x++) { float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128[x] * dout_i; // gradient contribution to bias (using shared memory friendly index) atomicAdd(&dbias_shared[shared_index + x*warpSize], dout_i); // gradient contribution to weight (using shared memory friendly index) atomicAdd(&dweight_shared[shared_index + x*warpSize], norm_bti * dout_i); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp128[x] = (floatX)((float)dinp128[x] + dval); } // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing store128cg(dinp_bt + global_index, dinp128); } } // Accumulate into a FP32 scratchpad // BF16 atomics are potentially much slower... and this is more precise! // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though __syncthreads(); float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { // global atomics in the same "shared memory banking friendly" order atomicAdd(&scratch_dbias[i], dbias_shared[i]); atomicAdd(&scratch_dweight[i], dweight_shared[i]); } __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { for (int i = warpId; i < iterations_C; i += warpsInBlock) { // reorder from atomic/shared memory-friendly index to real global memory index // and convert from float/FP32 to floatX/BF16 for the final write int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for (int x = 0; x < x128::size; x++) { float s_db = scratch_dbias[shared_index + x*warpSize]; float s_dw = scratch_dweight[shared_index + x*warpSize]; dbias128[x] = (floatX)(s_db + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); } store128(dbias + global_index, dbias128); store128(dweight + global_index, dweight128); } } } __global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { if(C % (32 * x128::size) != 0) { if(threadIdx.x == 0 && blockIdx.x == 0) { printf("Number of channels is not a multiple of 32 * x128::size"); } __trap(); // prefer to crash here than run into a deadlock later on } int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; // size = 2 * C + 1 int warpId = threadIdx.x / WARP_SIZE; // warp index within a block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = WARP_SIZE * x128::size; int iterations_C = ceil_div(C, C_per_iteration) + 2; // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; float* dbias_tmp_shared = shared + 2 * C; float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; // init shared memory to zero for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { int b = idx / T; int t = idx % T; const floatX* dout_bt = dout + b * T * C + t * C; const floatX* inp_bt = inp + b * T * C + t * C; floatX* dinp_bt = dinp + b * T * C + t * C; const float mean_bt = (float)mean[b * T + t]; const float rstd_bt = (float)rstd[b * T + t]; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { x128 dout128_i = load128(dout_bt + i); x128 inp128_i = load128(inp_bt + i); x128 weight128_i = load128(weight + i); for (int k = 0; k < x128::size; k++) { float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } } dnorm_mean = warpReduceSum(dnorm_mean) / C; dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; // now iterate again and accumulate all the gradients // unfortunately we cannot use the same index for x128 arrays and shared memory // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) // so this would result in an 8-way bank conflict, and kill performance // so instead, we use a shared memory friendly index, and reorder before the final write for (int i = 0; i < iterations_C; i++) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); if (global_index >= C) { break; } x128 dout128 = load128cs(dout_bt + global_index); x128 inp128 = load128cs(inp_bt + global_index); x128 dinp128 = load128(dinp_bt + global_index); x128 weight128 = load128(weight + global_index); for (int x = 0; x < x128::size; x++) { float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128[x] * dout_i; // sum up the gradients for bias and weight across the entire block // this is basically a reduction (but only inter-warp, not intra-warp) // doing it this way allows us to avoid using atomics while using many warps if (warpId != 0) { dbias_tmp_shared[threadIdx.x] = dout_i; dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; } __syncthreads(); if (warpId == 0) { float dbias_tmp = dout_i; float dweight_tmp = norm_bti * dout_i; for (int j = 1; j < warpsInBlock; j++) { dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; } // gradient contribution to bias (using shared memory friendly index) dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; // gradient contribution to weight (using shared memory friendly index) dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; } __syncthreads(); // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp128[x] = (floatX)((float)dinp128[x] + dval); } // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing store128cg(dinp_bt + global_index, dinp128); } } __syncthreads(); // Each block writes its partial sum to global memory // The last block to finish becomes responsible for summing up all the partial sums // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) unsigned int* scratchFlag = (unsigned int*)(scratch); // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { // Write to global memory in the same "shared memory banking friendly" order scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; } __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { // Reduction of the partial sums by the final block // todo - there isn't enough parallelism even inside that single SM... // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { f128 dbias_accum = f128::zeros(); f128 dweight_accum = f128::zeros(); for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { int offset = i + 2*C*read_block_idx; f128 dbias128 = load128(scratch_dbias + offset); f128 dweight128 = load128(scratch_dweight + offset); for(int k = 0; k < f128::size; k++) { dbias_accum[k] += dbias128[k]; dweight_accum[k] += dweight128[k]; } } store128(dbias_shared + i, dbias_accum); store128(dweight_shared + i, dweight_accum); } __syncthreads(); // reorder from atomic/shared memory-friendly index to real global memory index // and convert from float/FP32 to floatX/BF16 for the final write // this is separate also because it cannot use as many warps as the above (f128 vs x128) // todo - if we split this code into another kernel, we could maybe do it at the same time? for (int i = warpId; i < iterations_C; i += warpsInBlock) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); if (global_index >= C) { break; } x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for (int x = 0; x < x128::size; x++) { float s_db = dbias_shared[shared_index + x*WARP_SIZE]; float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; dbias128[x] = (floatX)(s_db + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); } store128(dbias + global_index, dbias128); store128(dweight + global_index, dweight128); } } } // similar to kernel 9, but uses vectors to access shared memory, which also avoids the bank conflict problems, // and makes use require fewer barriers, at the cost of increased shared memory consumption. // warning: this kernel is _extremely_ close to getting register spills, so many "optimizations" turn out to be unhelpful // or need to be implemented in a very specific way. __global__ void __launch_bounds__(512, 2) layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; // size = 2 * C + 1 int warpId = threadIdx.x / WARP_SIZE; // warp index within a block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = WARP_SIZE * x128::size; int iterations_C = ceil_div(C, C_per_iteration); // + 2; // the first half of shared memory is bias, second is weight size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size); float* dbias_shared = shared; float* dweight_shared = shared + rounded_C; // warp zero doesn't actually write to the _tmp_shared memory locations, so we don't need to reserve memory // the obvious solution is to change the addressing below to use (threadId.x-32) as offset, but that causes // register spills, so instead we mess with the base pointer here, which doesn't increase register usage. float* dbias_tmp_shared = shared + 2 * rounded_C - WARP_SIZE * f128::size; float* dweight_tmp_shared = shared + 2 * rounded_C + f128::size * BLOCK_SIZE - 2 * WARP_SIZE * f128::size; // init shared memory to zero for(int i = threadIdx.x * f128::size; i < rounded_C; i += BLOCK_SIZE * f128::size) { store128(dbias_shared + i, f128::zeros()); store128(dweight_shared + i, f128::zeros()); } __syncthreads(); for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { const floatX* dout_bt = dout + bt * C; const floatX* inp_bt = inp +bt * C; floatX* dinp_bt = dinp + bt * C; // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { x128 dout128_i = load128(dout_bt + i); x128 inp128_i = load128(inp_bt + i); x128 weight128_i = load128(weight + i); for (int k = 0; k < x128::size; k++) { float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * (float)inp128_i[k]; } } const float mean_bt = (float)mean[bt]; const float rstd_bt = (float)rstd[bt]; dnorm_mean = warpReduceSum(dnorm_mean) / C; dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt - dnorm_mean * mean_bt * rstd_bt; for (int c = 0; c < iterations_C; c++) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); x128 dout128 = x128::zeros(); x128 inp128 = x128::zeros(); x128 dinp128 = x128::zeros(); x128 weight128 = x128::zeros(); if(global_index < C) { dout128 = load128cs(dout_bt + global_index); inp128 = load128cs(inp_bt + global_index); dinp128 = load128(dinp_bt + global_index); weight128 = load128(weight + global_index); } for(int o = 0; o < x128::size / f128::size; ++o) { f128 dbias_f; f128 dweight_f; for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; dbias_f[i] = dout_i; dweight_f[i] = norm_bti * dout_i; float dval = 0.0f; dval += (float) weight128[x] * (float)dout128[x]; // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale dinp128[x] = (floatX) ((float) dinp128[x] + dval); } if (warpId != 0) { store128(dbias_tmp_shared + threadIdx.x * f128::size, dbias_f); // this seems to generate a 64-bit store, instead of 128-bit. // however, forcing 128-bit (e.g., using inline ptx), results in register // spilling and much worse performance, so we'll keep it like this for now // but ideally, we could reduce the register pressure a little. store128(dweight_tmp_shared + threadIdx.x * f128::size, dweight_f); } __syncthreads(); if (warpId == 0) { for (int j = 1; j < warpsInBlock; j++) { f128 dbias_tmp = load128(dbias_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE)); f128 dweight_tmp = load128(dweight_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE)); for(int i = 0; i < f128::size; ++i) { dbias_f[i] += dbias_tmp[i]; dweight_f[i] += dweight_tmp[i]; } } } __syncthreads(); if (warpId == 0) { f128 db_old = load128(dbias_shared + global_index + f128::size * o); f128 dw_old = load128(dweight_shared + global_index + f128::size * o); for(int i = 0; i < f128::size; ++i) { dbias_f[i] += db_old[i]; dweight_f[i] += dw_old[i]; } store128(dbias_shared + global_index + f128::size * o, dbias_f); store128(dweight_shared + global_index + f128::size * o, dweight_f); } } if(global_index < C) { // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing store128cg(dinp_bt + global_index, dinp128); } } } __syncthreads(); // Each block writes its partial sum to global memory // The last block to finish becomes responsible for summing up all the partial sums // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) unsigned int* scratchFlag = (unsigned int*)(scratch); // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { // Write to global memory in the same "shared memory banking friendly" order store128(scratch_dbias + i + 2*C*blockIdx.x, load128(dbias_shared + i)); store128(scratch_dweight + i + 2*C*blockIdx.x, load128(dweight_shared + i)); } __syncthreads(); // that portion of shared memory is no longer used, so we can repurpose it for the scratch flag. unsigned int *tmp_flag = (unsigned int*)(shared + 2*rounded_C); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { // Reduction of the partial sums by the final block // todo - there isn't enough parallelism even inside that single SM... // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { f128 dbias_accum = f128::zeros(); f128 dweight_accum = f128::zeros(); for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { int offset = i + 2*C*read_block_idx; f128 dbias128 = load128(scratch_dbias + offset); f128 dweight128 = load128(scratch_dweight + offset); for(int k = 0; k < f128::size; k++) { dbias_accum[k] += dbias128[k]; dweight_accum[k] += dweight128[k]; } } store128(dbias_shared + i, dbias_accum); store128(dweight_shared + i, dweight_accum); } __syncthreads(); // convert from float/FP32 to floatX/BF16 for the final write // this is separate because it cannot use as many warps as the above (f128 vs x128) // todo - if we split this code into another kernel, we could maybe do it at the same time? for (int c = warpId; c < iterations_C; c += warpsInBlock) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); if (global_index >= C) { break; } x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for(int o = 0; o < x128::size / f128::size; ++o) { f128 s_db = load128(dbias_shared + global_index + o * f128::size); f128 s_dw = load128(dweight_shared + global_index + o * f128::size); for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; dbias128[x] = (floatX)(s_db[i] + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]); } } store128(dbias + global_index, dbias128); store128(dweight + global_index, dweight128); } } } // ---------------------------------------------------------------------------- // kernel launchers void layernorm_backward1(float* dinp, float* dweight, float* dbias, const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd, int B, int T, int C, const int block_size) { const int N = B * T; const int grid_size = ceil_div(N, block_size); layernorm_backward_kernel1<<>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward2(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int N = B * T; const int grid_size = ceil_div(32*N, block_size); size_t shared_mem_size = 2 * C * sizeof(float); float* dweight_tmp; float* dbias_tmp; cudaCheck(cudaMalloc(&dweight_tmp, C * sizeof(float))); cudaCheck(cudaMalloc(&dbias_tmp, C * sizeof(float))); cudaMemset(dweight_tmp, 0, C * sizeof(float)); cudaMemset(dbias_tmp, 0, C * sizeof(float)); layernorm_backward_kernel2<<>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, dweight_tmp, dbias_tmp); copy_to_dweight_dbias<<<1, 512>>>(C, dweight, dbias, dweight_tmp, dbias_tmp); cudaCheck(cudaFree(dweight_tmp)); cudaCheck(cudaFree(dbias_tmp)); } template void layernorm_backward3(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = (1024/block_size) * cuda_num_SMs; size_t shared_mem_size = 2 * C * sizeof(float); layernorm_backward_kernel3<<>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward4(Tdinp* dinp, Tparams* dweight, Tparams* dbias, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = (1024/block_size) * cuda_num_SMs; size_t shared_mem_size = 2 * C * sizeof(float); layernorm_backward_kernel4<<>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = 1 * cuda_num_SMs; // only support 1 block per SM for simplicity, 1024 threads is best anyway size_t shared_mem_size = (2 * C + 1) * sizeof(float); cudaMemset(scratch, 0, (grid_size * 2 * C + 1) * sizeof(float)); layernorm_backward_kernel5<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward6(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = (1024/block_size) * cuda_num_SMs; size_t shared_mem_size = (2 * C + 1) * sizeof(float); // Including this as part of the timing until we can parallelise it // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float)); layernorm_backward_kernel6<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward7(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = (1024/block_size) * cuda_num_SMs; size_t shared_mem_size = (2 * C + 1) * sizeof(float); // Including this as part of the timing until we can parallelise it // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float)); layernorm_backward_kernel7<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { const int grid_size = (1024/block_size) * cuda_num_SMs; size_t shared_mem_size = (2 * C + 1) * sizeof(float); // Including this as part of the timing until we can parallelise it // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float)); layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { assert(C % (32 * x128::size) == 0 && "Channels must be divisible by (32 * x128::size)"); const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs? size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float); cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } template void layernorm_backward10(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, int B, int T, int C, int block_size) { if(block_size == 1024) { block_size = 512; } //assert(C % (32 * x128::size) == 0 && "Channels must be divisible by (32 * x128::size)"); const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs? size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size); size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); cudaCheck(cudaMemset(scratch, 0, 1 * sizeof(float))); // just need to memset the flag for this version layernorm_backward_kernel10<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } // kernel version dispatch void layernorm_backward(int kernel_num, floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C, const int block_size) { switch (kernel_num) { #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) case 1: layernorm_backward1(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size); break; #endif case 2: layernorm_backward2(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 3: layernorm_backward3(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size); break; #if defined(ENABLE_BF16) case 4: layernorm_backward4(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size); break; #endif case 5: layernorm_backward5(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 6: layernorm_backward6(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 7: layernorm_backward7(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 8: layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 9: layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; case 10: layernorm_backward10(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } cudaCheck(cudaGetLastError()); } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 1600; // this is the problematic size // first do the forward pass in CPU float* out = (float*)malloc(B * T * C * sizeof(float)); float* mean = (float*)malloc(B * T * sizeof(float)); float* rstd = (float*)malloc(B * T * sizeof(float)); float* inp = make_random_float(B * T * C); float* weight = make_random_float(C); float* bias = make_random_float(C); layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C); // now do the backward pass, again on CPU float *dout = make_random_float(B * T * C); float *dinp = make_zeros_float(B * T * C); float *dweight = make_zeros_float(C); float *dbias = make_zeros_float(C); layernorm_backward_cpu(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C); // the above calculations act as the reference // now let's do the same on the GPU // read kernel_num from command line int kernel_num = 2; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // move all the variables we need for backward pass onto the GPU floatX* d_dinp; floatX* d_dweight; floatX* d_dbias; floatX* d_dout; floatX* d_inp; floatX* d_weight; floatX* d_mean; floatX* d_rstd; float* d_scratch; cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_dweight, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_dbias, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX))); cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float))); // copy over the "inputs" to the backward call cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); cudaCheck(memcpy_convert(d_weight, weight, C)); cudaCheck(memcpy_convert(d_mean, mean, B * T)); cudaCheck(memcpy_convert(d_rstd, rstd, B * T)); // launch the kernel // removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?! int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // init the "outputs" of the backward call to zeros cudaCheck(cudaMemset(d_dinp, 0, B * T * C * sizeof(floatX))); cudaCheck(cudaMemset(d_dweight, 0, C * sizeof(floatX))); cudaCheck(cudaMemset(d_dbias, 0, C * sizeof(floatX))); layernorm_backward(kernel_num, d_dinp, d_dweight, d_dbias, d_scratch, d_dout, d_inp, d_weight, d_mean, d_rstd, B, T, C, block_size); // check the correctness of the kernel float error_threshold_dinp = sizeof(floatX) == 4 ? 1e-3f : 1e-1f; // allow larger errors for BF16/FP16 float error_threshold_dparams = sizeof(floatX) == 4 ? 1e-3f : 5e-1f; // much, much larger... printf("Checking correctness...\n"); printf("dinp:\n"); validate_result(d_dinp, dinp, "dinp", B * T * C, error_threshold_dinp); printf("dweight:\n"); validate_result(d_dweight, dweight, "dweight", C, error_threshold_dparams); printf("dbias:\n"); validate_result(d_dbias, dbias, "dbias", C, error_threshold_dparams); printf("All results match for block_size=%d.\n\n", block_size); } // now time the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, layernorm_backward, kernel_num, d_dinp, d_dweight, d_dbias, d_scratch, d_dout, d_inp, d_weight, d_mean, d_rstd, B, T, C, block_size); printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } // cleanups free(out); free(mean); free(rstd); free(inp); free(weight); free(bias); free(dout); free(dinp); free(dweight); free(dbias); cudaCheck(cudaFree(d_dinp)); cudaCheck(cudaFree(d_dweight)); cudaCheck(cudaFree(d_dbias)); cudaCheck(cudaFree(d_dout)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_weight)); cudaCheck(cudaFree(d_mean)); cudaCheck(cudaFree(d_rstd)); cudaCheck(cudaFree(d_scratch)); return 0; } ================================================ FILE: dev/cuda/layernorm_forward.cu ================================================ /* Kernels for layernorm forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C ./layernorm_forward 1 version 2 parallelizes over all of B,T,C ./layernorm_forward 2 version 3 uses cooperative groups to parallelize over all of B,T,C ./layernorm_forward 3 version 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mean(x)**2 (allowing us to do a single pass over x on load) ./layernorm_forward 4 verstion 5 allocates blocks per row instead of warps per row, same alg as 4 otherwise ./layernorm_forward 5 */ #include #include #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference // GPT-2 layernorm forward pass void layernorm_forward_cpu(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C) { float eps = 1e-5f; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // seek to the input position inp[b,t,:] const float* x = inp + b * T * C + t * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { m += x[i]; } m = m/C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { float xshift = x[i] - m; v += xshift * xshift; } v = v/C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[b,t,:] float* out_bt = out + b * T * C + t * C; for (int i = 0; i < C; i++) { float n = (s * (x[i] - m)); // normalized output float o = n * weight[i] + bias[i]; // scale and shift it out_bt[i] = o; // write } // cache the mean and rstd for the backward pass later mean[b * T + t] = m; rstd[b * T + t] = s; } } } // ---------------------------------------------------------------------------- // GPU kernels // naive drag and drop implementation into kernel, parallelize over B,T, loop over C __global__ void layernorm_forward_kernel1(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int N, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; float eps = 1e-5f; if (idx < N) { // seek to the input position inp[idx,:] const float* x = inp + idx * C; // calculate the mean float m = 0.0f; for (int i = 0; i < C; i++) { m += x[i]; } m = m / C; // calculate the variance (without any bias correction) float v = 0.0f; for (int i = 0; i < C; i++) { float xshift = x[i] - m; v += xshift * xshift; } v = v / C; // calculate the rstd float s = 1.0f / sqrtf(v + eps); // seek to the output position in out[idx,:] float* out_idx = out + idx * C; for (int i = 0; i < C; i++) { float n = (s * (x[i] - m)); // normalized output float o = n * weight[i] + bias[i]; // scale and shift it out_idx[i] = o; // write } // cache the mean and rstd for the backward pass later mean[idx] = m; rstd[idx] = s; } } __global__ void mean_kernel(float* mean, const float* inp, int N, int C, int block_size) { extern __shared__ float shared[]; int idx = blockIdx.x; // range [0, B*T) int tid = threadIdx.x; // range [0, block_size) const float* x = inp + idx * C; // thread coarsening float sum = 0.0f; for (int i = tid; i < C; i += block_size) { sum += x[i]; } shared[tid] = sum; __syncthreads(); // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] += shared[tid + stride]; } } // write the final result (at thread 0) to global memory if (tid == 0) { mean[idx] = shared[0] / C; } } __global__ void rstd_kernel(float* rstd, const float* inp, const float* mean, int N, int C, int block_size) { extern __shared__ float shared[]; int idx = blockIdx.x; // range [0, B*T) int tid = threadIdx.x; // range [0, block_size) const float* x = inp + idx * C; float m = mean[idx]; // thread coarsening float sum = 0.0f; for (int i = tid; i < C; i += block_size) { float diff = x[i] - m; sum += diff * diff; } shared[tid] = sum; __syncthreads(); // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] += shared[tid + stride]; } } // write the final result (at thread 0) to global memory if (tid == 0) { rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f); } } __global__ void normalization_kernel(float* out, const float* inp, float* mean, float* rstd, const float* weight, const float* bias, int B, int T, int C) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int bt = idx / C; int c = idx % C; float m = mean[bt]; float s = rstd[bt]; float xi = inp[idx]; float n = s * (xi - m); float o = n * weight[c] + bias[c]; out[idx] = o; } __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const float* __restrict__ inp, const float* __restrict__ weight, const float* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); // meta_group_size is the number of warps in a block, and meta_group_rank is the warp index int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N) { return; } // the row of input that this group of threads is responsible for const float* x = inp + idx * C; // mean float sum = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { sum += x[i]; } sum = cg::reduce(warp, sum, cg::plus{}); float m = sum / C; if(warp.thread_rank() == 0 && mean != nullptr) { __stcs(mean + idx, m); } // rstd sum = 0.0f; for (int i = warp.thread_rank(); i < C; i += warp.size()) { float diff = x[i] - m; sum += diff * diff; } sum = cg::reduce(warp, sum, cg::plus{}); float s = rsqrtf(sum / C + 1e-5f); if(warp.thread_rank() == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } // final normalization and scaling by weight/bias float* o = out + idx * C; for (int c = warp.thread_rank(); c < C; c += warp.size()) { // load and store using the .cs "streaming" hint to the compiler, // indicating that this data will not be reused soon, and can be streamed through the caches // this allows the threads to get more cache-hits for the (shared) weight and bias parameters float n = s * (__ldcs(x+c) - m); __stcs(o+c, n * weight[c] + bias[c]); } } // same as kernel 3 but uses var(x) == mean(x**2) - mean(x)**2 __global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const float* __restrict__ inp, const float* __restrict__ weight, const float* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= N) { return; } // the row of input that this group of threads is responsible for const float* x = inp + idx * C; // thread coarsening through the row, reduce the sum in series float sum = 0.0; // stores sum(x) float sum2 = 0.0; // stores sum(x**2) for (int i = warp.thread_rank(); i < C; i += warp.size()) { float xi = x[i]; sum += xi; sum2 += xi * xi; } // warp-level reduction at the end sum = cg::reduce(warp, sum, cg::plus{}); // sum(x) sum2 = cg::reduce(warp, sum2, cg::plus{}); // sum(x**2) sum /= C; // mean(x) sum2 /= C; // mean(x**2) // mean, var, rstd float m = sum; float var = sum2 - sum * sum; float s = rsqrtf(var + 1e-5f); // store the mean, no need to cache it if(warp.thread_rank() == 0 && mean != nullptr) { __stcs(mean + idx, m); } // store the rstd, no need to cache it if(warp.thread_rank() == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } // final normalization and scaling by weight/bias float* o = out + idx * C; for (int c = warp.thread_rank(); c < C; c += warp.size()) { float n = s * (__ldcs(x+c) - m); __stcs(o+c, n * weight[c] + bias[c]); } } // like 4, but in kernel 5 we have each block doing one row, not just a single warp __global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const float* __restrict__ inp, const float* __restrict__ weight, const float* __restrict__ bias, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); __shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps __shared__ float shared_sum2[32]; // warps will be writing into shared memeory after warp-reduce int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; int idx = blockIdx.x; // simply one block per row // the row of input that this group of threads is responsible for const float* x = inp + idx * C; // thread coarsening through the row, reduce the sum in series float thread_sum = 0.0; // stores sum(x) float thread_sum2 = 0.0; // stores sum(x**2) // for (int i = C + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { for (int i = threadIdx.x; i < C; i += blockDim.x) { float xi = x[i]; thread_sum += xi; thread_sum2 += xi * xi; } // warp-level reduction float warp_sum = cg::reduce(warp, thread_sum, cg::plus{}); // sum(x) float warp_sum2 = cg::reduce(warp, thread_sum2, cg::plus{}); // sum(x**2) // store the warp-level reduction in shared memory (we could have lane_id == 0 guard but not needed) shared_sum[warp_id] = warp_sum; shared_sum2[warp_id] = warp_sum2; __syncthreads(); // load results from shared memory to threads, pad with zeros for threads that are out of bounds warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f; warp_sum2 = (lane_id < num_warps) ? shared_sum2[lane_id] : 0.0f; // now reduce the warp-level reductions float block_sum = cg::reduce(warp, warp_sum, cg::plus{}); // sum(x) float block_sum2 = cg::reduce(warp, warp_sum2, cg::plus{}); // sum(x**2) // mean, var, rstd block_sum /= C; // mean(x) block_sum2 /= C; // mean(x**2) float m = block_sum; float var = block_sum2 - m * m; float s = rsqrtf(var + 1e-5f); // store the mean, no need to cache it if(threadIdx.x == 0 && mean != nullptr) { __stcs(mean + idx, m); } // store the rstd, no need to cache it if(threadIdx.x == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } // final normalization and scaling by weight/bias float* o = out + idx * C; for (int i = threadIdx.x; i < C; i += blockDim.x) { float n = s * (__ldcs(x+i) - m); __stcs(o+i, n * weight[i] + bias[i]); } } // Inspired by `fused_residual_forward_kernel5` in fused_residual_forward.cu __global__ void layernorm_forward_kernel6(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, const float* __restrict__ inp, const float* __restrict__ weight, const float* __restrict__ bias, int N, int C) { assert(blockDim.x == WARP_SIZE); // load weights and biases into shared memory // do this before we allow any threads to exit! extern __shared__ char params[]; // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so // let's keep everything as x128 x128* s_weight = reinterpret_cast(params); x128* s_bias = reinterpret_cast(params) + (C / x128::size); x128* s_in = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { s_weight[i/x128::size] = load128(weight + i); s_bias[i/x128::size] = load128(bias + i); } __syncthreads(); int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx >= N) { return; } // guard // adjust pointers to current token inp += idx * C; out += idx * C; const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = load128cs(inp + c); for(int k = 0; k < x128::size; ++k) { sum += (float)in_data[k]; } s_in[c / x128::size] = in_data; } sum = warpReduceSum(sum); float m = sum / C; float v = 0.f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = s_in[c / x128::size]; for(int k = 0; k < x128::size; ++k) { v += ((float)in_data[k] - m) * ((float)in_data[k] - m); } } v = warpReduceSum(v) / C; float s = rsqrtf(v + eps); for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = s_in[c / x128::size]; const x128 w = s_weight[c / x128::size]; const x128 b = s_bias[c / x128::size]; x128 out_data; for(int k = 0; k < x128::size; ++k) { float n = s * ((float)in_data[k] - m); // normalized output float o = n * (float)w[k] + (float)b[k]; // scale and shift it out_data[k] = o; } store128cs(out + c, out_data); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0 && mean != nullptr) { __stcs(mean + idx, m); } // store the rstd, no need to cache it if(threadIdx.x == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } } // ---------------------------------------------------------------------------- // kernel launcher void layernorm_forward1(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { const int N = B * T; const int grid_size = ceil_div(N, block_size); layernorm_forward_kernel1<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void layernorm_forward2(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { int N = B * T; // in mean and rstd, threads cooperate within blocks via reductions mean_kernel<<>>(mean, inp, N, C, block_size); cudaCheck(cudaGetLastError()); rstd_kernel<<>>(rstd, inp, mean, N, C, block_size); cudaCheck(cudaGetLastError()); // in the normalization, everything just gets flattened out const int block_size2 = 256; const int grid_size = ceil_div(B * T * C, block_size2); normalization_kernel<<>>(out, inp, mean, rstd, weight, bias, B, T, C); cudaCheck(cudaGetLastError()); } void layernorm_forward3(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); const int N = B * T; const int grid_size = ceil_div(N * 32, block_size); layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void layernorm_forward4(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); const int N = B * T; const int grid_size = ceil_div(N * 32, block_size); layernorm_forward_kernel4<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void layernorm_forward5(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { assert(block_size % 32 == 0); assert(block_size <= 1024); const int N = B * T; const int grid_size = N; layernorm_forward_kernel5<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void layernorm_forward6(float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, int block_size) { assert(block_size % 32 == 0); const int N = B * T; int block_y = block_size / WARP_SIZE; const int grid_size = ceil_div(N, block_y); size_t smem = (2 + block_y) * C * sizeof(float); // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute // this may fail, in which case we fall back to the smem free implementation. cudaCheck(cudaGetLastError()); auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); cudaGetLastError(); if (status == cudaSuccess) { layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); } else { const int grid_size = N; // fall back to the version without shared memory layernorm_forward_kernel5<<>>(out, mean, rstd, inp, weight, bias, N, C); } cudaCheck(cudaGetLastError()); } // kernel version dispatch void layernorm_forward(int kernel_num, float* out, float* mean, float* rstd, const float* inp, const float* weight, const float* bias, int B, int T, int C, const int block_size) { switch (kernel_num) { case 1: layernorm_forward1(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; case 2: layernorm_forward2(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; case 3: layernorm_forward3(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; case 4: layernorm_forward4(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; case 5: layernorm_forward5(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; case 6: layernorm_forward6(out, mean, rstd, inp, weight, bias, B, T, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int C = 768; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* mean = (float*)malloc(B * T * sizeof(float)); float* rstd = (float*)malloc(B * T * sizeof(float)); float* inp = make_random_float(B * T * C); float* weight = make_random_float(C); float* bias = make_random_float(C); // move to GPU float* d_out; float* d_mean; float* d_rstd; float* d_inp; float* d_weight; float* d_bias; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_weight, C * sizeof(float))); cudaCheck(cudaMalloc(&d_bias, C * sizeof(float))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_weight, weight, C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_bias, bias, C * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 2; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512, 1024}; layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C); // check the correctness of the kernel at all block sizes for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); layernorm_forward(kernel_num, d_out, d_mean, d_rstd, d_inp, d_weight, d_bias, B, T, C, block_size); validate_result(d_out, out, "out", B * T * C, 1e-5f); validate_result(d_mean, mean, "mean", B * T, 1e-5f); validate_result(d_rstd, rstd, "rstd", B * T, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); // time the kernel at different block sizes for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 2000; float elapsed_time = benchmark_kernel(repeat_times, layernorm_forward, kernel_num, d_out, d_mean, d_rstd, d_inp, d_weight, d_bias, B, T, C, block_size); // napkin math: estimate the memory bandwidth achieved // e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = (2 * B * T * C) * 4; // *4 for float float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(out); free(mean); free(rstd); free(inp); free(weight); free(bias); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_mean)); cudaCheck(cudaFree(d_rstd)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_weight)); cudaCheck(cudaFree(d_bias)); return 0; } ================================================ FILE: dev/cuda/matmul_backward.cu ================================================ /* Kernels for matmul backward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward OMP_NUM_THREADS=32 ./matmul_backward 1 */ #include #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void matmul_backward_cpu(float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_forward // this backward could be done in a single "round" of loops // but that doesn't afford an efficient parallelization strategy // backward into inp first, parallelize over B,T #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* dout_bt = dout + b * T * OC + t * OC; float* dinp_bt = dinp + b * T * C + t * C; for (int o = 0; o < OC; o++) { float* wrow = weight + o*C; float d = dout_bt[o]; for (int i = 0; i < C; i++) { dinp_bt[i] += wrow[i] * d; } } } } // backward into weight/bias, parallelize over output channels OC #pragma omp parallel for for (int o = 0; o < OC; o++) { double sum = 0.0; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* dout_bt = dout + b * T * OC + t * OC; float* inp_bt = inp + b * T * C + t * C; float* dwrow = dweight + o*C; float d = dout_bt[o]; if (dbias != NULL) { sum += d; } for (int i = 0; i < C; i++) { dwrow[i] += inp_bt[i] * d; } } } if (dbias != NULL){dbias[o] = sum;} } } // ---------------------------------------------------------------------------- // GPU kernels // naive kernel to backpropagate only the bias, it's just a sum :'( __global__ void matmul_backward_bias_kernel_naive(float* dbias, const float* dout, int B, int T, int OC) { int o = blockIdx.x * blockDim.x + threadIdx.x; if (o < OC) { double sum = 0.0; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { sum += dout[b * T * OC + t * OC + o]; } } dbias[o] = sum; } } // use shared memory and coarsening + reductions __global__ void matmul_backward_bias_kernel_faster(float* dbias, const float* dout, int B, int T, int OC) { extern __shared__ float shared[]; int o = blockIdx.x; // range [0, OC) int tid = threadIdx.x; // range [0, block_size) int block_size = blockDim.x; const float* x = dout + o; // thread coarsening double sum = 0.0; for (int i = tid; i < B * T; i += block_size) { sum += x[i * OC]; } shared[tid] = (float) sum; __syncthreads(); // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] += shared[tid + stride]; } } // write the final result (at thread 0) to global memory if (tid == 0) { dbias[o] = shared[0]; } } // ---------------------------------------------------------------------------- // kernel launcher // version1: simple cuBLAS calls void matmul_backward1(float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, int B, int T, int C, int OC) { float alpha = 1.0f; float beta = 1.0f; // note we must use beta = 1.0 so that we do a +=, as we should, because gradients add // for reference the API is: // cublasStatus_t cublasSgemm(cublasHandle_t handle, // cublasOperation_t transa, cublasOperation_t transb, // int m, int n, int k, // const float *alpha, // const float *A, int lda, // const float *B, int ldb, // const float *beta, // float *C, int ldc) // recall the forward pass was calculated with alpha = 1.0f, beta = 0.0f as: // cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC); // backward to input cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &alpha, weight, C, dout, OC, &beta, dinp, C)); // backward to weight cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &alpha, inp, C, dout, OC, &beta, dweight, C)); // backward to bias, if given if (dbias != NULL) { // sum over B,T using matrix vector multiplication with cuBLAS // for reference this API is: // cublasStatus_t cublasSgemv(cublasHandle_t handle, cublasOperation_t trans, // int m, int n, // const float *alpha, // const float *A, int lda, // const float *x, int incx, // const float *beta, // float *y, int incy) // dout is (B,T,OC), or in 2D terms (B*T, OC) // cublasCheck(cublasSgemv(cublas_handle, CUBLAS_OP_N, B*T, OC, &alpha, dout, B*T, ones, 1, &beta, dbias, 1)); // cublasCheck(cublasSgemv(cublas_handle, CUBLAS_OP_T, OC, B*T, &alpha, dout, OC, ones, 1, &beta, dbias, 1)); // ugh the above isn't working... // let's just do naive calculation for now, fix later // const int block_size=128; // const int grid_size=(OC + block_size - 1) / block_size; // matmul_backward_bias_kernel<<>>(dbias, dout, B, T, OC); // bit faster const int block_size=512; dim3 block_dim(block_size); dim3 grid_dim(OC); size_t shared_mem_size = block_size * sizeof(float); matmul_backward_bias_kernel_faster<<>>(dbias, dout, B, T, OC); } } void matmul_backward(int kernel_num, float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, int B, int T, int C, int OC) { switch (kernel_num) { case 1: matmul_backward1(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int C = 768; int OC = 768 * 4; // expansion of 4, e.g. in the MLP // set up the device int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, deviceIdx); printf("Device %d: %s\n", deviceIdx, deviceProp.name); // setup cuBLAS and its mathmodes, ensure fp32 int enable_tf32 = 0; // use fp32 to get accurate results for checking w.r.t. CPU cublasCheck(cublasCreate(&cublas_handle)); printf("enable_tf32: %d\n", enable_tf32); cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); // create host memory of random numbers float* dinp = make_zeros_float(B * T * C); float* dweight = make_zeros_float(OC * C); float* dbias = make_zeros_float(OC); float* dout = make_random_float(B * T * OC); float* inp = make_random_float(B * T * C); float* weight = make_random_float(OC * C); float* ones = make_ones_float(OC); // move to GPU float* d_dinp; float* d_dweight; float* d_dbias; float* d_dout; float* d_inp; float* d_weight; float* d_ones; cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dweight, OC * C * sizeof(float))); cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(float))); cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_weight, OC * C * sizeof(float))); cudaCheck(cudaMalloc(&d_ones, OC * sizeof(float))); cudaCheck(cudaMemcpy(d_dinp, dinp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_dweight, dweight, OC * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_dbias, dbias, OC * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_dout, dout, B * T * OC * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_weight, weight, OC * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_ones, ones, OC * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // calculate the CPU reference matmul_backward_cpu(dinp, dweight, dbias, dout, inp, weight, B, T, C, OC); // calculate the GPU version matmul_backward(kernel_num, d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones, B, T, C, OC); // compare printf("Checking correctness...\n"); printf("dinp:\n"); validate_result(d_dinp, dinp, "dinp", B * T * C, 1e-3f); printf("dweight:\n"); validate_result(d_dweight, dweight, "dweight", OC * C, 1e-3f); printf("dbias:\n"); validate_result(d_dbias, dbias, "dbias", OC, 1e-3f); printf("All results match.\n\n"); // now benchmark the kernel int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, matmul_backward, kernel_num, d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones, B, T, C, OC); printf("time %.4f ms\n", elapsed_time); // cleanups free(dinp); free(dweight); free(dbias); free(dout); free(inp); free(weight); free(ones); cudaCheck(cudaFree(d_dinp)); cudaCheck(cudaFree(d_dweight)); cudaCheck(cudaFree(d_dbias)); cudaCheck(cudaFree(d_dout)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_weight)); cudaCheck(cudaFree(d_ones)); cublasCheck(cublasDestroy(cublas_handle)); return 0; } ================================================ FILE: dev/cuda/matmul_backward_bias.cu ================================================ /* Kernels for matmul backward pass bias only. Compile example: nvcc -O3 -lcublas -lcublasLt -std=c++17 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias ./matmul_backward_bias 1 ./matmul_backward_bias 2 ./matmul_backward_bias 3 ./matmul_backward_bias 4 ./matmul_backward_bias 5 ncu: sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 */ #include #include #include #include #include #include #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // utility functions __host__ __device__ bool isPowerOfTwo(int n) { return (n > 0) && ((n & (n - 1)) == 0); } __host__ __device__ int largestPowerOfTwoLessOrEqual(int n) { // Return the largest power of 2 less than or equal to n if (n < 1) { return 0; } while ((n & (n - 1)) > 0) { n = n & (n - 1); } return n; } // ---------------------------------------------------------------------------- // CPU code reference void matmul_backward_bias_cpu(float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, int B, int T, int C, int OC) { for (int o = 0; o < OC; o++) { double sum = 0.0; for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* dout_bt = dout + b * T * OC + t * OC; sum += dout_bt[o]; } } dbias[o] = sum; } } // ---------------------------------------------------------------------------- // GPU kernels float* dbias_buffer; __global__ void matmul_backward_bias_kernel1(floatX* dbias, const floatX* dout, int B, int T, int OC) { extern __shared__ float shared[]; int o = blockIdx.x; // range [0, OC) int tid = threadIdx.x; // range [0, block_size) int block_size = blockDim.x; const floatX* x = dout + o; // thread coarsening float sum = 0.0; for (int i = tid; i < B * T; i += block_size) { sum += (float)x[i * OC]; } shared[tid] = sum; __syncthreads(); // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] += shared[tid + stride]; } } // write the final result (at thread 0) to global memory if (tid == 0) { dbias[o] = (floatX)((float)dbias[o] + shared[0]); } } // cooperative groups solution, one warp per output channel __global__ void matmul_backward_bias_kernel2(floatX* dbias, const floatX* dout, int B, int T, int OC) { // dout is (B, T, OC), dbias is (OC) // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3) int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if(idx >= OC) { return; } int BT = B * T; // number of elements to reduce in total, per channel // first, thread coarsening to sum reduce the problem size from B*T to 32 float sum = 0.0f; for(int i = warp.thread_rank(); i < BT; i += warp.size()) { sum += (float)dout[i * OC + idx]; } // now do a warp-level reduce to get the sum across the 32 threads in this warp sum = cg::reduce(warp, sum, cg::plus{}); // write the result to output (global memory) if(warp.thread_rank() == 0) { dbias[idx] = (float)dbias[idx] + sum; } } __global__ void matmul_backward_bias_kernel3(floatX* dbias, const floatX* dout, int B, int T, int OC) { // dout is (B, T, OC), dbias is (OC) // in this version of the kernel the entire block of block_size is dedicated to one output channel namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); __shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps int BT = B * T; // number of elements to reduce in total, per channel int num_warps = blockDim.x / 32; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; int idx = blockIdx.x; // simply one block per row // round 1: thread coarsening to reduce the problem size from B*T to block_size float thread_sum = 0.0f; for(int i = threadIdx.x; i < BT; i += blockDim.x) { thread_sum += (float)dout[i * OC + idx]; } // now do a warp-level reduce to get the sum across the 32 threads in each warp // reduce the problem size from block_size to block_size/32 i.e. `num_warps` float warp_sum = cg::reduce(warp, thread_sum, cg::plus{}); // store the warp sum in shared memory (we could have lane_id == 0 guard but not needed) shared_sum[warp_id] = warp_sum; __syncthreads(); // load results from shared memory to threads, pad with zeros for threads that are out of bounds warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f; // now reduce the warp-level reductions float block_sum = cg::reduce(warp, warp_sum, cg::plus{}); // sum(x) // write the result to output (global memory) if(threadIdx.x == 0) { dbias[idx] = (float)dbias[idx] + block_sum; } } // this kernel performs a column-wise reduction over dout, in PyTorch equivalent to: // dbias = dout.sum((0,1)) // the idea is to employ one block to reduce along several columns, // where each block has a width of 32 columns to ensure coalesced access. // at the end we accumulate the reductions performed by the warps in each block via shared memory __global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout, int B, int T, int OC) { // this kernel is launched with 1D grid_dim of OC/32 // for example let's say block_size is 128 extern __shared__ float smem[]; // of size block_size (128) const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3 const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31 const int tl = blockIdx.x * warpSize; // pointer to the start column for this block const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4 // pointer to the start of the column for one lane of threads // so e.g. 4 (`vstep`) threads (of the same lane_id) will reduce this one column const floatX* dout_col = dout + tl + lane_id; // column reductions by looping through the rows // each of the 4 threads offsets by its warp_id and then skips by vstep // together these 4 threads cover all B*T rows of this (lane_id) column // importantly, consecutive threads (in threadId) are processing adjacent columns, // leading to a coalesced memory access pattern float dout_sum = 0.0f; for (int row = warp_id; row < B * T; row += vstep) { dout_sum += (float)dout_col[row * OC]; } smem[lane_id + warp_id * warpSize] = dout_sum; __syncthreads(); // warp_id 0 reduces the shared memory column-wise, linearly dout_sum = 0.0f; if (warp_id == 0) { for (int j = 0; j < vstep; j++) { dout_sum += smem[lane_id + j * warpSize]; } dbias[tl + lane_id] = (float)dbias[tl + lane_id] + dout_sum; } } #ifndef ENABLE_BF16 __global__ void matmul_backward_bias_kernel5(floatX* dbias, const floatX* dout, int B, int T, int OC) { int oc = blockIdx.x * blockDim.x + threadIdx.x; if(oc >= OC) return; float sum = 0.0; // grid-wide loop for maximum parallelism for (int i = blockIdx.y; i < B * T; i += gridDim.y) { sum += (float)dout[i * OC + oc]; } // and atomically add everything together. atomics within one block are conflict-free! atomicAdd(dbias + oc, sum); } #endif __global__ void cast_and_add_kernel(floatX* dst, const float* src, size_t n) { // used only for matmul_backward_bias kernel, a little bit embarassing TODO delete later const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { dst[idx] = (floatX)((float)dst[idx] + src[idx]); } // have to += because dbias is a paramater } __global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, int B, int T, int OC, const int block_size) { // note: this kernel reads in floatX, but it writes to float! // this is because we're using atomics, which are super slow in < fp32 precision on < H100 GPUs // so the trick is do fp32 atomics to a buffer, and then copy_and_cast the result to floatX // (this also results in higher accuracy than doing accumulation directly in floatX) // see comments in matmul_backward() for an explanation of block/grid dimensions etc. const int block_size_x = 32; const int block_size_y = block_size / block_size_x; // 16 const int OC_per_warp = block_size_x * x128::size; // 256 at BF16 int local_oc = threadIdx.x * x128::size; int global_oc = blockIdx.x * OC_per_warp + local_oc; float accumulators[x128::size]; extern __shared__ float shared[]; for (int k = 0; k < x128::size; k++) { accumulators[k] = 0.0f; } int thread_id = threadIdx.y * block_size_x + threadIdx.x; for (int idx = thread_id; idx < OC_per_warp; idx += block_size) { shared[idx] = 0.0f; } __syncthreads(); if(global_oc < OC) { for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) { x128 packed_dout = load128(dout + global_oc + idx*OC); for (int k = 0; k < x128::size; k++) { accumulators[k] += (float)packed_dout[k]; } } // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance, // so we accumulate in a conflict-free order, then reorder to match the global memory order for (int k = 0; k < x128::size; k++) { atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]); } } if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data __syncthreads(); // read the accumulated values in the conflict-free order int i = threadIdx.x + (threadIdx.y * block_size_x); float tmp = shared[i]; __syncthreads(); // write them back to shared memory in the global memory order // 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp) shared[local_oc + threadIdx.y] = tmp; __syncthreads(); // now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp) if (i + blockIdx.x*OC_per_warp < OC) { atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]); } } // We want to decrease the amount of channels handled by each block, so that we need fewer across-block reductions. // We do this by realizing the following: For scalar memory access, we need to read one element per thread in a warp // to read an entire cacheline, but for vectorized memory access, with 128 bit of data per thread, we only need eight // threads to fetch a cacheline, which means that we can already operate on a "depth" of four within a single warp. // => blockDim.x == 4, blockDim.y == 32/4 = 8 // template __global__ void matmul_backward_bias_kernel8(OutFloat* dbias, const floatX* dout, int B, int T, int OC, std::bool_constant) { constexpr const int bdx = 4; constexpr const int bdy = 32 / bdx; assert(blockDim.x == bdx); assert(blockDim.y == bdy); int warp_d = (int)threadIdx.x; int warp_c = (int)threadIdx.y; int block_d = (int)threadIdx.z; const int OC_per_warp = bdy * x128::size; // 64 at BF16 int local_oc = warp_c * x128::size; int global_oc = blockIdx.x * OC_per_warp + local_oc; int local_bt = warp_d + bdx * block_d; int bt_per_block = bdx * blockDim.z; float accumulators[x128::size]; for (int k = 0; k < x128::size; k++) { accumulators[k] = 0.0f; } if(global_oc < OC) { // sum up over all bt within registers for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { x128 packed_dout = load128(dout + global_oc + idx*OC); for (int k = 0; k < x128::size; k++) { accumulators[k] += (float)packed_dout[k]; } } } __shared__ float sub_results[x128::size][32][bdy]; // reduce within-warp results for (int k = 0; k < x128::size; k++) { float v = accumulators[k]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); if(warp_d == 0) { sub_results[k][block_d][warp_c] = v; } } __syncthreads(); // block-wide reductions for (int k = block_d; k < x128::size; k += blockDim.z) { float a = 0.f; for (int r = warp_d; r < blockDim.z; r += bdx) { float v = sub_results[k][r][warp_c]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); a += v; } if(warp_d == 0 && global_oc < OC) { // coalesced, but not cacheline-sized if constexpr (!Atomic) { dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]); } else { atomicAdd(dbias + global_oc + k, a); } } } } // Like kernel 8, but instead of accumulating to the auxiliary buffer, it writes // multiple values that need to be summed up in a separate kernel call. // If UseAuxBuffer is false, gridDim.y has to be one, and results are added directly // to dbias. template __global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC, std::bool_constant) { constexpr const int bdx = 4; constexpr const int bdy = 32 / bdx; assert(blockDim.x == bdx); assert(blockDim.y == bdy); int warp_d = (int)threadIdx.x; int warp_c = (int)threadIdx.y; int block_d = (int)threadIdx.z; const int OC_per_warp = bdy * x128::size; // 64 at BF16 int local_oc = warp_c * x128::size; int global_oc = blockIdx.x * OC_per_warp + local_oc; int local_bt = warp_d + bdx * block_d; int bt_per_block = bdx * blockDim.z; float accumulators[x128::size]; for (int k = 0; k < x128::size; k++) { accumulators[k] = 0.0f; } if(global_oc < OC) { // sum up over all bt within registers for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { x128 packed_dout = load128(dout + global_oc + idx*OC); for (int k = 0; k < x128::size; k++) { accumulators[k] += (float)packed_dout[k]; } } } __shared__ float sub_results[x128::size][32][bdy]; // reduce within-warp results for (int k = 0; k < x128::size; k++) { float v = accumulators[k]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); if(warp_d == 0) { sub_results[k][block_d][warp_c] = v; } } __syncthreads(); // block-wide reductions for (int k = block_d; k < x128::size; k += blockDim.z) { float a = 0.f; for (int r = warp_d; r < blockDim.z; r += bdx) { float v = sub_results[k][r][warp_c]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); a += v; } if(warp_d == 0 && global_oc < OC) { // coalesced, but not cacheline-sized if constexpr (!UseAuxBuffer) { dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]); } else { dbias[global_oc + k + blockIdx.y * OC] = a; } } } } __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, size_t m) { const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * f128::size; assert(n % x128::size == 0); if (idx < n) { f128 acc; for(int k = 0; k < f128::size; ++k) { acc[k] = 0.f; } for(int l = 0; l < m; ++l) { f128 s = load128(src + idx + n * l); for(int k = 0; k < f128::size; ++k) { acc[k] += s[k]; } } for(int k = 0; k < f128::size; ++k) { dst[idx + k] = (floatX) ((float)dst[idx + k] + acc[k]); } } } // ---------------------------------------------------------------------------- // kernel launcher // version1: simple cuBLAS calls void matmul_backward_bias1(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { block_size = largestPowerOfTwoLessOrEqual(block_size); assert(isPowerOfTwo(block_size)); // block_size needs to be power of 2 due to the reduction dim3 block_dim(block_size); dim3 grid_dim(OC); size_t shared_mem_size = block_size * sizeof(float); matmul_backward_bias_kernel1<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } void matmul_backward_bias2(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { // block_size 512 seems best const int grid_size = ceil_div(OC * 32, block_size); matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } void matmul_backward_bias3(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { // block_size 256 seems best matmul_backward_bias_kernel3<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } void matmul_backward_bias4(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { assert(OC % 32 == 0); // OC must be divisible by 32 for this kernel const int grid_size = OC / 32; matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } #ifndef ENABLE_BF16 void matmul_backward_bias5(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { const int grid_size_x = ceil_div(OC, block_size); const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size); matmul_backward_bias_kernel5<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } #endif void matmul_backward_bias7(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { if(block_size < 256) { block_size = 256; } // Each warp is responsible for 32 * "x128::size" = 256 OCs at BF16 (OC must be a multiple of 256!) // Block size is 512 threads (16 warps) and we reduce those 16 values into 1 at the end // blockDim.x is 32 --> single warp being responsible for those 256 OCs // blockDim.y is 16 --> 16 parallel independent warps processing the same OCs for different BTs // gridDim.x is OC / 256 --> each block processes 256 OCs // grimDim.y is max(1, (cuda_num_SMs * threads_per_SM) / (512 * gridDim.x)); --> fill up the entire GPU! const int warp_size = 32; const int OC_per_warp = warp_size * x128::size; // 256 at BF16 const int block_size_x = 32; const int block_size_y = block_size / block_size_x; // 16 const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16 const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU! assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float))); matmul_backward_bias_kernel7<<>>(dbias_buffer, dout, B, T, OC, block_size); cudaCheck(cudaGetLastError()); cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); cudaCheck(cudaGetLastError()); } void matmul_backward_bias8(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { dim3 block_dim = {4, 8, (unsigned)block_size/32}; const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16 const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU! // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { matmul_backward_bias_kernel8<<>>(dbias, dout, B, T, OC, std::bool_constant{}); cudaCheck(cudaGetLastError()); } else { cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float))); matmul_backward_bias_kernel8<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); cudaCheck(cudaGetLastError()); cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); cudaCheck(cudaGetLastError()); } } void matmul_backward_bias9(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { dim3 block_dim = {4, 8, (unsigned)block_size/32}; const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16 const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU! // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, std::bool_constant{}); cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); cudaCheck(cudaGetLastError()); reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); cudaCheck(cudaGetLastError()); } } void matmul_backward_bias(int kernel_num, floatX* dbias, floatX* dout, int B, int T, int OC, int block_size) { switch (kernel_num) { case 1: matmul_backward_bias1(dbias, dout, B, T, OC, block_size); break; case 2: matmul_backward_bias2(dbias, dout, B, T, OC, block_size); break; case 3: matmul_backward_bias3(dbias, dout, B, T, OC, block_size); break; case 4: matmul_backward_bias4(dbias, dout, B, T, OC, block_size); break; case 5: #ifndef ENABLE_BF16 matmul_backward_bias5(dbias, dout, B, T, OC, block_size); #else fprintf(stderr, "Kernel 5 is only supported for fp32"); exit(1); #endif break; case 7: matmul_backward_bias7(dbias, dout, B, T, OC, block_size); break; case 8: matmul_backward_bias8(dbias, dout, B, T, OC, block_size); break; case 9: matmul_backward_bias9(dbias, dout, B, T, OC, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; int OC = 768 * 4; // expansion of 4, e.g. in the MLP // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // create host memory of random numbers float* dbias = make_zeros_float(OC); float* dout = make_random_float(B * T * OC); // move to GPU floatX* d_dbias; floatX* d_dout; cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(floatX))); cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(floatX))); cudaCheck(cudaMalloc(&dbias_buffer, OC * sizeof(float) * 32)); cudaCheck(memcpy_convert(d_dbias, dbias, OC)); cudaCheck(memcpy_convert(d_dout, dout, B * T * OC)); // ncu debugging / profiling, do a single call // int block_size_debug; // if (kernel_num == 1) { block_size_debug = 512; // } else if (kernel_num == 2) { block_size_debug = 512; // } else { block_size_debug = 256; } // printf("kernel %d, block_size %d\n", kernel_num, block_size_debug); // matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size_debug); // exit(EXIT_SUCCESS); int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; // calculate the CPU reference matmul_backward_bias_cpu(NULL, NULL, dbias, dout, NULL, NULL, B, T, C, OC); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // memset the bias to zero cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(floatX))); // calculate the GPU version matmul_backward_bias(kernel_num, d_dbias, d_dout, B, T, OC, block_size); // compare printf("Checking correctness...\n"); float tol = std::is_same_v ? 5e-3f : 1.0f; validate_result(d_dbias, dbias, "dbias", OC, tol); printf("All results match for block_size=%d.\n\n", block_size); } // now benchmark the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 2000; float elapsed_time = benchmark_kernel(repeat_times, matmul_backward_bias, kernel_num, d_dbias, d_dout, B, T, OC, block_size); printf("block_size %d time %.4f ms\n", block_size, elapsed_time); } // cleanups free(dbias); free(dout); cudaCheck(cudaFree(dbias_buffer)); cudaCheck(cudaFree(d_dbias)); cudaCheck(cudaFree(d_dout)); return 0; } ================================================ FILE: dev/cuda/matmul_forward.cu ================================================ /* Kernels for matmul forward pass. It's advised to use OpenMP here because the CPU implementation is fairly slow otherwise Compile example: nvcc -O3 --use_fast_math -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward -lcublas -lcublasLt version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C OMP_NUM_THREADS=32 ./matmul_forward 1 version 2 calls cuBLAS, very fast OMP_NUM_THREADS=32 ./matmul_forward 2 version 3 calls cuBLASLt, should be even faster OMP_NUM_THREADS=32 ./matmul_forward 3 */ #include #include #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void matmul_forward_cpu(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { // OC is short for "output channels" // inp is (B,T,C), weight is (OC, C), bias is (OC) // out will be (B,T,OC) #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { float* out_bt = out + b * T * OC + t * OC; const float* inp_bt = inp + b * T * C + t * C; for (int o = 0; o < OC; o++) { float val = (bias != NULL) ? bias[o] : 0.0f; const float* wrow = weight + o*C; for (int i = 0; i < C; i++) { val += inp_bt[i] * wrow[i]; } out_bt[o] = val; } } } } // ---------------------------------------------------------------------------- // GPU kernels // kernel 1: naive kernel, every thread handles one output element, direct global memory access __global__ void matmul_forward_kernel1(float* out, const float* inp, const float* weight, const float* bias, int BT, int C, int OC) { // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C // inp is (B,T,C), weight is (OC, C), bias is (OC) // in the naive kernel, every thread handles one element of out int bt = blockIdx.x * blockDim.x + threadIdx.x; int oc = blockIdx.y * blockDim.y + threadIdx.y; if (bt < BT && oc < OC) { float val = (bias != NULL) ? bias[oc] : 0.0f; const float* wrow = weight + oc * C; const float* inp_bt = inp + bt * C; for (int i = 0; i < C; i++) { val += inp_bt[i] * wrow[i]; } out[bt * OC + oc] = val; } } // is there no better way other than just adding bias with a whole separate kernel? // this is a highly memory-bound operation, should be fused into the matmul kernel // but i can't seem to find a cuBLAS function that does this __global__ void add_bias(float* out, const float* bias, int B, int T, int OC) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = idx; i < B * T * OC; i += stride) { int col = i % OC; out[i] += bias[col]; } } // kernel 4: semi-efficient handwritten kernel // see trimat_forward.cu for some intermediate development steps __device__ float4 ld_vec(const float* address) { return *reinterpret_cast(address); } __device__ void st_vec(float* address, float4 val) { *reinterpret_cast(address) = val; } __global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out, const float* inp, const float* weight, const float* bias, int C, int OC) { // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C // inp is (B,T,C), weight is (OC, C), bias is (OC) // each thread handles 8x8 elements; each block 128 by 128 elements. int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y); // buffers to cache chunks of the input matrices __shared__ float lhs_s[128][32]; __shared__ float rhs_s[128][32]; // adjust our pointers for the current block inp += 128 * blockIdx.x * C; weight += 128 * blockIdx.y * C; out += 128 * blockIdx.x * OC + 128 * blockIdx.y; float vals[8][8] = {}; if(bias != NULL) { for (int i = 0; i < 8; i++) { for (int j = 0; j < 8; j += 4) { float4 b = ld_vec(bias + oc + j); vals[i][j+0] = b.x; vals[i][j+1] = b.y; vals[i][j+2] = b.z; vals[i][j+3] = b.w; } } } int si_start = 4*(16 * threadIdx.y + threadIdx.x); for (int so = 0; so < C; so += 32) { __syncthreads(); int xmod8 = threadIdx.x % 8; int xby8 = threadIdx.x / 8; int xo = 4 * xmod8; for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) { st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo)); st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo)); } __syncthreads(); for (int si = si_start; si < si_start + 32; si += 4) { float4 rhs[8]; for (int u = 0; u < 8; ++u) { rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]); } for (int ii = 0; ii < 8; ++ii) { float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]); for (int ji = 0; ji < 8; ++ji) { vals[ii][ji] += lhs.x * rhs[ji].x; vals[ii][ji] += lhs.y * rhs[ji].y; vals[ii][ji] += lhs.z * rhs[ji].z; vals[ii][ji] += lhs.w * rhs[ji].w; } } } } for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; j += 4) { float4 result; result.x = vals[i][j + 0]; result.y = vals[i][j + 1]; result.z = vals[i][j + 2]; result.w = vals[i][j + 3]; st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result); } } } // ---------------------------------------------------------------------------- // kernel launcher // kernel 1 is the most naive matmul kernel void matmul_forward1(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC, const int sqrt_block_size) { // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C // inp is (B,T,C), weight is (OC, C), bias is (OC) dim3 gridDim(ceil_div(B * T, sqrt_block_size), ceil_div(OC, sqrt_block_size)); dim3 blockDim(sqrt_block_size, sqrt_block_size); matmul_forward_kernel1<<>>(out, inp, weight, bias, B*T, C, OC); cudaCheck(cudaGetLastError()); } // kernel 2 calls cuBLAS, which should be very efficient void matmul_forward2(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC, const int sqrt_block_size) { // for reference API is: // cublasStatus_t cublasSgemm(cublasHandle_t handle, // cublasOperation_t transa, cublasOperation_t transb, // int m, int n, int k, // const float *alpha, // const float *A, int lda, // const float *B, int ldb, // const float *beta, // float *C, int ldc) // for us, inp is (B*T, C), weight is (OC, C), out is (B*T, OC) // cuBLAS does C = alpha * A * B + beta * C // where A is mxk, B is kxn, C is mxn // now, because we use row-major storage, cuBLAS (which is column-major) sees our matrices transposed. // algorithmically / in e.g. PyTorch we want to do: out = inp @ weight.T // but because cuBLAS is column-major, we actually want to get it to calculate out.T . Mathematically, this is: // out.T = weight @ inp.T // but again, our variables look transposed, so using the actual weight/inp we have here in this function, this becomes // out.T = weight.T @ inp // so we need to get cuBLAS to calculate weight.T @ inp (the variables here are the actual ones in this function) // => need to call cuBLAS with A = weight, B = inp // => need to call cuBLAS with transa = CUBLAS_OP_T, transb = CUBLAS_OP_N const float alpha = 1.0f; const float beta = 0.0f; cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC)); // and now we still have to add the bias... (ew) if (bias != NULL) { int block_size = sqrt_block_size * sqrt_block_size; int grid_size = ceil_div(OC * B * T, block_size); add_bias<<>>(out, bias, B, T, OC); cudaCheck(cudaGetLastError()); } } // uses cublasLt to fuse the bias and gelu // https://docs.nvidia.com/cuda/cublas/#cublasltmatmul // https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu void matmul_forward3(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { int has_bias = (bias != NULL); int has_gelu = 0; // check bias alignment if(((uintptr_t)bias % 16) != 0) { printf("Bias pointer is not aligned (cuBLASLt requirement)!\n"); exit(EXIT_FAILURE); } int returnedResults = 0; cublasLtMatmulDesc_t operationDesc; cublasLtMatmulPreference_t preference; cublasLtMatrixLayout_t weightLayout; cublasLtMatrixLayout_t inputLayout; cublasLtMatrixLayout_t outputLayout; cublasLtMatrixLayout_t biasLayout; cublasLtMatmulHeuristicResult_t heuristic; // create the operation descriptor cublasOperation_t opNoTranspose = CUBLAS_OP_N; cublasOperation_t opTranspose = CUBLAS_OP_T; cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_DEFAULT; if (has_bias && has_gelu) { epilogueBias = CUBLASLT_EPILOGUE_GELU_BIAS; } else if (has_bias) { epilogueBias = CUBLASLT_EPILOGUE_BIAS; } else if (has_gelu) { epilogueBias = CUBLASLT_EPILOGUE_GELU; } cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F)); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, sizeof(epilogueBias))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); // define matrix layouts cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C)); cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C)); cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC)); cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC)); // create a preference handle with specified max workspace cublasCheck(cublasLtMatmulPreferenceCreate(&preference)); cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &cublaslt_workspace_size, sizeof(cublaslt_workspace_size))); // find a suitable algorithm cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, weightLayout, inputLayout, outputLayout, outputLayout, preference, 1, &heuristic, &returnedResults)); if (returnedResults == 0) { printf("No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d, gelu: %d\n", B, T, C, OC, has_bias, has_gelu); exit(EXIT_FAILURE); } // call the matmul const float alpha = 1.0f, beta = 0.0f; cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc, &alpha, weight, weightLayout, inp, inputLayout, &beta, out, outputLayout, out, outputLayout, &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, 0)); // cleanups cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout)); cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout)); cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout)); cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout)); } // handwritten, relatively efficient non-tensorcore matmul kernel void matmul_forward4(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC, int sqrt_block_size) { // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C // inp is (B,T,C), weight is (OC, C), bias is (OC) sqrt_block_size = 16; dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size)); dim3 blockDim(sqrt_block_size, sqrt_block_size); matmul_forward_kernel4<<>>(out, inp, weight, bias, C, OC); cudaCheck(cudaGetLastError()); } // kernel version dispatch void matmul_forward(int kernel_num, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC, const int sqrt_block_size) { switch (kernel_num) { case 1: matmul_forward1(out, inp, weight, bias, B, T, C, OC, sqrt_block_size); break; case 2: matmul_forward2(out, inp, weight, bias, B, T, C, OC, sqrt_block_size); break; case 3: matmul_forward3(out, inp, weight, bias, B, T, C, OC); break; case 4: matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 32; int T = 1024; int C = 768; int OC = 768 * 4; // expansion of 4, e.g. in the MLP // set up the device int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, deviceIdx); printf("Device %d: %s\n", deviceIdx, deviceProp.name); // setup cuBLAS and cuBLASLt cublasCheck(cublasCreate(&cublas_handle)); cublasCheck(cublasLtCreate(&cublaslt_handle)); // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; printf("enable_tf32: %d\n", enable_tf32); cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); // setup the (global) cuBLASLt workspace cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); // create host memory of random numbers float* out = (float*)malloc(B * T * OC * sizeof(float)); float* inp = make_random_float(B * T * C); float* weight = make_random_float(OC * C); float* bias = make_random_float(OC); // move to GPU float* d_out; float* d_inp; float* d_weight; float* d_bias; cudaCheck(cudaMalloc(&d_out, B * T * OC * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); cudaCheck(cudaMalloc(&d_weight, C * OC * sizeof(float))); cudaCheck(cudaMalloc(&d_bias, OC * sizeof(float))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_weight, weight, C * OC * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_bias, bias, OC * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC); // time the kernel at different block sizes int sqrt_block_sizes[] = {4, 8, 16, 32}; for (int j = 0; j < sizeof(sqrt_block_sizes) / sizeof(int); j++) { int sqrt_block_size = sqrt_block_sizes[j]; printf("Checking block size %d x %d.\n", sqrt_block_size, sqrt_block_size); matmul_forward(kernel_num, d_out, d_inp, d_weight, d_bias, B, T, C, OC, sqrt_block_size); validate_result(d_out, out, "out", B * T * OC, 1e-1f); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(sqrt_block_sizes) / sizeof(int); j++) { int sqrt_block_size = sqrt_block_sizes[j]; int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, matmul_forward, kernel_num, d_out, d_inp, d_weight, d_bias, B, T, C, OC, sqrt_block_size); // napkin math: estimate the flops achieved // e.g. A100 40GB PCIe is advertised at 19.5 TFLOPS fp32 float tflops = (float)B * T * C * OC * 2 / elapsed_time * 1e3f / 1e12f; printf("sqrt_block_size %4d | time %.4f ms | tflops %.2f\n", sqrt_block_size, elapsed_time, tflops); } // free memory free(out); free(inp); free(weight); free(bias); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_weight)); cudaCheck(cudaFree(d_bias)); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); return 0; } ================================================ FILE: dev/cuda/nccl_all_reduce.cu ================================================ /* A simple test of NCCL capabilities. Fills a vector with 1s on the first GPU, 2s on the second, etc. Then aggregates the values in the resulting vectors. Compile example: nvcc -lmpi -lnccl -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lcublas -lcublasLt nccl_all_reduce.cu -o nccl_all_reduce Run on 2 local GPUs (set -np to a different value to change GPU count): mpirun -np 2 ./nccl_all_reduce */ #include "common.h" #include #include #include #include #include #include #include #include void nccl_check(ncclResult_t status, const char *file, int line) { if (status != ncclSuccess) { printf("[NCCL ERROR] at file %s:%d:\n%s\n", file, line, ncclGetErrorString(status)); exit(EXIT_FAILURE); } } #define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__)) void mpi_check(int status, const char *file, int line) { if (status != MPI_SUCCESS) { char mpi_error[4096]; int mpi_error_len = 0; assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS); printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, mpi_error); exit(EXIT_FAILURE); } } #define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__)) // Sets a vector to a predefined value __global__ void set_vector(float *data, int N, float value) { int i = blockIdx.x * blockDim.x + threadIdx.x; // Check for out-of-bounds access if (i < N) { data[i] = value; } } size_t cdiv(size_t a, size_t b) { return (a + b - 1) / b; } // Parameters specific to training on multiple GPUs. typedef struct { int process_rank; // Rank of this process among all MPI processes on all hosts. 0 if no multi-GPU. int num_processes; // Total number of processes on all hosts. 1 if no multi-GPU. int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU. ncclComm_t nccl_comm; // NCCL communication primitive, used for collective mutli-GPU work. } MultiGpuConfig; // Determine which GPU this process should use. // Processes on the same machines use different GPU indicies. Processes on other machines don't. // Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread int multi_gpu_get_local_device_idx(int process_rank, int num_processes) { char hostname[1024]; hostname[1023] = '\0'; // All processes on the same machine will share the same hostname. gethostname(hostname, 1023); for (int i=0; i < 1024; i++) { if (hostname[i] == '.') { hostname[i] = '\0'; break; } } uint64_t hostname_hash = 5381; for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; } // Distribute all hostname hashes to all processes. uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); all_hostsname_hashes[process_rank] = hostname_hash; mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); // Identify which GPU we need to use. int local_device_idx = 0; for (int current_process = 0; current_process < num_processes; ++current_process) { if (current_process == process_rank) { // Found my gpu, local_device_idx now has my target GPU index. break; } if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { // This process ID runs on the same machine, but it's not me, skip this GPU local_device_idx++; } } free(all_hostsname_hashes); return local_device_idx; } MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { // Initialize MPI. MultiGpuConfig result; mpiCheck(MPI_Init(argc, argv)); mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); printf("[Process rank %d] Using GPU %d\n", result.process_rank, result.local_device_idx); cudaCheck(cudaSetDevice(result.local_device_idx)); ncclUniqueId nccl_id; if (result.process_rank == 0) { ncclCheck(ncclGetUniqueId(&nccl_id)); } mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank)); return result; } void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) { ncclCommDestroy(multi_gpu_config->nccl_comm); mpiCheck(MPI_Finalize()); } float get_mean(float *arr, size_t size, int process_rank) { double sum = 0.0; for (size_t i = 0; i < size; ++i) { sum += arr[i]; } return sum / size; } int main(int argc, char **argv) { // Some constants const size_t all_reduce_buffer_size = 32 * 1024 * 1024; const size_t threads_per_block = 1024; MultiGpuConfig multi_gpu_config = multi_gpu_config_init(&argc, &argv); // Allocating buffers on each of the devices. float *all_reduce_buffer; cudaCheck( cudaMalloc(&all_reduce_buffer, all_reduce_buffer_size * sizeof(float))); int n_blocks = cdiv(all_reduce_buffer_size, threads_per_block); // Set the allocated memory to a defined value. set_vector<<>>( all_reduce_buffer, all_reduce_buffer_size, (float)(multi_gpu_config.process_rank + 1)); cudaCheck(cudaGetLastError()); float *all_reduce_buffer_host = (float *)malloc(all_reduce_buffer_size * sizeof(float)); cudaCheck(cudaMemcpy(all_reduce_buffer_host, all_reduce_buffer, sizeof(float) * all_reduce_buffer_size, cudaMemcpyDeviceToHost)); printf("[Process rank %d] average value before all reduce is %.6f\n", multi_gpu_config.process_rank, get_mean(all_reduce_buffer_host, all_reduce_buffer_size, multi_gpu_config.process_rank)); float *all_reduce_buffer_recv; cudaCheck(cudaMalloc(&all_reduce_buffer_recv, all_reduce_buffer_size * sizeof(float))); ncclCheck(ncclAllReduce( (const void *)all_reduce_buffer, (void *)all_reduce_buffer_recv, all_reduce_buffer_size, ncclFloat, ncclSum, multi_gpu_config.nccl_comm, 0)); cudaCheck(cudaMemcpy(all_reduce_buffer_host, all_reduce_buffer_recv, sizeof(float) * all_reduce_buffer_size, cudaMemcpyDeviceToHost)); float all_reduce_mean_value = get_mean(all_reduce_buffer_host, all_reduce_buffer_size, multi_gpu_config.process_rank); printf("[Process rank %d] average value after all reduce is %.6f\n", multi_gpu_config.process_rank, all_reduce_mean_value); float expected_all_reduce_mean_value = 0.0; for (int i = 0; i != multi_gpu_config.num_processes; ++i) { expected_all_reduce_mean_value += i + 1; } if (abs(expected_all_reduce_mean_value - all_reduce_mean_value) > 1e-5) { printf("[Process rank %d] ERROR: Unexpected all reduce value: %.8f, expected %.8f\n", multi_gpu_config.process_rank, all_reduce_mean_value, expected_all_reduce_mean_value); } else { printf("[Process rank %d] Checked against expected mean value. All good!\n", multi_gpu_config.process_rank); } free(all_reduce_buffer_host); cudaCheck(cudaFree(all_reduce_buffer)); cudaCheck(cudaFree(all_reduce_buffer_recv)); multi_gpu_config_free(&multi_gpu_config); } ================================================ FILE: dev/cuda/permute.cu ================================================ /* Kernels to demonstrate permute operation. Compile example: nvcc -O3 permute.cu -o permute The goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2). Before permutation, we need to understand how to access elements in a flattened (linear) form of the matrix. Given: dim1 = size of the 1st dimension dim2 = size of the 2nd dimension dim3 = size of the 3rd dimension dim4 = size of the 4th dimension For any element in a 4D matrix at position (i1, i2, i3, i4), where: i1 is the index in dimension 1 i2 is the index in dimension 2 i3 is the index in dimension 3 i4 is the index in dimension 4 If you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations. Initially, it might take some time to grasp, but with practice, you'll develop a mental model for it. To calculate the indices, use the following formulas: i1 = (idx / (dim2 * dim3 * dim4)) % dim1; i2 = (idx / (dim3 * dim4)) % dim2; i3 = (idx / dim4) % dim3; i4 = idx % dim4; Pattern Explanation: To find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions. Then, perform modulo operation with the current dimension. The linear index in a flattened 1D array is calculated as: linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4 This linear index uniquely identifies the position of the element in the 1D array. To permute the matrix, we need to rearrange the indices according to the new shape. In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2). The new dimension post permutation will be as follows: dim1 becomes the new 3rd dimension. dim2 becomes the new 4th dimension. dim3 becomes the new 2nd dimension. dim4 becomes the new 1st dimension. permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2; Here's how this works: i4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block. i3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in. i1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in. i2: This gives the offset within the current i1 block. Lastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix. -------------------------------------------------------------------------------------------------------------------------------------------------------- Similarly we can follow the above approach to permute matrices of any dimensions. */ #include #include #include #include #include "common.h" // CPU function to permute a 4D matrix void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) { int total_threads = dim1 * dim2 * dim3 * dim4; for (int idx = 0; idx < total_threads; idx++) { // Calculate the 4D indices from the linear index int i1 = (idx / (dim2 * dim3 * dim4)) % dim1; int i2 = (idx / (dim3 * dim4)) % dim2; int i3 = (idx / dim4) % dim3; int i4 = idx % dim4; // Compute the new index for the permuted matrix // Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2) int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2; out_matrix[permuted_idx] = matrix[idx]; } } // CUDA kernel to permute a 4D matrix __global__ void permute_kernel(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) { int idx = blockIdx.x * blockDim.x + threadIdx.x; // Ensure index is within bounds if (idx < dim1 * dim2 * dim3 * dim4) { // Calculate the 4D indices from the linear index int i1 = (idx / (dim2 * dim3 * dim4)) % dim1; int i2 = (idx / (dim3 * dim4)) % dim2; int i3 = (idx / dim4) % dim3; int i4 = idx % dim4; // Compute the new index for the permuted matrix // Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2) int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2; out_matrix[permuted_idx] = matrix[idx]; } } int main() { int dim_1 = 24; int dim_2 = 42; int dim_3 = 20; int dim_4 = 32; // Set up the device int deviceIdx = 0; cudaSetDevice(deviceIdx); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, deviceIdx); printf("Device %d: %s\n", deviceIdx, deviceProp.name); // Allocate host memory float* matrix = make_random_float(dim_1 * dim_2 * dim_3 * dim_4); float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float)); // Initialize the matrix with random values // Allocate device memory float *d_matrix, *d_permuted_matrix; cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float)); cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float)); // Copy matrix from host to device cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice); // Perform permutation on CPU clock_t start = clock(); permute_cpu(matrix, permuted_matrix, dim_1, dim_2, dim_3, dim_4); clock_t end = clock(); double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC; // Define block and grid sizes dim3 blockSize(256); int totalThreads = dim_1 * dim_2 * dim_3 * dim_4; int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size // Launch CUDA kernel to perform permutation permute_kernel<<>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4); cudaDeviceSynchronize(); // Ensure kernel execution is complete // Verify results printf("Checking correctness...\n"); validate_result(d_permuted_matrix, permuted_matrix, "permuted_matrix", dim_1 * dim_2 * dim_3 * dim_4, 1e-5f); printf("All results match.\n\n"); // benchmark kernel int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, permute_kernel, d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4 ); printf("time gpu %.4f ms\n", elapsed_time); printf("time cpu %.4f ms\n", elapsed_time_cpu); // Free allocated memory free(matrix); free(permuted_matrix); cudaFree(d_matrix); cudaFree(d_permuted_matrix); return 0; } ================================================ FILE: dev/cuda/residual_forward.cu ================================================ /* Kernels for residual forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt residual_forward.cu -o residual_forward version 1 is naive port from CPU code to kernel ./residual_forward 1 version 2 packs input into 128 bit memory reads ./residual_forward 2 */ #include #include #include #define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference lol void residual_forward_cpu(float* out, const float* inp1, const float* inp2, int N) { for (int i = 0; i < N; i++) { out[i] = inp1[i] + inp2[i]; } } // ---------------------------------------------------------------------------- // GPU kernels // elementwise ops are nice and ez __global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]); } } __global__ void residual_forward_kernel2(floatX* out, const floatX* inp1, const floatX* inp2, int N) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; if (idx < N) { x128 packed_out; x128 packed_inp1 = load128cs(inp1 + idx); x128 packed_inp2 = load128cs(inp2 + idx); for (int k = 0; k < packed_inp1.size; ++k) { packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]); } store128(out + idx, packed_out); } } // ---------------------------------------------------------------------------- // kernel launcher void residual_forward1(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) { const int grid_size = ceil_div(N, block_size); residual_forward_kernel1<<>>(out, inp1, inp2, N); cudaCheck(cudaGetLastError()); } void residual_forward2(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) { const int grid_size = ceil_div(N, (int)(block_size * x128::size)); residual_forward_kernel2<<>>(out, inp1, inp2, N); cudaCheck(cudaGetLastError()); } // kernel version dispatch void residual_forward(int kernel_num, floatX* out, const floatX* inp1, const floatX* inp2, int N, int block_size) { switch (kernel_num) { case 1: residual_forward1(out, inp1, inp2, N, block_size); break; case 2: residual_forward2(out, inp1, inp2, N, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* inp1 = make_random_float(B * T * C); float* inp2 = make_random_float(B * T * C); // move to GPU floatX* d_out; floatX* d_inp1; floatX* d_inp2; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX))); cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C)); cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel residual_forward_cpu(out, inp1, inp2, B * T * C); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); residual_forward(kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) float tol = 1e-5; #else float tol = 1e-2f; #endif validate_result(d_out, out, "out", B * T * C, tol); } printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, residual_forward, kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size ); // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 2 read and 1 write, 4 bytes each // and e.g. A100 40GB PCIe is advertised at 1,555GB/s long memory_ops = B * T * C * 3 * 4; float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); } // free memory free(out); free(inp1); free(inp2); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp1)); cudaCheck(cudaFree(d_inp2)); return 0; } ================================================ FILE: dev/cuda/softmax_forward.cu ================================================ /* Kernels for softmax forward pass. Compile example: nvcc -O3 --use_fast_math -lcublas -lcublasLt softmax_forward.cu -o softmax_forward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C ./softmax_forward 1 version 2 is a fused kernel that parallelizes over all of B,T,C ./softmax_forward 2 version 3 uses intra-warp reductions for maxval and sumval, must use block_size=32 ./softmax_forward 3 version 4 uses both intra-warp reductions and shared memory for inter-warp reductions so it can tolerate any block_size % 32 == 0. this is hopefully the most efficient version ./softmax_forward 4 version 5 is naive port from CPU code (softmax_online) to kernel: parallelizes over B,T, loops over C ./softmax_forward 5 version 6 is softmax_online that parallelizes over all of B,T,C ./softmax_forward 6 version 7 is softmax optimized for very large C. ./softmax_forward 7 */ #include #include #include #include #include #include #include "common.h" // ---------------------------------------------------------------------------- // CPU code reference void softmax_forward_cpu(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed for (int i = 0; i < N; i++) { const float* inp_row = inp + i * C; float* out_row = out + i * C; float maxval = -INFINITY; for (int j = 0; j < C; j++) { if (inp_row[j] > maxval) { maxval = inp_row[j]; } } // Note: since we want to ensure that the CUDA-kernels are accurate, // we do this accumulation in higher precision, so we can be assured // that our ground-truth is of high quality. double sum = 0.0; for (int j = 0; j < C; j++) { out_row[j] = expf(inp_row[j] - maxval); sum += out_row[j]; } float norm = 1.f / (float)sum; for (int j = 0; j < C; j++) { out_row[j] *= norm; } } } // online version of softmax on CPU from the paper "Online normalizer calculation for softmax" void softmax_forward_online_cpu(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed for (int i = 0; i < N; i++) { const float* inp_row = inp + i * C; float* out_row = out + i * C; float maxval = -INFINITY; float sum = 0.0f; for (int j = 0; j < C; j++) { float maxval_prev = maxval; if (inp_row[j] > maxval) { maxval = inp_row[j]; sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval); } else { sum += expf(inp_row[j] - maxval); } } for (int j = 0; j < C; j++) { out_row[j] = expf(inp_row[j] - maxval) / sum; } } } // ---------------------------------------------------------------------------- // GPU kernels __global__ void softmax_forward_kernel1(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < N) { const float* inp_row = inp + i * C; float* out_row = out + i * C; float maxval = -INFINITY; for (int j = 0; j < C; j++) { if (inp_row[j] > maxval) { maxval = inp_row[j]; } } double sum = 0.0; for (int j = 0; j < C; j++) { out_row[j] = expf(inp_row[j] - maxval); sum += out_row[j]; } for (int j = 0; j < C; j++) { out_row[j] /= (float)sum; } } } __global__ void softmax_forward_kernel2(float* out, const float* inp, int N, int C) { // inp is (N, C) // in each row of C elements, first calculates maxval, then returns expf(val - maxval) extern __shared__ float shared[]; int idx = blockIdx.x; // ranges [0, N) int tid = threadIdx.x; // ranges [0, block_size) int block_size = blockDim.x; const float* x = inp + idx * C; // idx-th row of inp // thread coarsening float maxval = -INFINITY; for (int i = tid; i < C; i += block_size) { maxval = fmaxf(maxval, x[i]); } shared[tid] = maxval; // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] = fmaxf(shared[tid], shared[tid + stride]); } } __syncthreads(); float offset = shared[0]; // compute expf and write the result to global memory for (int i = tid; i < C; i += block_size) { out[idx * C + i] = expf(x[i] - offset); } __syncthreads(); // thread coarsening again, for the sum x = out + idx * C; // idx-th row of out float sumval = 0.0f; for (int i = tid; i < C; i += block_size) { sumval += x[i]; } shared[tid] = sumval; // reductions for (int stride = block_size / 2; stride >= 1; stride /= 2) { __syncthreads(); if (tid < stride) { shared[tid] += shared[tid + stride]; } } // broadcast the sum to all threads in the block __syncthreads(); float sum = shared[0]; // divide the input values by the sum for (int i = tid; i < C; i += block_size) { out[idx * C + i] = x[i] / sum; } } // warp-level reduction for finding the maximum value __device__ float warpReduceMax(float val) { for (int offset = 16; offset > 0; offset /= 2) { val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); } return val; } __global__ void softmax_forward_kernel3(float* out, const float* inp, int N, int C) { // kernel must use block size of 32 extern __shared__ float shared[]; int idx = blockIdx.x; int tid = threadIdx.x; const float* x = inp + idx * C; // Thread coarsening and within-warp reduction for maxval float maxval = -INFINITY; for (int i = tid; i < C; i += blockDim.x) { maxval = fmaxf(maxval, x[i]); } maxval = warpReduceMax(maxval); // Broadcast maxval within the warp float offset = __shfl_sync(0xFFFFFFFF, maxval, 0); // Compute expf and write the result to global memory for (int i = tid; i < C; i += blockDim.x) { out[idx * C + i] = expf(x[i] - offset); } // Thread coarsening and within-warp reduction for sumval x = out + idx * C; float sumval = 0.0f; for (int i = tid; i < C; i += blockDim.x) { sumval += x[i]; } // No need to broadcast sumval since all threads in the warp will have the same value // (due to the fact that we're using __shfl_xor_sync) sumval = warpReduceSum(sumval); // Divide the input values by the sum for (int i = tid; i < C; i += blockDim.x) { out[idx * C + i] = x[i] / sumval; } } __global__ void softmax_forward_kernel4(float* out, const float* inp, int N, int C) { // out is (N, C) just like inp. Each row of inp will get softmaxed. // same as kernel3, but can handle any block size (multiple of 32) // each row of C elements is handled by block_size threads // furthermore, each block_size threads get executed in warps of 32 threads // special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions // shared memory is used for inter-warp reduction extern __shared__ float shared[]; int idx = blockIdx.x; int tid = threadIdx.x; int warpId = threadIdx.x / 32; // warp index within a block int laneId = threadIdx.x % 32; // thread index within a warp // the number of warps per block. recall that blockDim.x is block_size int warpsPerBlock = blockDim.x / 32; // shared[] must be allocated to have warpsPerBlock elements // those will be used for max and sum values float* max_or_sum_storage = shared; // one row of inp, i.e. inp[idx, :] of shape (C,) const float* x = inp + idx * C; // first, thread coarsening by directly accessing global memory in series float maxval = -INFINITY; for (int i = tid; i < C; i += blockDim.x) { maxval = fmaxf(maxval, x[i]); } // now within-warp reductions for maxval maxval = warpReduceMax(maxval); // the 0th thread of each warp writes the maxval of that warp to shared memory if (laneId == 0) max_or_sum_storage[warpId] = maxval; __syncthreads(); // now the 0th thread of the block reduces the max values in shared memory, i.e. across warps if (tid == 0) { float val = max_or_sum_storage[tid]; for (int i = 1; i < warpsPerBlock; i++) { val = fmaxf(val, max_or_sum_storage[i]); } // store the final max in the first position max_or_sum_storage[0] = val; } __syncthreads(); // broadcast the max to all threads float offset = max_or_sum_storage[0]; // compute expf and write the result to global memory for (int i = tid; i < C; i += blockDim.x) { out[idx * C + i] = expf(x[i] - offset); } // okay now we calculated exp(x - max(x)) // step 2: sum all the values and divide by the sum // thread coarsening for sum x = out + idx * C; float sumval = 0.0f; for (int i = tid; i < C; i += blockDim.x) { sumval += x[i]; } // within-warp reduction for sumval sumval = warpReduceSum(sumval); // write sumval to shared memory if (laneId == 0) max_or_sum_storage[warpId] = sumval; __syncthreads(); // inter-thread reduction of sum if (tid == 0) { float val = max_or_sum_storage[tid]; for (int i = 1; i < warpsPerBlock; ++i) { val += max_or_sum_storage[i]; } max_or_sum_storage[0] = val; } __syncthreads(); // broadcast the sum to all threads float sum = max_or_sum_storage[0]; // divide the whole row by the sum for (int i = tid; i < C; i += blockDim.x) { out[idx * C + i] = x[i] / sum; } } __global__ void softmax_forward_online_kernel1(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < N) { const float* inp_row = inp + i * C; float* out_row = out + i * C; float maxval = -INFINITY; double sum = 0.0; for (int j = 0; j < C; j++) { float maxval_prev = maxval; float current_val = inp_row[j]; if (current_val > maxval) { maxval = current_val; sum = sum * expf(maxval_prev - maxval) + expf(current_val - maxval); } else { sum += expf(current_val - maxval); } } for (int j = 0; j < C; j++) { out_row[j] = expf(inp_row[j] - maxval) / sum; } } } // struct for the reduction operation, guarantees 8-byte alignment struct __align__(8) SumMax { float maxval; float sum; }; // forceinline helps avoid function call overhead __device__ __forceinline__ SumMax reduce_sum_max_op(SumMax a, SumMax b) { bool a_bigger = (a.maxval > b.maxval); SumMax bigger_m = a_bigger ? a : b; SumMax smaller_m = a_bigger ? b : a; SumMax res; res.maxval = bigger_m.maxval; res.sum = bigger_m.sum + smaller_m.sum * expf(smaller_m.maxval - bigger_m.maxval); return res; } __global__ void softmax_forward_online_kernel2(float* out, const float* inp, int N, int C) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if (idx >= N) { return; } // one row of inp, i.e. inp[idx, :] of shape (C,) const float* x = inp + idx * C; // base case for the reduction SumMax sm_partial; sm_partial.maxval = -INFINITY; sm_partial.sum = 0.0f; // first, thread coarsening by directly accessing global memory in series for (int i = warp.thread_rank(); i < C; i += warp.size()) { sm_partial = reduce_sum_max_op(sm_partial, { x[i], 1.0f }); } // second, the reduction SumMax sm_total = cg::reduce(warp, sm_partial, reduce_sum_max_op); // divide the whole row by the sum for (int i = warp.thread_rank(); i < C; i += warp.size()) { // the below is equivalent to // out[idx * C + i] = expf(x[i] - sm_total.maxval) / sm_total.sum; // but uses special instruction that bypasses the cache __stcs(out + idx * C + i, expf(x[i] - sm_total.maxval) / sm_total.sum); } } __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int C) { // out is (N, C) just like inp. Each row of inp will get softmaxed. // same as kernel4, but optimised for very large Cs with advanced unrolling // The trick is to read into a register array (all indices known at compile time) // and always read UNROLL_FACTOR values to maximise memory level parallelism // even if we would be out of bounds, we set the index to min(C-1, idx) // so we just do some unnecessary reads (obviously bad for small C) // the writes are in a separate loop with a conditional check for out of bounds // making it separate is necessary to convince the compiler to do the right thing const int UNROLL_FACTOR = 8; const int warpsPerBlock = blockDim.x / 32; extern __shared__ float shared[]; int idx = blockIdx.x; int tid = threadIdx.x; int warpId = threadIdx.x / 32; // warp index within a block int laneId = threadIdx.x % 32; // thread index within a warp // shared[] must be allocated to have 2 * warpsPerBlock elements // first half for max values, the second half for sum values float* maxvals = shared; float* sumvals = &shared[warpsPerBlock]; if (tid >= C) { maxvals[warpId] = -INFINITY; sumvals[warpId] = 0.0f; return; } const float* x = inp + idx * C; // input float* y = out + idx * C; // output // first, thread coarsening by directly accessing global memory in series float maxval = -INFINITY; for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { maxval = fmaxf(maxval, x[min(C - 1, i + u*blockDim.x)]); } } // now within-warp reductions for maxval maxval = warpReduceMax(maxval); // the 0th thread of each warp writes the maxval of that warp to shared memory if (laneId == 0) maxvals[warpId] = maxval; __syncthreads(); // now the 0th thread reduces the maxvals in shared memory, i.e. across warps if (tid == 0) { float val = maxvals[tid]; #pragma unroll for (int i = 1; i < warpsPerBlock; i++) { val = fmaxf(val, maxvals[i]); } // store the final max in the first position maxvals[0] = val; } __syncthreads(); // broadcast the max to all threads float offset = maxvals[0]; // compute expf and write the result to global memory // + thread coarsening for sum float sumval = 0.0f; for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { float reg_array[UNROLL_FACTOR]; #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]); } #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { if (i + u*blockDim.x < C) { float output = expf(reg_array[u] - offset); y[min(C - 1, i + u*blockDim.x)] = output; // compiler likes redundant min()?! sumval += output; // combined into the same loop unlike kernel3 } } } // okay now we calculated exp(x - max(x)) // step 2: sum all the values and divide by the sum // within-warp reduction for sumval sumval = warpReduceSum(sumval); // write sumval to shared memory if (laneId == 0) sumvals[warpId] = sumval; __syncthreads(); // inter-thread reduction of sum if (tid == 0) { float val = sumvals[tid]; #pragma unroll for (int i = 1; i < warpsPerBlock; ++i) { val += sumvals[i]; } sumvals[0] = val; } __syncthreads(); // broadcast the sum to all threads float sum = sumvals[0]; // divide the whole row by the sum for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { float reg_array[UNROLL_FACTOR]; #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { reg_array[u] = y[min(C - 1, i + u*blockDim.x)]; } #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { if (i + u*blockDim.x < C) { y[i + u*blockDim.x] = reg_array[u] / sum; } } } } __global__ void softmax_forward_online_kernel8(float* out, const float* inp, int N, int C) { // online softmax paper: http://arxiv.org/abs/1805.02867 // online softmax reduces loops from 3 to 2 // which is done by calculating sumval and maxval in one loop const int warpsPerBlock = blockDim.x / warpSize; int tid = threadIdx.x; if (tid >= C) { return; } int warpId = tid / warpSize; int laneId = tid % warpSize; // one warp one row int row = blockIdx.x * warpsPerBlock + warpId; if (row >= N) { return; } const float* x = inp + row * C; float* const y = out + row * C; // merge calculating maxval and sumval in one loop // which is an arithmetic improvment from online softmax over normal softmax float maxval = -INFINITY, sumval = 0.0f, bigger; for (int i = laneId; i < C; i += warpSize) { // when updating the maxval, dynamically updates the previous sumval by // multiplying e^{previous_maxval - current_maxval} bigger = fmaxf(maxval, x[i]); sumval = sumval * expf(maxval - bigger) + expf(x[i] - bigger); maxval = bigger; } // use warp functions instead of cooperative groups for better readibility // calculate the warp wised maxval and sumval float offsetMaxval, offsetSumval; for (int offset = warpSize / 2; offset > 0; offset >>= 1) { __syncwarp(); offsetMaxval = __shfl_down_sync(0xFFFFFFFF, maxval, offset); offsetSumval = __shfl_down_sync(0xFFFFFFFF, sumval, offset); if (offsetMaxval > maxval) { sumval *= expf(maxval - offsetMaxval); maxval = offsetMaxval; } else { offsetSumval *= expf(offsetMaxval - maxval); } sumval += offsetSumval; } // sync the warp wised maxval and sumval // which are also the maxval and sumval of one row in C maxval = __shfl_sync(0xFFFFFFFF, maxval, 0); sumval = __shfl_sync(0xFFFFFFFF, sumval, 0); for (int i = laneId; i < C; i += warpSize) { y[i] = expf(x[i] - maxval) / sumval; } } // ---------------------------------------------------------------------------- // kernel launcher void softmax_forward1(float* out, const float* inp, int N, int C, const int block_size) { const int grid_size = ceil_div(N, block_size); softmax_forward_kernel1<<>>(out, inp, N, C); cudaCheck(cudaGetLastError()); } void softmax_forward2(float* out, const float* inp, int N, int C, const int block_size) { int grid_size = N; size_t shared_mem_size = block_size * sizeof(float); softmax_forward_kernel2<<>>(out, inp, N, C); } void softmax_forward3(float* out, const float* inp, int N, int C, int block_size) { block_size = 32; // awkward but ok. this one only works with block size 32 int grid_size = N; size_t shared_mem_size = block_size * sizeof(float); softmax_forward_kernel3<<>>(out, inp, N, C); } void softmax_forward4(float* out, const float* inp, int N, int C, int block_size) { int grid_size = N; // for each warp in the block we need a float that will be used for both maxval and sumval size_t shared_mem_size = block_size / 32 * sizeof(float); softmax_forward_kernel4<<>>(out, inp, N, C); } void softmax_forward_online1(float* out, const float* inp, int N, int C, int block_size) { const int grid_size = ceil_div(N, block_size); softmax_forward_online_kernel1 <<>> (out, inp, N, C); cudaCheck(cudaGetLastError()); } void softmax_forward_online2(float* out, const float* inp, int N, int C, int block_size) { const int grid_size = ceil_div(N * 32, block_size); softmax_forward_online_kernel2 <<>> (out, inp, N, C); cudaCheck(cudaGetLastError()); } void softmax_forward7(float* out, const float* inp, int N, int C, int block_size) { int grid_size = N; size_t shared_mem_size = 2 * block_size / 32 * sizeof(float); softmax_forward_kernel7<<>>(out, inp, N, C); } void softmax_forward_online8(float* out, const float* inp, int N, int C, int block_size) { const int grid_size = ceil_div(N * 32, block_size); softmax_forward_online_kernel8<<>>(out, inp, N, C); cudaCheck(cudaGetLastError()); } // kernel version dispatch void softmax_forward(int kernel_num, float* out, const float* inp, int N, int C, const int block_size) { switch (kernel_num) { case 1: softmax_forward1(out, inp, N, C, block_size); break; case 2: softmax_forward2(out, inp, N, C, block_size); break; case 3: softmax_forward3(out, inp, N, C, block_size); break; case 4: softmax_forward4(out, inp, N, C, block_size); break; case 5: softmax_forward_online1(out, inp, N, C, block_size); break; case 6: softmax_forward_online2(out, inp, N, C, block_size); break; case 7: softmax_forward7(out, inp, N, C, block_size); break; case 8: softmax_forward_online8(out, inp, N, C, block_size); break; default: printf("Invalid kernel number\n"); exit(1); } } // ---------------------------------------------------------------------------- int main(int argc, char **argv) { srand(0); int B = 8; int T = 1024; int V = 50257; int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); // create host memory of random numbers float* out = (float*)malloc(B * T * V * sizeof(float)); float* inp = make_random_float(B * T * V); // make the input less uniformly random: Otherwise, all probabilities will be basically zero, // and the tests are not actually meaningful. const int* outliers = make_random_int(B * T * 3, V); for(int k = 0; k < 3; ++k) { for(int j = 0; j < B * T; ++j) { inp[j * V + outliers[j*3 + k]] *= 20; } } // move to GPU float* d_out; float* d_inp; cudaCheck(cudaMalloc(&d_out, B * T * V * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * V * sizeof(float))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * V * sizeof(float), cudaMemcpyHostToDevice)); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); int block_sizes[] = {32, 64, 128, 256, 512, 1024}; softmax_forward_cpu(out, inp, B * T, V); { float max_el = -INFINITY; for(int i = 0; i < B * T * V; ++i) { max_el = max(max_el, out[i]); } assert(max_el > 1e-4); printf("Largest output is: %f\n", max_el); } // first check the correctness of the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); softmax_forward(kernel_num, d_out, d_inp, B * T, V, block_size); validate_result(d_out, out, "out", B * T * V, 1e-4f); } printf("All results match. Starting benchmarks.\n\n"); // time the kernel at different block sizes for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, softmax_forward, kernel_num, d_out, d_inp, B * T, V, block_size ); printf("block_size %4d | time %.4f ms | per token %.2f µs\n", block_size, elapsed_time, elapsed_time * 1'000 / (B*T)); } // free memory free(out); free(inp); free((void*)outliers); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); return 0; } ================================================ FILE: dev/cuda/trimat_forward.cu ================================================ /* Triangular matrix multiplication as in autoregressive attention. A short story. by @ngc92 Compile: nvcc -O3 --use_fast_math -lcublas -lcublasLt trimat_forward.cu -o trimat_forward -lcublas Run: cuBLAS baseline kernel ./trimat_forward 0 naive ./trimat_forward 1 registers ./trimat_forward 2 tri3 ./trimat_forward 3 tri4 ./trimat_forward 4 */ #include #include #include #include #include #include #include #include #include "common.h" static float* d_qkvr; // scratch for the cublas kernel /* ** Chapter I - Introduction ** * * You are Trimul. You've always wanted to do fast matrix multiplication, but they said * "Don't bother, big dumb Cublas is much faster than you!" * "I don't need to be faster than Cublas", you replied, "I can be smarter. Cublas calculates * the entire matrix, but I need only half. If I'm more than half as fast as Cublas, I'm * going to win." * * So to prove everyone wrong, you enter the TriMatlon, the most prestigious competition * for anyone paying Attention. * * Before you start preparing, lets have a look at the players involved * * First, there is the Referee (`trimul_cpu`), slow and ponderous, but producing results * beyond reproof. * Then, there is Cublas. Cublas' mind is so inflexible, it doesn't actually comprehend * what we are trying to do here, so Cublas has brought an assistant (`permute_kernel`) * that translates the competition into a task that it can solve. But once it recognizes * the problem, its muscle memory kicks in, and matrix products are produced faster than * the eye can see. Stuck in its routine, Cublas doesn't realize the task is already * finished with the lower triangle, though. * * If you can do without an assistant, and can solve the right task, then that's your opportunity * to shine! */ // taken from then attention forward pass void trimul_cpu(float* out, const float* inp, int B, int T, int C, int NH) { // inp shape: (B, T, 3, NH, HS) // out shape: (B, NH, T, T) int C3 = C*3; int HS = C / NH; // head size float scale = 1.0 / sqrtf(HS); for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { for (int nh = 0; nh < NH; nh++) { // Q[b][nh][t][:] = inp[b][t][0][nh][:] (where : is the slice operator for hs) const float* query_t = inp + b * T * C3 + t * C3 + nh * HS; // out[b][nh][t][:] float* out_bth = out + b * NH * T * T + nh * T * T + t * T; // pass 1: calculate query dot key and maxval for (int t2 = 0; t2 <= t; t2++) { // K[b][nh][t2][:] = inp[b][t2][1][nh][:] const float* key_t2 = inp + b * T * C3 + t2 * C3 + nh * HS + C; // +C because it's key // Q[b][nh][t][:] dot K[b][nh][t2][:] float val = 0.0f; for (int i = 0; i < HS; i++) { val += query_t[i] * key_t2[i]; } val *= scale; // out[b][nh][t][t2] = val out_bth[t2] = val; } for(int t2 = t + 1; t2 < T; ++t2) { // causal mask, using NAN to supress warnings -> it could be -inf // but it doesn't matter because in validate_result we ignore infinities/NANs out_bth[t2] = NAN; } } } } } __global__ void permute_kernel(float* q, float* k, float* v, const float* inp, int B, int T, int NH, int HS) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, T, HS) // but instead, we have a single tensor QKV (inp) of shape (B, T, 3, NH, HS) int idx = blockIdx.x * blockDim.x + threadIdx.x; // Q[b][nh][t][hs] = inp[b][t][0][nh][hs] if (idx < B * NH * T * HS) { int b = idx / (NH * T * HS); int rest = idx % (NH * T * HS); int nh = rest / (T * HS); rest = rest % (T * HS); int t = rest / HS; int hs = rest % HS; int inp_idx = \ (b * T * 3 * NH * HS) + (t * 3 * NH * HS) + (0 * NH * HS) + (nh * HS) + hs; q[idx] = inp[inp_idx]; k[idx] = inp[inp_idx + NH * HS]; v[idx] = inp[inp_idx + 2 * (NH * HS)]; } } void trimul_cublas(float* preatt, const float* inp, int B, int T, int C, int NH) { int HS = C / NH; // head size // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) float* q, * k, * v; q = d_qkvr + 0 * B * T * C; k = d_qkvr + 1 * B * T * C; v = d_qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; int num_blocks = ceil_div(total_threads, 256); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); cudaCheck(cudaGetLastError()); // batched matrix multiply with cuBLAS const float alpha = 1.0f / sqrtf(HS); const float beta = 0.0f; // This schedules in parallel B*NH matmuls of shape q@k^t = (T, HS) @ (HS, T) = (T, T). // IMPORTANT NOTE: Cublas uses a column-major (and we use row-major in our codebase) representation, // so this call might look confusing to you if you look at the `cublasSgemmStridedBatched` signature. // // In order to avoid having to do an additional transpose operation after this func call, // we need to pass in K as the first argument and Q as the second argument, which might make you think we're computing K^T @ Q. // That combined with the shapes we got after the permute kernel - (B, NH, T, HS) (I'll omit B, NH for brevity going forward) // and you might think we end up with (HS, T) @ (T, HS) = (HS, HS). // This is not the case. :) // // Cublas sees our row-major matrix (T, HS) as (HS, T), hence we set the lead dimensions to HS (see function signature). // We transpose K and end up computing K^T @ Q = (T, HS) @ (HS, T) = (T, T). // If you were to interpret the above formula K^T @ Q you might think we end up with: // ----------------------------------- // k1.dot(q1) k1.dot(q2) ... k1.dot(qT) // k2.dot(q1) k2.dot(q2) ... k2.dot(qT) // ... // kT.dot(q1) kT.dot(q2) ... kT.dot(qT) // ----------------------------------- // But as I mentioned, Cublas is column-major! // So given that the dot product is symmetric we can write k1.dot(q1) as q1.dot(k1) and transposing the above // representation we can see what we actually end up with in the row-major format: // ----------------------------------- // q1.dot(k1) q1.dot(k2) ... q1.dot(kT) // q2.dot(k1) q2.dot(k2) ... q2.dot(kT) // ... // qT.dot(k1) qT.dot(k2) ... qT.dot(kT) // ----------------------------------- // which is exactly what we wanted! :) cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH)); } /* ** Chapter II - Getting a Team ** * * OK, you've registered for the competition, now what to do. TriMatlon is a team competition, so first, you need * to figure out what kind of team you need, and how to organize it. The individual instances and heads of the * problem are completely independent, so you just can send separate teams to work there completely independently. * * To figure out how to organize each team, you take out your spyglass (`Nsight Compute`) and look how the Cublas teams * are handling their work. * Turns out, you need 256 athletes in each group, and those handle 128 x 128 of the tasks. They work together in * a tight square formation, 16 wide and 16 deep. * * So, you went out and got your 100 000 friends, and split them into groups (`trimul_launcher`). Each group gets * informed about where they should work (`trimul_global`) and goes off to do their thing (`matmul_tri_naive`). * Let's observe how we're doing. */ // using creates an alias for a function pointer using matmul_fn_ptr = void(*)(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha); template __global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float* inp, int T, int C, int NH) { // skip above the diagonal if(blockIdx.y > blockIdx.x) return; // set up indices int C3 = C*3; int HS = C / NH; // head size float scale = 1.0 / sqrtf(HS); // we put the "batch x head" dimension into the z block index. int b = blockIdx.z / NH; int nh = blockIdx.z % NH; // Get the base address for the current batch and head // shapes -> inp (B, T, 3, NH, HS), Q (B, NH, T, HS), K (B, NH, T, HS) const float* q = inp + b * T * C3 + nh * HS; // Q[b][nh][:][:] = inp[b][:][0][nh][:] const float* k = inp + b * T * C3 + nh * HS + C; // K[b][nh][:][:] = inp[b][:][1][nh][:] float* r = out + (b*NH + nh)*T*T; // out[b][nh][:][:] // start the multiplication matmul_tri(r, T, k, C3, q, C3, T, HS, scale); } template void trimul_launcher(float* out, const float* inp, int B, int T, int C, int NH) { // we assume nice shapes here. Let's not make the code a mess by supporting weird shapes that you // wouldn't want to use anyway. assert(T % 128 == 0); // No need to ceil_div, if it's not a multiple of 128, we would get wrong results anyway. trimul_global<<>>(out, inp, T, C, NH); cudaCheck(cudaGetLastError()); } /* ** Chapter III - ... ** * * You go over to the playing field. On one end of the field, there is a huge pile of funnily shaped cookie cutters. * Some in the shape of animals, some in the shape of a landscape. Each group of workers has assigned some runners, * fetching the cookie cutters for them. The workers seem very relaxing, chatting with each other, lounging about. * You focus in on one of them. * * He seems to be giving an instruction to a runner, and then turns back to reading a novel. The runner, meanwhile, * crosses the field and back, handing him an elephant shape. Then she's off again to pick up a savannah background. * Having received the two shapes, pressed them into the dough, and makes an elephant-in-the-savannah cookie. He hands * the cutters back to the runner. "Can you please fetch me an elephant and a jungle next?" * While she's on her way, he takes a sip off his cocktail. * This time, she's making only one trip, keeping the elephant in her pocket (_Cache_). Still, it seems to take forever. * You keep observing: * - Elephant and zoo * - Elephant and island * ... * - Lion and savannah * - Lion and jungle * - Lion and zoo * ... * * The worker has his poor runner fetch the same things over and over again, looking like she's about to faint from exhaustion. * Even though she realizes this and always keeps one of them in her pocket, there is so much running, * and little actual work happening. * * Clearly, this isn't going to be effective, so you call a team meeting. */ // baseline implementation: 20 ms __device__ void matmul_tri_naive(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) { // coordinate system: // | - - - - - > j // | // | // v // i // get coordinates of our block - each thread is responsible for a single 8x8 block. int i_base = 128 * blockIdx.x + 8 * threadIdx.x; int j_base = 128 * blockIdx.y + 8 * threadIdx.y; // One more check to skip the upper diagonal in blocks that are on the diagonal. // Note: we deliberately waste some compute on the jagged diagonal i.e. elements that belong // to the upper triangle that should be masked out. This will be ignored due to the causal mask // in the reference CPU implementation when used in the `validate_result` function. // Alternatively this check should be done in the nested for loop below -> if (i > j) return. if(j_base > i_base) return; // Simple nested loop that calculates 8x8 results in one thread. for(int io = 0; io < 8; ++io) { int i = i_base + io; for(int jo = 0; jo < 8; ++jo) { int j = j_base + jo; float val = 0; for (int s = 0; s < HS; ++s) { val += q[i * QS + s] * k[j * KS + s]; } p[i * PS + j] = val * alpha; } } } /* ** Chapter IV - ... ** * * Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send their runners 64 times * to fetch the corresponding shapes. This is terribly inefficient; The runners need a minute or so for each trip, * but making a cookie can be done in just a second. * * "Let's try something different tomorrow: Just get all 16 cookie cutters that you need, and do all 64 combinations * of them! See all this free space on your workbench (_registers_), you can keep them all there for easy access." * * The next morning, you come back to the field for another practice session. Initially, there is bustling activity * with the runners, picking up 16 shapes for each worker. But then, the workers have to put down their newspapers * and start making cookies. Now there are 64 combinations, so it takes them a full minute. * * Not all groups of workers are equally fast. When the first group finishes with all animal-landscape combinations, * they already start asking the runners for the next set of cookie cutters, combining plants and houses. Even though * the workers are much busier than before, they are still spending most of their time just waiting. * * Still, instead of being busy for 20 hours, your team is now done with the task in just 3h 30 minutes; already, this * is five times faster. * * You think to yourself: "Why should we stop at 8 x 8 combinations? Lets to 16 x 16, that's only twice the work for * the runners, but four times as much for the actual workers." * You head over to the baking area, and make that suggestion to one of your team leaders. * "In theory, that sounds great", she agrees, "but see, we only have limited space on our workbenches (_registers_). * There is still some room left, but we simply cannot bake 256 cookies at the same time, sorry." * * A different strategy is needed, then. */ // reorganize loops to enable data reuse: 3.5 ms __device__ void matmul_tri_registers(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) { int i_base = 128 * blockIdx.x + 8 * threadIdx.x; int j_base = 128 * blockIdx.y + 8 * threadIdx.y; if (j_base > i_base) return; // shift our pointers to the sub-block this thread is responsible for q += i_base * QS; k += j_base * KS; p += i_base * PS + j_base; float vals[8][8] = {}; for (int hs = 0; hs < HS; ++hs) { float lhs[8]; float rhs[8]; for (int u = 0; u < 8; ++u) { lhs[u] = q[u * QS + hs]; rhs[u] = k[u * KS + hs]; } for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; ++j) { vals[i][j] += lhs[i] * rhs[j]; } } } for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; ++j) { p[i * PS + j] = vals[i][j] * alpha; } } } /* ** Chapter IV - By the Bucketload ** * * Despite the hectic activity, you pick out one of the runners. "Why are you always brining just one shape? Wouldn't * it be much more efficient if you took more than one?" * "Of course", the runner answers, "but they've asked me for an elephant, a lion, a zebra, and a goldfish. These * are all over the place, I can't just pick them up at one spot (_strided acccess_). * "But the lion is right next to the palm tree. You could bring those two together?", you confirm. * "Yes", he says, "if they just asked for the different categories at the same time, that would make things * so much easier. See, I have this bucket, I could carry lots of things in one go if I could just scoop them up * from the same place (_coalesced access_). * * OK, then lets fetch the first animal, first plant, first vehicle, and first landmark shape in one go (_vectorized load_). * [Here, the metaphor breaks down a bit: Since we're accumulating all the results, getting more data at the same time * depth-wise doesn't require more space on the workbench. We're stacking the cookies!] * * You also streamline the shape combination further. Instead of picking up all animals and landscapes at one, it is * more efficient, using less workbench space, to just pick up all animals. Then, you get one landscape, combine it * will all animals, get the next landscape, combine, and so on. * * In this way, instead of 2 x 8 x 4 cookie cutters that take up space, you only need (8+1) x 4 at the same time. * * With these optimizations, you are down to 100 minutes for this task. Still slower than Cublas, but not by much. * * In the arena, each team also has access to a small storage hut, much closer to their workbenches than the piles of * cookie cutters on the other side. Cublas is using them heavily, so maybe you should, too. */ // convenient helper functions to make the code below more readable __device__ float4 ld_vec(const float* address) { return *reinterpret_cast(address); } __device__ void st_vec(float* address, float4 val) { *reinterpret_cast(address) = val; } // vector instructions for coalesced memory access: 1.7 ms __device__ void matmul_tri3(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) { // Same logic as previous kernel we just load in float4 to improve coalescing int i_base = 128 * blockIdx.x + 8 * threadIdx.x; int j_base = 128 * blockIdx.y + 8 * threadIdx.y; if (j_base > i_base) return; // shift our pointers to the sub-block this thread is responsible for q += i_base * QS; k += j_base * KS; p += i_base * PS + j_base; float vals[8][8] = {}; for (int hs = 0; hs < HS; hs += 4) { // load in float4 to improve coalescing float4 rhs[8]; for (int u = 0; u < 8; ++u) { rhs[u] = ld_vec(k + u * KS + hs); } for (int i = 0; i < 8; ++i) { // no need to keep lhs around for the i loop, it's only reused in the j loop anyway. float4 lhs = ld_vec(q + i * QS + hs); for (int j = 0; j < 8; ++j) { vals[i][j] += lhs.x * rhs[j].x; vals[i][j] += lhs.y * rhs[j].y; vals[i][j] += lhs.z * rhs[j].z; vals[i][j] += lhs.w * rhs[j].w; } } } for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; j += 4) { float4 result; result.x = vals[i][j + 0] * alpha; result.y = vals[i][j + 1] * alpha; result.z = vals[i][j + 2] * alpha; result.w = vals[i][j + 3] * alpha; st_vec(p + i * PS + j, result); } } } /* ** Chapter V - Sharing is Caring ** * * You take a look around the shed, and see that there are 32 shelves there. They are much larger than the workbenches, * giving you enough space for all the cookie cutters needed by the entire team. * * Within the team, workers have banded together in groups of 32. They are always doing the same thing, reducing the * amount of effort required for coordination. However, that also means that if you send them all to pick up different * cookie cutters from the same shelf, they will have to wait and queue up (_shared memory bank conflict_). * * In order to achieve maximum efficiency, we send the runners fetching cutters with the maximum bucket size: 32 different * categories at the same time. * * [I'm having trouble getting the specifics into the story in a sensible way. For now, please read the code for more * details.] * */ __device__ void matmul_tri4(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) { int i_base = 128 * blockIdx.x + 8 * threadIdx.x; int j_base = 128 * blockIdx.y + 8 * threadIdx.y; // we need all threads for loading data, so none of them can chicken out early, even // if they are not responsible for any useful result. if (blockIdx.y > blockIdx.x) return; q += 128 * blockIdx.x * QS; k += 128 * blockIdx.y * KS; __shared__ float lhs_s[128][32]; __shared__ float rhs_s[128][32]; float vals[8][8] = {}; for (int so = 0; so < HS; so += 32) { // Read a large slice of the input, worked on together by all threads. // They are organized differently for this part. We want to ensure // fully coalesced loads, so we let a single warp handle consecutive // addresses, which means we need to combine two threadIdx.y values // in one read operation. // note: threads may read data here that they don't need themselves. // this really is a block-level operation. // note2: 16x16 threads (i.e. the block) will, through this for loop, fetch 32 dims from 128 keys and 128 queries // i.e. from Q/K, of shape (T, HS) take q[:128, so*32:(so+1)*32] and k[:128, so*32:(so+1)*32] __syncthreads(); for(int y = threadIdx.y / 2; y < 128; y += 8) { int xo = (threadIdx.y % 2) * 16; lhs_s[y][threadIdx.x + xo] = q[y * QS + so + threadIdx.x + xo]; rhs_s[y][threadIdx.x + xo] = k[y * KS + so + threadIdx.x + xo]; } __syncthreads(); // Now we compute a partial dot product (only 32 dims) for all combinations of keys and queries (128x128). // Each thread does 8x8 of these partial dot products. // E.g. thread (0,0) covers queries 0-7 and keys 0-7. More generally first row of threads // (0,:) covers queries 0-7 with keys 0-127 and so on. // In the next iterations of the outer (`so`) loop we'll be accumulating values to `vals` until we // get the full dot product. We then later deposit it into the output matrix for all 8x8 blocks // that are below the diagonal. for (int si = 0; si < 32; ++si) { float rhs[8]; for (int u = 0; u < 8; ++u) { rhs[u] = rhs_s[u + 8 * threadIdx.y][(si + threadIdx.x) % 32]; } for (int ii = 0; ii < 8; ++ii) { float lhs = lhs_s[ii + 8 * threadIdx.x][(si + threadIdx.x) % 32]; for (int ji = 0; ji < 8; ++ji) { vals[ii][ji] += lhs * rhs[ji]; } } } } // don't write above the diagonal if (j_base > i_base) return; for (int ii = 0; ii < 8; ++ii) { for (int ji = 0; ji < 8; ji += 4) { int i = i_base + ii; int j = j_base + ji; float4 result; result.x = vals[ii][ji + 0] * alpha; result.y = vals[ii][ji + 1] * alpha; result.z = vals[ii][ji + 2] * alpha; result.w = vals[ii][ji + 3] * alpha; st_vec(p + i * PS + j, result); } } } /* ** Chapter VI - Competition Day ** * * Finally, you feel ready to take on Cublas. You hand out tickets to the event for you friends to see. * * --------------------------------------------------------------------------------- * | CuBLAS vs TriMul - Fight of the Century | * | | * | Ticket code: | * | > nvcc -O3 --use_fast_math trimat_forward.cu -o trimat_forward -lcublas | * | > ./trimat 4 | * | | * --------------------------------------------------------------------------------- */ void trimul_gpu(int kernel_num, float* out, const float* inp, int B, int T, int C, int NH) { switch (kernel_num) { case 0: trimul_cublas(out, inp, B, T, C, NH); break; case 1: trimul_launcher(out, inp, B, T, C, NH); break; case 2: trimul_launcher(out, inp, B, T, C, NH); break; case 3: trimul_launcher(out, inp, B, T, C, NH); break; case 4: trimul_launcher(out, inp, B, T, C, NH); break; default: printf("Invalid kernel number\n"); exit(1); } } int main(int argc, char **argv) { setup_main(); int B = 8; int T = 1024; int C = 768; int NH = 12; // create host memory of random numbers float* out = (float*)malloc(B * NH * T * T * sizeof(float)); float* inp = make_random_float(B * T * 3 * C); // move to GPU float* d_out; float* d_inp; cudaCheck(cudaMalloc(&d_out, B * NH * T * T * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float))); cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice)); // buffer for cublas cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float))); // read kernel_num from command line int kernel_num = 1; if (argc > 1) { kernel_num = atoi(argv[1]); } printf("Using kernel %d\n", kernel_num); // first check the correctness of the kernel trimul_cpu(out, inp, B, T, C, NH); trimul_gpu(kernel_num, d_out, d_inp, B, T, C, NH); validate_result(d_out, out, "out", B * NH * T * T, 1e-4f); printf("All results match. Starting benchmarks.\n\n"); // benchmark speed of the kernel int repeat_times = 100; float elapsed_time = benchmark_kernel(repeat_times, trimul_gpu, kernel_num, d_out, d_inp, B, T, C, NH); float cublas_time = benchmark_kernel(repeat_times, trimul_gpu, 0, d_out, d_inp, B, T, C, NH); printf("time %.2f ms vs %.2f ms for CuBLAS\n", elapsed_time, cublas_time); // free memory free(out); free(inp); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); cudaCheck(cudaFree(d_qkvr)); cublasDestroy(cublas_handle); return 0; } ================================================ FILE: dev/data/README.md ================================================ # dev/data organization The idea is that each dataset has a .py file here in the root of `dev/data`, and each dataset then creates a directory here, and writes and caches anything inside that directory. So for example: - running `python tinystories.py` will create a directory `tinystories` with its .bin files inside it - running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these. Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default. ================================================ FILE: dev/data/data_common.py ================================================ """ Common utilities for the datasets """ import requests from tqdm import tqdm import numpy as np def download_file(url: str, fname: str, chunk_size=1024): """Helper function to download a file from a given url""" resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) with open(fname, "wb") as file, tqdm( desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, ) as bar: for data in resp.iter_content(chunk_size=chunk_size): size = file.write(data) bar.update(size) HEADERS_INFO = { "gpt-2": { "magic": 20240520, "version": 1, "token_dtype": np.uint16, }, "llama-3": { "magic": 20240801, "version": 7, "token_dtype": np.uint32, }, } def write_datafile(filename, toks, model_desc="gpt-2"): """ Saves token data as a .bin file, for reading in C. - First comes a header with 256 int32s - The tokens follow, each as uint16 (gpt-2) or uint32 (llama) """ assert len(toks) < 2**31, "token count too large" # ~2.1B tokens assert model_desc in ["gpt-2", "llama-3"], f"unknown model descriptor {model_desc}" info = HEADERS_INFO[model_desc] # construct the header header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values header[0] = info["magic"] header[1] = info["version"] header[2] = len(toks) # number of tokens after the 256*4 bytes of header # construct the data (numpy array of tokens) toks_np = np.array(toks, dtype=info["token_dtype"]) # write to file num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize) print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format") with open(filename, "wb") as f: f.write(header.tobytes()) f.write(toks_np.tobytes()) def write_evalfile(filename, datas): """ Saves eval data as a .bin file, for reading in C. Used for multiple-choice style evals, e.g. HellaSwag and MMLU - First comes a header with 256 int32s - The examples follow, each example is a stream of uint16_t: - delimiter of 2**16-1, i.e. 65,535 - , bytes encoding this example, allowing efficient skip to next - , the index of the example in the dataset -